GAN、WGAN通俗基本原理
來自專欄機器學習和chatbot備忘錄1 人贊了文章
一、基本的GAN
GAN 的主要思想:通過神經網路去模擬現有數據的分布(極大似然估計),從而達到生成的目的。優化方法:通過判別器不斷優化。
給定某分布 ,利用 模擬 分布。
從中sample出 數據,優化 使得L最大:
對等式兩邊取log,個人認為上式少了從Pdata中Sample出n個樣本的概率,下圖中的約等式將這一部分包含了進去,然後再加上Pdata的信息熵(下式中為減去,是因為把符號提到了前面,變成了減號,這一項和 無關,不影響結果),這一項看起來有點像KL散度了,或者叫相對熵。
上式來說極大似然函數很難去優化,怎麼樣去優化這個 ,GAN解決了這個問題。利用一個判別器去衡量PG生成的數據和真實數據之間的差距,然後不斷優化判別器和生成器,直至判別器判別不出來。
其中V(G,D)為
V(G,D)到達最大的意思就是說,把來自真實分布的數據和來自生成器生成的數據分辨開,式子中的對應關係為,前一項表示來自Pdata的數據,D(x)的值要大,後一項表示來自Pg的數據,D(x)要小,這樣才能到達最大,maxV(G,D)是用來優化判別器的。而對於給定的不同的G(生成器),每個G都會有一個maxV(G,D),然後再在不同的G中,選擇一個使得maxV(G,D)最小的那個G,這一步用來優化生成器,這是上圖G = arg min maxV(G,D)的含義。理解了的話,下圖G1,G2,G3應該選擇第三個。
根據上面的思想,優化就分為三步 1.優化判別器,2.優化生成器,3.repeat
1.優化判別器:對於給定生成器下,求解最優的D,即求解maxV(G,D)
上式,只有一個未知數D,上式對D求導,可得最大時,D為
再把D帶回V(G,D)得到maxV(G,D),下面就是關於G的式子:
整理,得到關於KL散度的式子:
由於JS散度的定義是P,Q和兩者平均值的散度,所以上式可以化成
JS散度的取值範圍在(0,log2)之間,所以V( , )取值在(-2log2,0)之間。
2.優化生成器
上步求解出了判別器最優的值,並帶回原式,接下來V( , )只和生成器有關,優化生成器,也就是求解使得真實分布和生成的分布差異最小的生成器。
因為V( , )關於G的函數,求最小值,用梯度下降法:
其中L(G)為V( , )。V( , )也就是js散度,更新G去減小js散度。
在實際操作中:
(1).初始化 ,
(2).每次迭代時:
從 中sample出m個樣本{},從中sample出m個noise樣本,輸入生成器中獲得{},
更新判別器的參數,使判別結果最大化:
V=
對上式求導,利用梯度上升更新。上述步驟是為了更新判別器,使判別器效果更好。需要重複k次,為了可以使得V收斂,這樣可以保證找到maxV(G,D),。
更新生成器:
V=
更新生成器和前面一項沒有關係,只和後面一項有關,後面一項log(1-x)的圖像在x接近0的時候,趨近於0。不利於開始的更新,所以換成-log(x),圖像的趨勢都是一致的,在x接近0時,梯度很大,有利於更新。所以優化目標變成了下面的式子:
V =
對上式求導,利用梯度下降去更新 。
基本GAN存在的問題:
1.兩個分布很遠時JS散度都是log2,沒有漸變的過程,不利於梯度更新
2.判別器訓練的太強,會使得JS散度都是log2,參照第一條,不利於更新,可能需要加點雜訊進去才可以。
3.只能學到一部分的分布,不能學到全部的分布。
二、WGAN
WGAN的優化,沒有使用JS散度
WGAN 的優化公式為:
限定判別器的範圍(若是不給上界的話,優化判別器時會使得來自 的數據分的很開,不利於優化。)
判別器要符合1-Lipschitz函數
Lipschitz Function:
||f(x1)-f(x2)|| K||x1-x2||,k=1為1-Lipschitz
就是輸出的變化給了一個上界。
實際操作是加上weight clip限值權重的範圍在(-c,c)之間。
相比較於基本的GAN 更新的地方在去掉了所有的log,做了cilp:
三、Improve GAN
improve GAN 對WGAN 進行了優化,WGAN的weight clip會出現方正邊角的問題,就像1正則。
Improve GAN 的優化目標為:
將1-Lipschitz變成了懲罰項,會使得D(x)關於所有x的斜率絕對值之和都會接近1,這樣才能保證優化時,懲罰項為0,優化目標最大。而其中x-penalty在Pdata和Pg的連線之間。
推薦閱讀:
TAG:深度學習DeepLearning | 機器學習 | 生成對抗網路GAN |