一文概覽用於數據集增強的生成對抗網路架構
編者按:機器學習開發者Pedro Ferreira介紹了他在jungle.ai進行的生成對抗網路(GAN)應用研究,回顧了有助於數據集增強的GAN領域的研究進展,並授權【論智】把內容分享給國內讀者。
生成對抗網路(Generative Adversarial Network,GAN)迅猛地佔領了機器學習社區。優雅的理論基礎和在計算機視覺領域不斷提升的優越表現使其成為近年來機器學習最活躍的研究課題之一。事實上,Facebook AI Research的領導人Yann Lecun在2016年說過,「在我看來,GAN及其新提出的變體是機器學習在過去10年最有意思的想法。」想要了解這一課題的最新進展,請參閱這篇The GAN Zoo(GAN動物園)。
儘管GAN已被證明是很出色的圖像生成模型,例如生成面部圖像和卧室圖像,GAN尚未在其他數據集上進行過廣泛測試,例如由工廠提供的數據集,其中包含大量來自生產線上的感測器的測量值。不同於諸如圖片之類的靜態數據,這樣的數據集甚至可能包括時序信息,機器學習模型需要利用這些時序信息預測未來的事件。在這類數據上應用生成模型可能很有用,例如,如果我們的預測模型需要更多樣本進行訓練以提升其概括性。另外,如果我們提出一個可以生成優質合成數據的模型,那麼這個模型必定學習到了原始數據的潛在結構。既然模型學習到了潛在結構,預測模型就可以將該表示作為新特徵集來利用!
本文將介紹一些可能有助於數據集增強的GAN體系結構,包括樣本增強和特徵增強。讓我們從基本的GAN開始。
生成對抗網路
GAN模型由兩部分組成:生成器(generator)和判別器(discriminator)。這裡我們認為它們都是由參數確定的神經網路:G和D。判別網路的參數為最大化正確區分真實數據和偽造數據(生成網路偽造的數據)的概率這一目標而優化,而生成網路的目標是最大化判別網路不能識別其偽造的樣本的概率。
生成網路如此產生樣本:接受一個輸入向量z,該向量取樣自一個潛分布(latent distribution),應用由網路定義的函數G至該向量,得到G(z)。判別網路交替接受G(z)和x(一個真實數據樣本),輸出輸入為真的概率。
通過適當的超參數調優和足夠的訓練迭代次數,生成網路和判別網路將一起收斂(通過梯度下降方法進行參數更新)至描述偽造數據的分布和取樣真實數據的分布相一致的點。
本文接下來的部分將通過基於MNIST數據集生成新數字或編碼原始數字至潛空間來演示GAN是如何工作的。我們也會看下如何將GAN應用到類別數據和時序數據上。
作為開始,下面是一個在MNIST數據集上訓練的、基於多層感知器(MLP)的簡單GAN模型生成的一些樣本。
GAN並非盡善盡美
儘管GAN能如我們所見的那樣工作,在實踐中,GAN有一些缺點,自Ian Goodfellow等在2014年發表GAN的原始論文起,如何克服GAN的缺點一直是研究的熱點。GAN的主要缺點涉及它的訓練,GAN因極難訓練而聲名狼藉:首先,GAN的訓練高度依賴超參數。其次,也是最重要的,(生成網路和判別網路的)損失函數不提供必要的信息:儘管生成的樣本可能已經開始貼切地重現真實數據——顯著逼近真實數據的分布——一般而言無法通過損失的趨勢來指示這一表現。這意味著我們不能基於損失運行skopt之類的超參數優化器,相反必須手工迭代調優,真是可恥。
GAN架構的另一個缺點和它的功能有關。使用圖一顯示的基於原始的交叉熵損失的GAN,我們無法:
- 控制生成什麼數據。
- 生成類別數據。
- 訪問潛空間以便將其作為特徵使用。
生成類別數據對GAN而言是一個特大難題。Ian Goodfellow在這個reddit帖子中以非常直觀的方式解釋了這一點:
僅當合成數據基於連續數值時,你才能對合成數據作出微小的改動。基於離散數值無法作出微小的改動。
例如,如果你輸出的圖像的像素值為1.0,你可以在下一步將該像素值改為1.0001.如果你輸出單詞「企鵝」,你無法在下一步將其修改為「企鵝 + .001」,因為並不存在「企鵝 + .001」這樣的單詞。你需要經歷從「企鵝」到「鴕鳥」的整個過程。
關鍵的想法是,生成網路不可能從一個實體(如「企鵝」)一路前進到另一個實體(如「鴕鳥」)。因為兩者之間的空間出現實體的概率為0,判別網路可以輕易地識別出該空間內的樣本是不真實的,因而它不可能被生成網路所愚弄。
GAN變體
為了解決原始GAN的問題,研發了一些其他的訓練方式和架構。下面將加以簡要介紹。這些介紹的目標是讓你對如何應用這些方法至結構化數據(比如Kaggle競賽中的數據)有所了解。
條件GAN
前面提到的GAN能生成看起來像MNIST數據集中的隨機數字。但是如果我們想生成特定數字呢?只需在訓練過程中做出一個小小的改動,我們就能告訴生成網路生成我們所要求的數字。在每次迭代中,生成網路的輸入不僅包括z,還包括指明數字的one-hot編碼向量。同樣,判別網路的輸入不僅包括真實樣本或偽造樣本,還包括同樣的標籤向量。
基於與前述GAN相同的流程,但是加上了這一輸出上的微小改動,條件GAN(CGAN)學習生成以輸入的標籤為條件的樣本。
讓我們為每個數字生成一個樣本!在潛空間取樣時,我們同時輸入一個one-hot編碼的向量指明我們所需的分類。對所有10個分類中的數字進行這一過程,得到圖四的結果:
Wasserstein GAN
Wasserstein GAN(WGAN)是最流行的GAN之一,它改變了目標,從而提高了訓練穩定性和可解釋性(損失和樣本質量的相關性),同時能夠生成類別數據。關鍵點在於,生成網路的目標是逼近真實數據分布,因此衡量分布間的距離的指標很重要,因為該指標將是最小化的目標。WGAN選擇了Wasserstein距離。Wasserstein距離也稱為推土機(Earth-Mover)距離。另外,WGAN實際上採用的是Wasserstein距離的近似。WGAN選擇Wasserstein距離是因為Wasserstein距離能在Kullback-Leibler散度和Jensen-Shannon散度無法收斂的分布上收斂。如果你對理論感興趣,可以看下原始論文或這篇出色總結Read-through: Wasserstein GAN。
在實現層面,總結一下逼近Wasserstein距離意味著什麼:
- 判別器的輸出不再是概率了,這也是將判別器改名為批評者(critic)的動機。
- 判別器的參數截斷至某個閾值(或者進行梯度懲罰)。
- 在每個訓練迭代中,判別器的參數比生成器的參數更新更頻繁。
用於類別數據的Wasserstein GAN
WGAN論文的作者展示了通過這種方式訓練的GAN顯示了訓練上的穩定性和可解釋性,但之後有研究證明,Wasserstein距離的使用賦予了GAN生成類別(categorical)數據的能力(即,並非圖像之類的連續值數據,甚至不是像用1表示周日、用2表示周一這樣的整型編碼數據)。當在這類數據上訓練原始的GAN時,判別網路的損失會在多次迭代中保持較低的水平,而生成網路的損失會不停增長。而WGAN在類別數據上訓練的方式和在連續值數據一樣。
我們只需如此做(圖五是一個例子):數據集中的每個類別變數都對應一個生成網路的softmax輸出,該輸出的維度和可能的離散值數目相等。判別網路並不接受one-hot編碼的softmax輸出作為輸入,相反,將原始的softmax輸出當做一組連續值變數,傳給判別網路作為輸入。這樣訓練就能收斂!在測試時,只需one-hot編碼生成網路的離散輸出即可生成偽造的類別數據。
上圖中的類別變數1為3個可能值中的1個,類別變數2為2個可能值中的1個。此外還有1個連續變數。
圖六展示了一個在類別值的數據集上訓練基於梯度懲罰的WGAN的例子,你可以在圖中看到穩定的、收斂的損失函數的美麗曲線。這一個例子是在Kaggle競賽中的Sberbank Russian Housing Market數據集(俄羅斯聯邦儲蓄銀行的房產市場數據集)上訓練的,該數據集同時包含連續變數和類別變數。
當然,你也可以組合WGAN和CGAN,以監督學習的方式訓練WGAN,以生成以分類標籤為條件的樣本!
注意:Cramer GAN進一步改進了Wasserstein GAN,其目標是提供質量更優的樣本,同時提高訓練穩定性。是否能用它生成類別數據是以後的研究課題。
雙向GAN
儘管WGAN看上去解決了很多問題,但它不允許訪問數據的潛空間表示。尋找這樣的表示可能很有幫助,不僅是因為可以通過在潛空間的連續移動控制生成什麼樣的數據,還因為可以通過潛空間提取特徵。
雙向GAN(Bidirectional GAN,BiGAN)是解決這一問題的一個嘗試。它如此工作:不僅學習一個生成式網路,同時學習一個編碼網路E,該編碼網路映射數據至生成網路的潛空間。對抗配置中,使用一個判別網路應對生成任務和編碼任務。BiGAN的作者展示了,在這一限制下,G和E這一對網路形成了一個自動編碼器(autoencoder):通過E編碼數據樣本,再通過G解碼,可以得到原始樣本。
InfoGAN
之前我們看到,CGAN允許調節生成網路以根據標籤生成樣本。不過,是否可以通過在GAN的潛空間中強制一個類別化的結構,以完全無監督的方式學習辨別數字呢?可不可以設置一個連續的代碼空間,讓我們可以訪問這一空間以描述數據樣本的連續語義變體呢?(在MNIST的例子中,連續語義變體可能是數字的寬度和斜度。)
上述兩個問題的答案都是可以。比那更好的是:我們可以同時做到這兩點。真相是,我們可以施加任何我們發現有用的代碼空間分布,然後訓練GAN編碼這些分布中有意義的特性。每份代碼將學習包含數據的不同語義特性,結果等效於信息退相干(information disentanglement)。
允許我們這麼乾的GAN是InfoGAN。簡單來說,InfoGAN試圖最大化生成網路代碼空間和推斷網路輸出的共同信息。推斷網路可以簡單配置為判別網路的一個輸出層,共享其他參數,意味著它是算力免費(computationally free)的。一旦訓練完成,InfoGAN的判別網路的推斷輸出層可以用來提取特徵,或者,如果代碼空間包含標籤信息,可以用來分類!
創建一個配有兩個代碼空間的InfoGAN——一個連續的二維空間和一個離散的十維空間——我們能夠以離散代碼為條件生成特定的數字,同時以連續代碼為條件生成特定風格的數字,生成如圖九所示的數據。注意,在整個無監督學習計劃中,沒有標籤的位置——在潛空間中施加一個類別分布足以讓模型學習編碼該分布的標籤信息!
對抗自動編碼器
對抗自動編碼器(Adversarial Autoencoder,AAE)結合了自動編碼器和GAN。這一模型優化兩個目標:其一,最小化通過編碼網路P和解碼網路Q的數據x的重建錯誤。其二,通過對抗訓練在代碼P(x)上施加一個先驗分布,在對抗訓練中,P為生成網路。所以,優化P和Q以最小化x和Q(z)的距離,其中z是自動編碼器的代碼空間向量,同時優化作為GAN的P和D,以迫使代碼空間P(x)匹配預先定義的結構。這可以看成對自動編碼器的正則化,迫使它學習有意義、結構化、內聚的代碼空間(而不是斷裂的代碼空間,參考Geoffrey Hinton講座筆記第76頁),以允許進行有效的特徵提取和降維。同時,由於在代碼上施加了一個已知先驗分布,從該先驗分布取樣,並將樣本傳給解碼網路Q,形成了一個生成式建模計劃!
讓我們在自動編碼器的對抗訓練中,在代碼空間上施加一個標準差為5的二維高斯分布。取樣該空間的相鄰點,得到一些生成數字的連續變體集。
我們還可以基於標籤訓練AAE,以強制標籤和數字風格信息的退相干。這樣,通過固定想要的標籤,施加的連續潛空間中的變體將導致不同風格的同一數字。以數字八為例:
很明顯,相鄰點間存在有意義的關係!為我們的數據集增強問題生成樣本時,這一性質可能會提供便利。
時序數據?
現實世界的結構化數據常常包含時序。在這樣的數據中,每個樣本和之前的樣本間存在某種依賴關係。經常選擇基於循環神經網路的模型來處理這種數據,原因是它們具備建模這種數據的內在能力。在我們的GAN模型中利用這些神經網路,在原則上可以產生更高質量的樣本和特徵!
循環GAN
讓我們將之前的GAN中的MLP替換為RNN,就像這篇論文所做的那樣。具體而言,我們將採用RNN的變體長短時記憶(LSTM)單元(事實上我們在談論深度學習最最時髦的行話——哎喲,我又這麼幹了),在波形(Waves)數據集上進行訓練。這一數據集包含偏移、頻率、振幅不同的一維正弦信號和鋸齒信號,所有信號的時步相同。從RNN的視角來看,每個樣本包含一個30時步的波形。
我們的CGAN的生成網路和判別網路都將採用基於LSTM的神經網路,將其轉化為一個RCGAN。我們將訓練該RCGAN學習按需生成正弦、鋸齒波形。
訓練之後,我們也將查看潛空間中的變體是如何產生生成樣本特性體現的連續變化的。具體而言,如果我們施加一個二維正態分布潛空間,並將分類標籤固定為正弦波形,我們將得到圖十四中顯示的樣本。其中,我們能很明顯地看到頻率和振幅由低到高的連續變化,這意味著RCGAN學習到了一個有意義的潛空間!
儘管在GAN中使用RNN對生成實值的序列化數據很有用,它仍然無法用於離散序列,是否可以配合RNN使用Wasserstein距離尚不清楚(在RNN上施加Lipschitz限制是以後的研究課題)。SeqGAN和最近的ARAE的目標是解決這一問題。
結論
我們看到,在因為GAN具有生成非常酷的圖像的能力而生成的那些大驚小怪的報道(看過沒有?)之外,一些架構也可能有助於處理更一般的機器學習問題,包括連續和離散的數據。本文的目的是介紹這一想法,並不打算嚴格地比較這些多用途生成式模型,不過本文確實證明了應該進行這樣的涉及GAN的研究。
原文地址: Towards data set augmentation with GANs
歡迎專註公眾號:論智(jqr_AI)
推薦閱讀: