你的 GAN 是不是少了什麼?
來自專欄 NLPer 的成長之路17 人贊了文章
Goodfellow 給一篇剛上 arXiv 的 The relativistic discriminator: a key element missing from standard GAN 文章點了贊,這篇文章就是一篇閱讀筆記以及附帶簡單的代碼實現。
Intuition
文章說標準的 GAN(SGAN) 在 Generator(下文用 G 代替) 生成的樣本越來越逼真的時候 Discriminator 缺了個東西。什麼呢?相對的概念,先來看下面這張圖片:
可以看到,我們的 real data 是麵包,fake data 是柯基(好萌啊2333),? 越大則說明是麵包的概率越大。圖中列出了三種情況:
- 真的麵包,真的柯基:這種情況二者的區別很明顯,因而?
- 真的麵包,柯基屁股(很像麵包):這種情況區別就沒有第一種那麼明顯了,因而 有所下降
- 像狗的麵包,真的柯基:和第二種情況類似,相對的區別度降低了, 同樣有所下降
有了一個模糊的印象之後,我們展開說說這個相對究竟是什麼。
Arguments of RaGAN
相對,一言以蔽之:Discriminator(下文中用 D 代替)衡量樣本真實性的時候,應該要同時利用 real data 和 fake data,衡量的由絕對的真假變成相對的為真或為假的概率。 作者從三個方面論述了其觀點:
Priori Knowledge
先驗知識的利用,即每次我們餵給 Discriminator(下文中用 D 代替) 的樣本中,基本上是一半 real data,一半 fake(Generator generated)。也就是說,不知道這個前提的話,那麼如果 G 生成的樣本(比如說圖片)能夠以假亂真的話,那麼 D 會認為所有的樣本都是 real 的,而如果知道這個前提,那麼當 fake 比 real 更 real 的時候,discrinimator 應該給 real samples 打低分(認為他們是 fake) 而不是認為所有的 samples 都是 real。因為在看到了更 real 的 fake samples 之後,相對地,利用先驗知識,我們會認為不那麼 real 的 samples(比如狗麵包)是 fake 的。而 SGAN 的訓練中並沒有利用到這一部分先驗知識。
Divergence Optimization
我們知道,SGAN 在訓練 D 的時候事實上是在 minimize 生成器分布和真實數據分布的 JS-Divergence,而 JS-Divergence 在兩個分布相同時達到最小值,其表現是 D 無法區分 real data 和 fake data,認為其為真的概率均為 0.5;JS-Divergence 在兩個分布差異較大的時候較大,即 D 認為 real data 為真的概率為 1,而認為 fake data 為真的概率為 0。理想的訓練過程是如下 (C) 圖所示:
但事實上,我們在訓練的時候一味地希望生成的圖片足夠逼真,即一直在將 D(fake_data)
向 1
推,而不管 D(real_data)
。這有做是達不到最小值的,這樣的 optimization 過程是存在問題的,WGAN 論文中似乎也有提到這一點。
Gradient
WGAN 等一系列對 GAN 做了改進的 GAN 被稱為 IPM-based GAN,作者將其梯度和標準的 GAN 進行了對比:
下面是 IPM-based GAN 的梯度:
當下面這些條件滿足的時候,兩式相等:
- 在訓練 D 的時候,?
- 訓練 G 的時候,?
- ? ,其中 ? 是一類實值函數(這個一般都能滿足)
考慮到 IPM-based GAN 相對於 SGAN 具有更好的穩定性,可以推斷,如果將 SGAN 推向 IPM-based GAN,能夠提高其穩定性。怎麼才能達到這個轉變呢?如果我們認為 D 足夠好,即能夠在訓練 G 時做到認為 ? (這是一個比較強的假設,但是在訓練一開始是能夠滿足的),而在訓練 D 的時候 ?,如開頭圖片的 (b) 所示那種情況,那麼唯一缺少的條件就是 ?。但是在 SGAN 中,? 只跟 real data 有關,也即是絕對的真假,如果 G 生成的樣本能夠影響 ? ,讓它變得不那麼真實(即相對應地變小),也即相對的真假。這個時候, 才可能變成 0。總結起來一句話就是,在 ? 提升的時候 ?相應地減少(這是一種相對的變化),這樣 GAN 的訓練才能更加穩定。
Methods
Relativistic GAN
如何把這種相對為真和為假的概率考慮進去?很簡單,只要把 logits 做一個簡單相減,即:
這裡用 sigmoid
函數將 logits 轉化成概率,然後再取對數;我們可以很容易地把這個式子進行泛化,即用更一般的函數(不一定是似然函數)來替換它。
Relativistic average GAN
不過這個時候,我們發現 GAN 中 D 的功能已經悄然發生的改變,由原來的:衡量輸入數據為真的概率,變成了輸入的數據與其對立類型隨機的一個樣本(如果輸入為 real data,那麼衡量其比 fake data 更像真的概率,反之亦然)相比更像真的的概率。那麼,更一般地,我們利用對立類型的平均真實度來作為更可靠的參照對象。
所謂平均真實度,就是對整個(或者一個 mini-batch) real data 和 fake data 求其 ? 的數學期望,這樣的估計能夠更加整體的反映訓練在這一時刻生成器生成的 fake data 的真實程度。同樣,我們可以輕易地將其泛化到別的函數之上。最後整個演算法流程如下:
Toy Demo
原作給了很多例子來說明 RaGAN 的效果,並且也提供了相應的代碼。這裡我就拿 MNIST 做一個小小的 Demo,在原版的 GAN 的 demo 上做了一個很小的改動,就是把 loss 計算中的 logits 進行對應的相減。
Generator
# Generator NetZ = tf.placeholder(tf.float32, shape=[None, 100])?G_W1 = tf.Variable(xavier_init([100, 128]), name="G_W1")G_W2 = tf.Variable(xavier_init([128, 784]), name="G_W2")?G_b1 = tf.Variable(tf.zeros([128]), name="G_b1")G_b2 = tf.Variable(tf.zeros([784]), name="G_21")?theta_G = [G_W1, G_W2, G_b1, G_b2]?def generator(z): G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 G_prob = tf.nn.sigmoid(G_log_prob)? return G_prob
Discriminator
# Discriminator NetD_W1 = tf.Variable(xavier_init([784, 128]), name="D_W1")D_b1 = tf.Variable(tf.zeros(shape=[128]), name="D_b1")?D_W2 = tf.Variable(xavier_init([128, 1]), name="D_W2")D_b2 = tf.Variable(tf.zeros(shape=[1]), name="D_b2")?theta_D = [D_W1, D_W2, D_b1, D_b2]?def discriminator(x): D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) D_logit = tf.matmul(D_h1, D_W2) + D_b2 D_prob = tf.nn.sigmoid(D_logit)? return D_prob, D_logit
RaGAN
# Ra GAN simple version mnist = input_data.read_data_sets(MNIST_data, one_hot=True)X = tf.placeholder(tf.float32, shape=[None, 784], name="X")??G_sample = generator(Z)D_real, D_logit_real = discriminator(X)D_fake, D_logit_fake = discriminator(G_sample)# discriminator # log( real_logits - fake_logits)D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real - D_logit_fake, labels=tf.ones_like(D_logit_fake)))?# generator loss# log(fake_logits - real_logits )G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake - D_logit_real, labels=tf.ones_like(D_logit_fake)))?# optimizerD_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
Train
def sample_Z(m, n): return np.random.uniform(-1., 1., size=[m, n])??steps = 1000001mb_size = 128Z_dim = 100??def plot(samples): fig = plt.figure(figsize=(4, 4)) gs = gridspec.GridSpec(4, 4) gs.update(wspace=0.05, hspace=0.05)? for i, sample in enumerate(samples): ax = plt.subplot(gs[i]) plt.axis(off) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect(equal) plt.imshow(sample.reshape(28, 28), cmap=Greys_r)? plt.show(block=False) return fig??with tf.Session() as sess: sess.run(tf.global_variables_initializer()) num = 0 for i in range(steps): X_mb, _ = mnist.train.next_batch(mb_size) _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={ X: X_mb, Z: sample_Z(mb_size, Z_dim) }) # 注意,這裡的 feed_dict 和原來不一樣了 _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={ X: X_mb, Z: sample_Z(mb_size, Z_dim) }) # 下面的 vanilla GAN 的 loss # _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})? if i % 100 == 0: print("Step %d" % i) print("G loss: %f" % G_loss_curr) print("D loss: %f" % D_loss_curr) print()? if i % 1000 == 0: samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)}) fig = plot(samples) fname = {}.png.format(str(num).zfill(5)) plt.savefig(fname, bbox_inches=tight) print(saved image + fname) num += 1 # plt.clf() plt.close(fig)
Summary
GAN 的難訓練是臭名昭著了,作者通過考慮引入相對這個概念來使得訓練過程變得更加穩定。這類見微知著的工作,還有前不久的 IndRNN 我覺得也算是一類,都是作者對某一類盛行的模型進行細緻地分析之後提出了很簡單卻又不平凡的改動,希望能夠有更多這樣的工作出現!
推薦閱讀: