Controllable Invariance through Adversarial Feature Learning

前言

上周在公眾號裡面說要回歸了,回歸的第一期要寫《Autumn is coming——GAN眼中的四季變化》,對GAN做圖像翻譯(編輯)做一些總結。後來發現把圖像編輯(很多圖像編輯任務也可以看出圖像翻譯)也加進來,工作量有點大。這個我慢慢寫。囊括的文章可以先列出來(不完整):

  • The Conditional Analogy GAN: Swapping Fashion Articles on People Images
  • Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
  • Learning to Discover Cross-Domain Relations with Generative Adversarial Networks
  • Unsupervised Cross-Domain Image Generation
  • DualGAN: Unsupervised Dual Learning for Image-to-Image Translation
  • Face Aging With Conditional Generative Adversarial Networks
  • Fader Networks: Manipulating Images by Sliding Attributes
  • GeneGAN: Learning Object Transfiguration and Attribute Subspace from Unpaired Data
  • Neural Photo Editing with Introspective Adversarial Networks
  • Invertible Conditional GANs for image editing
  • Image-to-Image Translation with Conditional Adversarial Networks

歡迎給我推薦列表裡面沒有的圖像翻譯的文章。

這個總結還在寫,先放一個截圖:

廢話了這麼多,回歸今天的主題,今天要說的這篇文章是Controllable Invariance through Adversarial Feature Learning,利用GAN學習具有某些不變性的特徵,也就是提取跟分類相關的特徵,而忽略與分類無關的屬性,比如:訓練一個網路分類微笑還是嚴肅,這個分類問題與性別、膚色、戴眼鏡等無關,可以要求分類器提取的特徵與這些屬性無關(相互獨立)。

Invariant Feature Learning by GAN

與傳統的GAN在樣本層面對抗不同,這篇文章是想通過GAN提取到具有某些不變性的特徵,在特徵的層面進行對抗。為方便描述,姑且稱文章的方法叫IFL-GAN吧。

我們有圖像 (X, y) ,圖像 X 具有一些跟分類無關的屬性 s ,IFL-GAN的目標是學習一個特徵提取器,能夠提取與 s 相互獨立的特徵 h ,也就是 P(y|h,s) = P(y|h)

具體來說,IFL-GAN涉及三個player:E(Encoder),C(Classifier),D(Discriminator)。

E的任務是編碼,C的目標當然就是分類,D則是從E編碼的特徵 h 中把與分類不相關的屬性 s 預測出來。我們的目的是編碼與屬性 s 無關,D想盡辦法想從 h 中預測出屬性 s ,E就需要抵抗住它的「攻擊」,儘可能地將屬性 s 從編碼中剝離出去。形象的說就是:

E:我幹活乾淨利索還高效,不拖泥帶水! (底層民工,被剝削階級)

D:你說你不拖泥帶水,我不信!(監工,總是看E不爽)

C:是不是拖泥帶水我不管,我只關心你是不是把活幹完了。(老闆)

將這個對抗寫成minmax博弈就是

min_{E,C} max_D J(E,C,D) = mathbb{E}_{X,s,y sim p(X,s,y)} [gamma log q_D(s|h=E(X,s)) - log q_C(y|h=E(X,s))]

也就是

C^* = max_C mathbb{E}_{X,s,y sim p(X,s,y)} log q_C(y|h=E(X,s))

D^* = max_D mathbb{E}_{X,s,y sim p(X,s,y)} log q_D(s|h=E(X,s))

E^* = min_E mathbb{E}_{X,s,y sim p(X,s,y)} [gamma log q_D(s|h=E(X,s)) - log q_C(y|h=E(X,s))]

其中, gamma 是一個超參,控制不變性限制(invariant constraint)的強度。

Loss推導

細心的讀者或許已經發現了,模型示例圖涉及到 s1-s ,它們是怎麼來的?

注意到上面的loss都是通過似然函數來定義的,分類器的loss( log q_C 項)就是交叉熵(這其實是一個結論,它的推導跟下面一樣)。下面的推導假設 s 是單一屬性,多個屬性的情況通過假定屬性之間相互獨立得到解決。

