手把手教你理解和實現生成式對抗神經網路(GAN)

手把手教你理解和實現生成式對抗神經網路(GAN)

來自專欄景略集智19 人贊了文章

生成式對抗神經網路(GAN)是目前深度學習研究中最活躍的領域之一,原因正是其能夠生成非常逼真的合成結果。在本文,我們會學習 GAN 的工作原理,然後用 TensorFlow 實現一個簡單的 GAN。文章結構如下

  • GAN 的基本理念和工作原理
  • 實現一個 GAN 模型,能從一個簡單分布中生成數據
  • 可視化和分析 GAN 的各個方面,更好地理解其背後原理

本文代碼地址見文末。

GAN 工作原理

GAN 的基本理念其實非常簡單,其核心由兩個目標互相衝突的神經網路組成,這兩個網路會以越來越複雜的方法來「矇騙」對方。這種情況可以理解為博弈論中的極大極小博弈樹

我們以一個偽造貨幣的例子形象地解釋 GAN 的工作原理。

在這個過程中,我們想像有兩類人:警察和罪犯。我們看看他們的之間互相衝突的目標:

  • 罪犯的目標:他的主要目標就是想出偽造貨幣的複雜方法,從而讓警察無法區分假幣和真幣。
  • 警察的目標:他的主要目標就是想出辨別貨幣的複雜方法,這樣就能夠區分假幣和真幣。

隨著這個過程不斷繼續,警察會想出越來越複雜的技術來鑒別假幣,罪犯也會想出越來越複雜的技術來偽造貨幣。這就是 GAN 中「對抗過程」的基本理念。

GAN 充分利用「對抗過程」訓練兩個神經網路,這兩個網路會互相博弈直至達到一種理想的平衡狀態,我們這個例子中的警察和罪犯就相當於這兩個神經網路。其中一個神經網路叫做生成器網路 G(Z),它會使用輸入隨機雜訊數據,生成和已有數據集非常接近的數據;另一個神經網路叫鑒別器網路 D(X),它會以生成的數據作為輸入,嘗試鑒別出哪些是生成的數據,哪些是真實數據。鑒別器的核心是實現二元分類,輸出的結果是輸入數據來自真實數據集(和合成數據或虛假數據相對)的概率。

整個過程的目標函數從正式意義上可以寫為:

我們在前面所說的 GAN 最終能達到一種理想的平衡狀態,是指生成器應該能模擬真實的數據,鑒別器輸出的概率應該為 0.5, 即生成的數據和真實數據一致。也就是說,它不確定來自生成器的新數據是真實還是虛假,二者的概率相等。

你可能很好奇,為什麼我們會需要這麼複雜的學習過程?學習這樣一種模型的好處是什麼?

GAN 的創造者 Ian Goodfellow 曾提到過,GAN 以及所有生成方法的背後靈感源自著名物理學家理查德·費曼的一句名言:

我不能創造的東西,我就不理解(What I cannot create, I do not understand)。

當然 Ian Googfellow 這裡指的是機器:「AI不能創造的東西,它就不理解。」

這是息息相關的,因為如果我們能從模型中生成真實的數據分布,那麼就意味著我們能知道理解該模型的方方面面的信息。大多時候,這些真實的數據分布包含數百萬張圖像,我們可以用具有數千個參數的模型生成它們,模型的這些參數能夠捕捉給定圖像的本質。

在實際生活中,GAN 有很多用途,我們在後面會講到。

實現 GAN

在這部分,我們會生成一個非常簡單的數據分布,試著學習一個生成器函數,它會用 GAN 模型從該數據分布中生成數據。整個部分分成 3 個小部分。首先,我們會寫一個基本函數生成一個二次型分布(真實數據分布),然後,我們寫一個生成器以及一個鑒別器。最後我們會用數據以對抗的方式訓練這兩個神經網路。

本次實現的目標就是學習一種新函數,能夠從和訓練數據一樣的分布中生成數據。我們預期訓練中生成器網路應當能生成遵循二次型分布的數據,這裡在下個部分會更詳細的解釋。雖然我們是以非常簡單的數據分布著手,但 GAN 很容易地能延伸為從很複雜的數據集中生成數據,比如生成手寫字、明星臉部、動物等等。

生成訓練數據

我們通過用 Numpy 庫生成隨機樣本,然後用一種函數生成第二個數據分布來實現我們的真實數據集。為了簡單起見,我們將函數設為二次函數。當然你可以用這裡的代碼生成具有更多維度或特徵之間關係更為複雜的數據集。

