如何形象又有趣的講解對抗神經網路(GAN)是什麼?

自從Goodfellow在2014年提出了對抗神經網路後在這個這個領域十分火熱,也經常刷知乎的時候看到相關文章,搜了一些資料後對其大致意思明白了,但其具體實現是如何的不太清楚,比如生成照片的網路G是如何構建的?有對這個方面很了解的知友告知一下嗎?


最形象的應該就是公式推導吧,認真看完,應該理解會很清楚哈~~下面是自己理解的推導,希望對你有用哈~~

生成對抗網路分為:1.生成模型,2.鑒別模型。其中,生成模型從無到有不斷地生成數據,而鑒別模型不斷鑒別生成器生成的模型;二者不斷對抗,生成模型拚命生成不讓鑒別模型識別出來的數據,鑒別模型拚命鑒別生成模型生成的數據;二者不斷成長,得到最好的生成模型鑒別模型

具體推倒如下:

首先是符號說明, 注意GAN,主要學習的是數據的分布,最終得到的是兩個一樣的數據分布。

定義數據分布,generator,discriminator等輸出

定義鑒別模型:

當鑒別模型輸出D(x)為1時,即可以輕鬆判別數據,此時上式取值最大。

當鑒別模型輸出D(G(z))為0時,即鑒別模型輕鬆地鑒別出生成模型的數據,此時上式取值最大。

故而為了鑒別模型越來越好,定義以下目標函數:

很顯然,最好的鑒別模型是使得V(G,D)最大的,即:

當鑒別模型取最好的時候,最好的生成模型即使得目標函數最小的,如下:

然後這個問題就轉變成了最大最小問題:

這個問題真的有最優解嗎?下面證明這個問題。

先證明有最優鑒別模型:

得到最優鑒別模型是

下面我們再來考慮一下GAN最終的目的是,得到生成模型可以生成非常逼真的數據,也就是說和真實數據分布一樣的數據,此時鑒別模型的輸出為:

其中,數據分布一樣

當DG輸出為0.5時,說明鑒別模型已經完全分不清真實數據和GAN生成的數據了,此時就是得到了最優生成模型了。

下面證明,生成模型存在:

充分性:

必要性:

上式最終可以轉化為KL散度,如下:

KL散度永遠大於等於0,可以知道目標函數最終最優值為-log4。

以上即是,GAN證明的推倒。

參考文獻:

Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.

An Annotated Proof of Generative Adversarial Networks with Implementation Notes


其實 GAN 網路的巧妙在於其設計思維,而技術上是對現有演算法的組合,沒什麼神秘的。既然題主對其大致意思已經了解了,那我就舉一個構建實例。GAN 網路主要由兩個網路合成。

生成網路

輸入為隨機數,輸出為生成數據。比如說,輸入一個一維隨機數,輸出一張 28x28 (784 維) 的圖片 (MNIST)。

網路實現用最 vanilla 的多層神經網路即可。記得不宜超過三層,否則梯度消失梯度爆炸的問題你懂的。中間的激活函數用當下最時興的 Relu 就好。輸出層需要使用其他激活函數,目的是為了生成數據的取值範圍與真實數據相似,具體使用什麼函數視情況而定。下面給出一個可能的實現方案:

(?, 1) 
ightarrow (1, 128) 
ightarrow relu 
ightarrow (128, 128) 
ightarrow relu 
ightarrow (128, 784) 
ightarrow tanh

判別網路

現在,我們把生成網路生成的數據稱為假數據,對應的,來自真實數據集的數據稱為真數據。判別網路輸入為數據(真或假),輸出一個判別概率。需注意的是,這裡判別的是圖像的真偽,而非圖像的類別。還以 MNIST 為例。輸入一個圖片後,我們並不要認圖片上畫的是啥數字,而是判別圖像到底來自於真實數據集,還是生成網路的胡亂合成。所以輸出一個一維條件概率(伯努利分布的概率參數)就好了。

網路實現同樣可用最基本的多層神經網路。下面給出一個可能的方案。

(?, 784) 
ightarrow (784, 128) 
ightarrow relu 
ightarrow (128, 128) 
ightarrow relu 
ightarrow (128, 1) 
ightarrow sigmoid

loss 函數

既然有倆網路,那麼我們就有倆 loss 函數對應之。生成網路用

mathcal L_{G} = mathbb H(1, D(G(mathbf z)))

G 代表 Generative; D 代表 Discriminative;mathbb H 代表交叉熵,這也是常用演算法之一,如果題主對於其意義有何不解,網上有大把資料。 mathbf z 是輸入生成網路的隨機數,那麼 G(mathbf z) 就是生成網路合成的假數據,D(G(mathbf z)) 則是對這個假數據的判別概率。這個 loss 用大白話來說,我生成網路的目標就是要你判別網路覺得我合成的數據是真的!(概率 1

判別網路的 loss 函數用

mathcal L_{D} = mathbb H(1, D(mathbf x)) + mathbb H(0, D(G(mathbf z)))

mathbf x 為真實數據。這個 loss 說的是,我判別網路就是要將真數據歸為真,假數據歸為假,既不想放過一個假數據,也不想錯殺一個真數據。

可見,這兩個 loss 的定義非常直覺化。對抗這個稱呼就是這麼來的。

訓練:

訓練我們用兩步走,先優化一次 mathcal L_{D} 再優化一次 mathcal L_{G} ,如此往複直到題主滿意。兩步走的訓練演算法與 Goodfellow 最初論文中的演算法不太一樣,不過結果是基本「等價」的。

超參:

GAN 網路對超參的敏感是眾所周知的。上面提供的超參絕對不能保證能生成令人滿意的結果。我只是拍腦袋想的。。。但是,題主絕對可以得到一些啟發性的結果讓自己對 GAN 網路有進一步的了解。

實現

用框架如 Tensorflow 之類顯然是最快的實現方法。不用自己算 analytic 導數,不用糾結矩陣的維度,還有什麼比這更舒心的。


先形象的解釋一下GAN是什麼。。。

我覺得很像。。。考試,確切地說是語文考試 = =

就好比,一個要考語文的學生,他在不斷練習解答閱讀題的過程中,不斷揣摩閱讀題的套路,希望在考試中用同樣的套路拿到高分。

當他參加語文考試後,批卷老師會將這名學生的答案與標準答案對照,以此來給出分數,此外,一個好的批卷老師,本身也要對標準答案有足夠理解。

如果這個學生想得到最高分的話,那他要做的就是,和答案一模一樣 = =

//換成GAN解釋一下

G(學生) 通過輸入 z (漢字,我們假設這是個已經會組織漢字的學生),將 z (漢字)不斷變換 輸出儘可能接近 real data (標準答案)的 sample(考生的答案),有一個 D (批卷老師)來判斷夠不夠接近,然後D(批卷老師)也不斷學習real data(標準答案),加深對real data(標準答案)的理解,給出更公平的評分

//額外需要補充的

有點不同的地方,在GAN的訓練中,D 一直在給G批卷子。。。

"比如生成照片的網路G是如何構建的?"

G 有個數字輸入z(好比參加語文考試要用漢字寫,目標圖片在計算機中用數值表示,G也要有個數字基礎),G本身是個神經網路(可以是MLP可以是VAE可以是balabala。。。反正G要有學習能力啊),G大膽的嘗試著生成一發圖像,並且交給了D,D 比較了一下 G生成的圖像 和 真實圖像的差別(呃計算機語言的話,叫距離),給出一個評分並返回(Backprop )給G,而這個返回的過程,就是G學習的過程

然後G生成了第二發圖像,交給D,D研究下和真實數據的差距,嗯變小了但還是有,給個比上回好點的分數,返回給G

訓練了n萬次後,G和D都變強了。。。


GAN的生成器G是通過辨別器D的指導下,通過迭代,變成可以把隨機分布變成樣本分布的生成器。

首先判別器D開始訓練,假如訓練到最優,它的表達式應該是訓練樣本概率分布/(生成樣本概率分布+訓練樣本概率分布),之後訓練G開始訓練,它的最優解當且僅當G生成樣本的概率分布=訓練樣本的概率分布。

理論上一次迭代就可以,工程實踐應該是不停的交替迭代。


作者在文章裡面給出一個有趣(同時也不很清晰)的比方:

The generative model can be thought of as analogous to a team of counterfeiters,

trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles.

翻譯過來就是:生成模型類似於一群貨幣造假者,他們生產假幣並在沒有發現的情況下使用假幣。而判別模型類似於警察,他們則會去檢驗假幣。 這兩者進行競技的過程使得造假者不斷提高造假技術,同時警察不斷提高偵查(假幣)的技術。這個不斷循環的過程,直到警察無法辨認出假幣而結束。

詳細論文正在看,希望以後能再回來作進一步完善。(一個直接的體會是,第二手資料適合做了解,第一手資料才適合做深入的閱讀。)


看一個代碼吧。


推薦閱讀:

amazon picking challenge(APC)2016中識別和運動規劃的主流演算法是什麼?
48個深度學習相關的平台和開源工具包,一定有很多你不知道的?
如何看待Baidu的Deep Speech 2語音識別系統入選MIT科技評論十大突破?
對於圖像識別和語音識別,其各自的深度學習框架的實現差異大嗎,假如理解了其中之一,轉向另一邊容易嗎?
現有的語音識別技術能否達到自動輸出嚴式國際音標的水平?

TAG:人工智慧 | 機器學習 | 深度學習DeepLearning | 生成對抗網路GAN |