什麼是生成式對抗網路 GAN

生成式對抗網路有兩部分組成

生成者,generator ,G同學,就像一個造假鈔的人

和鑒定者 discriminator ,D同學,就像驗鈔員

一開始G同學不知道怎麼印錢,印出來的錢是這樣的

但是D同學也不怎麼樣,也只能勉強分別出來這是假鈔

訓練的過程是G同學和D同學一起成長的過程,

就像周伯通聯繫左右互搏,兩者共同進步,一起成長.最後G同學做出來的假鈔越來越真,而D同學的驗鈔水平也越來越強

https://www.zhihu.com/video/936929994686054400

綠色的曲線是真實數據的分布圖,是平均值為4,標準差為0.5的高斯分布,就是我們剛才說的真錢的分布

紅色曲線是生產者生產出來的錢的分布

藍色的是驗鈔員,認為的所有的真錢的分布

可以看到三條曲線分布越來越相近

G同學成了造假高手 D同學成了驗鈔專家

Data sample 是真實的數據,Generator Sample 是G同學 generator 生產出來的數據

D同學 discriminator,它的輸入是"真錢",和"假錢"

D 同學的輸出是一個概率,數值在0到1 之間

G同學的輸入是隨機數據, noise

下面我們來看一下code

首先是真實數據,這是一個均值為4,標準差為0.5的高斯分布

class DataDistribution(object): def __init__(self): self.mu = 4 self.sigma = 0.5 def sample(self, N): samples = np.random.normal(self.mu, self.sigma, N) samples.sort() return samples

generator非常簡單,就是一個線性函數加上非線性函數softplus, 再加上一個線性函數

h_dim 是線性函數輸出的size,最後的輸出的維度是[batch_size * 1]

def linear(input, output_dim, scope=None, stddev=1.0): with tf.variable_scope(scope or linear): w = tf.get_variable( w, [input.get_shape()[1], output_dim], initializer=tf.random_normal_initializer(stddev=stddev) ) b = tf.get_variable( b, [output_dim], initializer=tf.constant_initializer(0.0) ) return tf.matmul(input, w) + bdef generator(input, h_dim): h0 = tf.nn.softplus(linear(input, h_dim, g0)) h1 = linear(h0, 1, g1) return h1

discriminator 稍微複雜一些,用到了線性函數和非線性函數的結合,同時還加入了minibatch, 這樣就可以一批一起看,而不是一個一個點去看,最後輸出的維度是[batch_size * 1]

def discriminator(input, h_dim, minibatch_layer=True): h0 = tf.nn.relu(linear(input, h_dim * 2, d0)) h1 = tf.nn.relu(linear(h0, h_dim * 2, d1)) # without the minibatch layer, the discriminator needs an additional layer # to have enough capacity to separate the two distributions correctly h2 = minibatch(h1) h3 = tf.sigmoid(linear(h2, 1, scope=d3)) return h3def minibatch(input, num_kernels=5, kernel_dim=3): x = linear(input, num_kernels * kernel_dim, scope=minibatch, stddev=0.02) activation = tf.reshape(x, (-1, num_kernels, kernel_dim)) diffs = tf.expand_dims(activation, 3) - tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0) abs_diffs = tf.reduce_sum(tf.abs(diffs), 2) minibatch_features = tf.reduce_sum(tf.exp(-abs_diffs), 2) return tf.concat([input, minibatch_features], 1)

完整的代碼

AYLIEN/gan-intro


推薦閱讀:

機器學習總結-K近鄰(KNN)演算法
直覺版:CNN解讀part1
機器學習開放課程(一):使用Pandas探索數據分析
機器學習總結-Logistic 回歸

TAG:生成對抗網路GAN | TensorFlow | 機器學習 |