GAN、WGAN通俗基本原理

GAN、WGAN通俗基本原理

來自專欄機器學習和chatbot備忘錄1 人贊了文章

一、基本的GAN

GAN 的主要思想:通過神經網路去模擬現有數據的分布(極大似然估計),從而達到生成的目的。優化方法:通過判別器不斷優化。

給定某分布 P_{data} ,利用 P_{G}(x;	heta) 模擬 P_{data} 分布。

P_{data}中sample出 x^{1},x^{2},x^{3}....x^{n} 數據,優化 	heta 使得L最大:

對等式兩邊取log,個人認為上式少了從Pdata中Sample出n個樣本的概率,下圖中的約等式將這一部分包含了進去,然後再加上Pdata的信息熵(下式中為減去,是因為把符號提到了前面,變成了減號,這一項和 	heta 無關,不影響結果),這一項看起來有點像KL散度了,或者叫相對熵。

上式來說極大似然函數很難去優化,怎麼樣去優化這個 P_{G} ,GAN解決了這個問題。利用一個判別器去衡量PG生成的數據和真實數據之間的差距,然後不斷優化判別器和生成器,直至判別器判別不出來。

其中V(G,D)為

V(G,D)到達最大的意思就是說,把來自真實分布的數據和來自生成器生成的數據分辨開,式子中的對應關係為,前一項表示來自Pdata的數據,D(x)的值要大,後一項表示來自Pg的數據,D(x)要小,這樣才能到達最大,maxV(G,D)是用來優化判別器的。而對於給定的不同的G(生成器),每個G都會有一個maxV(G,D),然後再在不同的G中,選擇一個使得maxV(G,D)最小的那個G,這一步用來優化生成器,這是上圖G = arg min maxV(G,D)的含義。理解了的話,下圖G1,G2,G3應該選擇第三個。

根據上面的思想,優化就分為三步 1.優化判別器,2.優化生成器,3.repeat

1.優化判別器:對於給定生成器下,求解最優的D,即求解maxV(G,D)

上式,只有一個未知數D,上式對D求導,可得最大時,D為

再把D帶回V(G,D)得到maxV(G,D),下面就是關於G的式子:

整理,得到關於KL散度的式子:

由於JS散度的定義是P,Q和兩者平均值的散度,所以上式可以化成

JS散度的取值範圍在(0,log2)之間,所以V( G ,D^{*} )取值在(-2log2,0)之間。

2.優化生成器

上步求解出了判別器最優的值,並帶回原式,接下來V( G ,D^{*} )只和生成器有關優化生成器,也就是求解使得真實分布和生成的分布差異最小的生成器。

因為V( G ,D^{*} )關於G的函數,求最小值,用梯度下降法:

其中L(G)為V( G ,D^{*} )。V( G ,D^{*} )也就是js散度,更新G去減小js散度。

在實際操作中

(1).初始化 	heta_{d} ,	heta_{g}

(2).每次迭代時:

P_{data}(x) 中sample出m個樣本{x^{1},x^{2},x^{3}....x^{m}},從P_{prior}(z)中sample出m個noise樣本,輸入生成器中獲得{	ilde{x}^{1},	ilde{x}^{2},	ilde{x}^{3} .... 	ilde{x}^{m}}, 	ilde{x} = G(z;	heta)

更新判別器的參數	heta_{d},使判別結果最大化:

V= frac{1}{m}sum_{i=1}^{m}{log(D(x^i))}+frac{1}{m}sum_{i=1}^{m}{log(1-D(	ilde{x}^i))}

對上式求導,利用梯度上升更新	heta_{d}上述步驟是為了更新判別器,使判別器效果更好。需要重複k次,為了可以使得V收斂,這樣可以保證找到maxV(G,D),。

更新生成器

V= frac{1}{m}sum_{i=1}^{m}{log(D(x^i))}+frac{1}{m}sum_{i=1}^{m}{log(1-D(G(z^i))}

更新生成器和前面一項沒有關係,只和後面一項有關,後面一項log(1-x)的圖像在x接近0的時候,趨近於0。不利於開始的更新,所以換成-log(x),圖像的趨勢都是一致的,在x接近0時,梯度很大,有利於更新。所以優化目標變成了下面的式子:

V = -frac{1}{m}sum_{i=1}^{m}{log(D(G(z^i))}

對上式求導,利用梯度下降去更新 	heta_{g}

基本GAN存在的問題

1.兩個分布很遠時JS散度都是log2,沒有漸變的過程,不利於梯度更新

2.判別器訓練的太強,會使得JS散度都是log2,參照第一條,不利於更新,可能需要加點雜訊進去才可以。

3.只能學到一部分的分布,不能學到全部的分布。


二、WGAN

WGAN的優化,沒有使用JS散度

WGAN 的優化公式為:

限定判別器的範圍(若是不給上界的話,優化判別器時會使得來自 P_{data}和P_{G} 的數據分的很開,不利於優化。)

判別器要符合1-Lipschitz函數

Lipschitz Function:

||f(x1)-f(x2)|| leq K||x1-x2||,k=1為1-Lipschitz

就是輸出的變化給了一個上界。

實際操作是加上weight clip限值權重的範圍在(-c,c)之間。

相比較於基本的GAN 更新的地方在去掉了所有的log,做了cilp


三、Improve GAN

improve GAN 對WGAN 進行了優化,WGAN的weight clip會出現方正邊角的問題,就像1正則。

Improve GAN 的優化目標為:

將1-Lipschitz變成了懲罰項,會使得D(x)關於所有x的斜率絕對值之和都會接近1,這樣才能保證優化時,懲罰項為0,優化目標最大。而其中x-penalty在Pdata和Pg的連線之間。

推薦閱讀:

TAG:深度學習DeepLearning | 機器學習 | 生成對抗網路GAN |