如果我們要求屬性 s 都是二值的, s in {0, 1} ,跟GAN的loss推導一樣,一個簡單的做法是,假設 s 滿足伯努利分布,D的輸出是伯努利分布的成功概率,於是

log q_D(s|h) = log [(D(h))^{s}(1-D(h))^{1-s}] = s cdot log D(h) + (1-s) cdot log (1-D(h))

可以推廣到多個獨立屬性的情形:

begin{align} log q_D(s|h) = sum_{i=1}^{|s|} log [(D(h)_i)^{s_i}(1-D(h)_i)^{1-s_i}] &= sum_{i=1}^{|s|} s_i cdot log D(h)_i + (1-s_i) cdot log (1-D(h)_i)  &= langle s, log(D(h)) rangle + langle 1-s, log(1-D(h)) rangle end{align}

IFL-GAN文章只討論了每個屬性都是二值的,也就是上面的情形,那麼連續的情形怎麼做呢?

理論上只要給出了分布,似然函數就可以計算。下面給出多元獨立高斯分布下的推導:

mu, sigma = D(h)  log q_D(s|h) = sum_{i=1}^{|s|} -0.5log (2pi) - log sigma_i - frac{(s_i - mu_i)^2}{2sigma_i^2}

如果屬性 s 的方差已知,那麼似然函數其實就是重構誤差。

呵呵,至此,作者的一個future work我們已經做完了。。。

那structured的情形呢?(逃......其實差不多,只是寫起來比較麻煩。

實驗

論文中的實驗結果我就不放了。

我做了一個復現,做的是MNIST+SVHN數字分類。在這個任務中,與分類無關的屬性就是它是背景乾淨的手寫體( s=0 )還是背景複雜的列印體( s=1 )。MNIST是單通道圖像,SVHN是RGB圖像,我在實驗中簡單地將MNIST圖像複製擴充成三通道圖像。

實驗結果呢,效果還不錯。不過,相同的分類器下,分類準確率並沒有直接混在一起訓練高(還沒有進一步調參)。下面的曲線是混合訓練,測MNIST和SVHN各自準確率的結果。

這個實驗調參還沒有做完,以及還沒有評估它提取的特徵更好,此外,屬性只用了一維對於這兩個數據來說,似乎有點少,可以補充一部分背景乾淨的列印體數字,或許效果會更好。等實驗整理好了會掛github上。

Related Work

這裡並不想討論跟invariant feature learning相關的文章,我們來看一篇做法跟它幾乎一樣的文章,文章幾乎是同期出來的。一篇是CMU學生做的,一篇是FB做的。

這篇文章名字叫Fader Networks: Manipulating Images by Sliding Attributes。除了任務不同,它的做法簡直跟IFL-GAN如出一轍。fader nets是做圖到圖編輯的,而IFL-GAN是做特徵提取,或者說分類任務。先來看一下fader nets模型。

Encoder:我幹活很高效,但不是我的活我不幹 。

Discriminator:這些,這些,還有這些都歸你做!

Decoder:你們打架我也不管,我只關心你是不是把活幹完了。

如果重構誤差用MSE,那麼它們連目標函數都可以寫成一樣的,MSE可以看成高斯分布的log likelihood(推導可以看上一節 s 是高斯分布的情形)。兩個模型是一樣的,這裡就不展開細講了(反正後面寫圖像翻譯總結還要提到fader nets......)

參考文獻

  1. Xie Q, Dai Z, Du Y, et al. Controllable Invariance through Adversarial Feature Learning[J]. arXiv preprint arXiv:1705.11122, 2017.
  2. Lample G, Zeghidour N, Usunier N, et al. Fader Networks: Manipulating Images by Sliding Attributes[J]. arXiv preprint arXiv:1706.00409, 2017.

推薦閱讀:

用語音mfcc參數作為特徵,利用SVM來進行分類判斷聲音是否是嬰兒哭聲?
在機器學習的項目中,特徵是如何被找出來的?
像微博這種短文本的分析,用什麼方法提取特徵比較好呢?
有哪些利用高維空間及特性解決低維空間的問題的方法和例子?

TAG:生成对抗网络GAN | 特征提取 |