import numpy as npdef get_y(x): return 10 + x*xdef sample_data(n=10000, scale=100): data = [] x = scale*(np.random.random_sample((n,))-0.5) for i in range(n): yi = get_y(x[i]) data.append([x[i], yi]) return np.array(data)

生成的數據非常簡單,畫出來會如下所示:

實現生成器和鑒別器網路

我們現在用 TensorFlow 層實現生成器和鑒別器網路。首先用如下函數實現生成器網路:

def generator(Z,hsize=[16, 16],reuse=False): with tf.variable_scope("GAN/Generator",reuse=reuse): h1 = tf.layers.dense(Z,hsize[0],activation=tf.nn.leaky_relu) h2 = tf.layers.dense(h1,hsize[1],activation=tf.nn.leaky_relu) out = tf.layers.dense(h2,2)return out

該函數以 placeholder 為隨機樣本(Z),數組 hsize 為 2 個隱藏層中的神經元數量,變數 reuse 則用於重新使用同樣的網路層。使用這些輸入,函數會創建一個具有 2 個隱藏層和給定數量節點的全連接神經網路。函數的輸出為一個 2 維向量,對應我們試著去學習的真實數據集的維度。對上面的函數也很容易修改,添加更多隱藏層、不同類型的層、不同的激活和輸出映射等。

我們用如下函數實現鑒別器網路:

def discriminator(X,hsize=[16, 16],reuse=False): with tf.variable_scope("GAN/Discriminator",reuse=reuse): h1 = tf.layers.dense(X,hsize[0],activation=tf.nn.leaky_relu) h2 = tf.layers.dense(h1,hsize[1],activation=tf.nn.leaky_relu) h3 = tf.layers.dense(h2,2) out = tf.layers.dense(h3,1)return out, h3

該函數會將輸入 placeholder 認作來自真實數據集向量空間的樣本,這些樣本可能是真實樣本,也可能是生成器網路所生成的樣本。和上面的生成器網路一樣,它也會使用 hsize 和 reuse 為輸入。我們在鑒別器中使用 3 個隱藏層,前兩個層的大小和輸入中一致,將第三個隱藏層的大小修改為 2,這樣我們就能在 2 維平面上查看轉換後的特徵空間,在後面部分會講到。該函數的輸出是給定 X 和最後一層的輸出(鑒別器從 X 中學習的特徵轉換)的 logit 預測(logit 就是神經網路模型中的 W * X矩陣)。該 logit 函數和 S 型函數正好相反,後者常用於表示幾率(變數為 1 的概率和為 0 的概率之間的比率)的對數。

對抗式訓練

出於訓練目的,我們將如下佔位符 x 和 z 分別定義為真實樣本和隨機雜訊樣本:

X = tf.placeholder(tf.float32,[None,2])Z = tf.placeholder(tf.float32,[None,2])

對於從生成器中生成樣本以及向鑒別器中輸入真實和生成的樣本,我們還需要創建出它們的圖形。通過使用前面定義的函數和佔位符可以做到這一點:

G_sample = generator(Z)r_logits, r_rep = discriminator(X)f_logits, g_rep = discriminator(G_sample,reuse=True)

使用生成的數據和真實數據的 logit,我們將生成器和鑒別器的損失函數定義如下:

disc_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=r_logits,labels=tf.ones_like(r_logits)) + tf.nn.sigmoid_cross_entropy_with_logits(logits=f_logits,labels=tf.zeros_like(f_logits)))gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=f_logits,labels=tf.ones_like(f_logits)))

這些損失是基於 sigmoid 交叉熵的損失,其使用我們前面定義的方程。這種損失函數在離散分類問題中很常見,它將 logit(由我們的鑒別器網路給定)作為輸入,為每個樣本預測正確的標籤,然後會計算每個樣本的誤差。我們使用的是 TensorFlow 中實現的該種損失函數逇優化版,比直接計算交叉熵更加穩定。更多詳情可以看看相關 TensorFlow API:

tensorflow.org/api_docs

接下來,我們用上面所述的損失函數和生成器及鑒別器網路函數中定義的網路層範圍,定義這兩個網路的優化器。在兩個神經網路中我們使用 RMSProp 優化器,學習率設為 0.001,範圍則為我們只為給定網路所獲取的權重或變數。

gen_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="GAN/Generator")disc_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="GAN/Discriminator")gen_step = tf.train.RMSPropOptimizer(learning_rate=0.001).minimize(gen_loss,var_list = gen_vars) # G Train stepdisc_step = tf.train.RMSPropOptimizer(learning_rate=0.001).minimize(disc_loss,var_list = disc_vars) # D Train step

然後我們按照所需數量的步驟訓練這兩個網路:

for i in range(100001): X_batch = sample_data(n=batch_size) Z_batch = sample_Z(batch_size, 2) _, dloss = sess.run([disc_step, disc_loss], feed_dict={X: X_batch, Z: Z_batch}) _, gloss = sess.run([gen_step, gen_loss], feed_dict={Z: Z_batch})print "Iterations: %d Discriminator loss: %.4f Generator loss: %.4f"%(i,dloss,gloss)

可以對上面的代碼進行修改,獲得更複雜的訓練過程,比如為鑒別器或生成器更新運行多個步驟,獲取真實和生成的樣本的特徵,繪製生成的樣本。這裡的操作可以參考我的代碼庫。

分析GAN

可視化訓練損失

為了能更好地理解訓練過程,我們可以每 10 次迭代後就繪製出訓練損失。從下圖中我們可以看到損失逐漸下降,到了訓練末尾幾乎保持不變。當鑒別器和生成器網路的損失達到這種幾乎不再變動的狀態時,就表示模型達到了平衡狀態。

可視化訓練期間的樣本

我們還可以在訓練中每 1000 次迭代後繪製出真實樣本和生成的樣本,這些可視化圖能清晰地展現生成器網路首先以輸入和數據集向量空間的隨機初始映射開始,然後慢慢發展至模擬真實數據集樣本。你可以看到,「虛假」樣本開始看起來越來越像「真實」數據分布。

可視化生成器更新

在這部分,我們將對抗式訓練期間更新生成器網路權重的影響進行可視化。通過繪製鑒別器網路最後一個隱藏層的激活狀況完成這一步。我們將最後一個隱藏層大小設置為 2,這樣就能在無需降維(即將輸入樣本轉換為不同的向量空間)的情況下很容易地繪製圖形。我們對可視化鑒別器網路所學習的特徵轉換函數很感興趣,正是這個函數讓模型區分真實和虛假數據。

我們繪製出鑒別器網路最後一層學習的真實樣本及生成樣本的特徵准換圖形,並且分別繪出更新生成器網路的權重之前和之後兩種情況的圖形。另外還繪製出在輸入樣本的特徵轉換後獲取的數據點的形心(centroid)。最後,我們分別計算真實數據和虛假數據的數據點形心,以及更新生成器之間和之後的數據點形心。從可視化圖形中可以發現如下現象:

  • 和預期一樣,真實數據樣本的轉換特徵並沒有出現變化。從圖形中可以看到它們完全一致。
  • 從形心可視化圖中可以看到,生成的數據樣本的特徵形心幾乎和真實數據樣本的特徵形心走勢一致。
  • 還可以看到,隨著迭代次數增加,真實樣本的轉換特徵混入越來越多的生成樣本的轉換特徵。這也在預期之內,因為在訓練鑒別器網路末尾時,應該就無法區分真實樣本和生成樣本。因此,在訓練結束時,這兩種樣本的轉換特徵應當一致。

結語

本文我們用 TensorFlow 實現了一種概念驗證 GAN 模型,能從非常簡單的數據分布中生成數據。在你自己練習時,建議將文中代碼進行如下調整:

  • 對鑒別器更新之前和之後的情況進行可視化
  • 改變各個層級的激活函數,看看訓練和生成樣本的區別
  • 添加更多層以及不同類型的層,看看對訓練時間和訓練穩定性的影響
  • 調整代碼以生成數據,包含來自兩個不同曲線的數據
  • 調整以上代碼,能夠處理更複雜的數據,比如 MNIST,CIFAR-10,等等。

本文代碼地址:

github.com/aadilh/blogs

可能你還敢興趣:

景略集智:你會有貓的:AI幫你生成各種各樣的喵。?

zhuanlan.zhihu.com圖標


參考資料:

blog.paperspace.com/imp


推薦閱讀:

圖網路——融合推理與學習的全新深度學習架構 | AI&Society第八期預告
細粒度分類SCDA&Recurrent Attention Convolutional Neural Networ
【深度學習系列】遷移學習Transfer Learning
FCNN全卷積神經網路補充
人工神經網路

TAG:生成對抗網路GAN | 神經網路 | 深度學習DeepLearning |