送書 | AI插畫師:如何用基於PyTorch的生成對抗網路生成動漫頭像?
本文由 「AI前線」原創,原文鏈接:送書 | AI插畫師:如何用基於PyTorch的生成對抗網路生成動漫頭像?
作者|陳雲編輯|Natalie
AI 前線導讀:」2016 年是屬於 TensorFlow 的一年,憑藉谷歌的大力推廣,TensorFlow 佔據了各大媒體的頭條。2017 年年初,PyTorch 的橫空出世吸引了研究人員極大的關注,PyTorch 簡潔優雅的設計、統一易用的介面、追風逐電的速度和變化無方的靈活性給人留下深刻的印象。
本文節選自《深度學習框架 PyTorch 入門與實踐》第 7 章,為讀者講解當前最火爆的生成對抗網路(GAN),帶領讀者從零開始實現一個動漫頭像生成器,能夠利用 GAN 生成風格多變的動漫頭像。注意啦,文末有送書福利!」
生成對抗網路(Generative Adversarial Net,GAN)是近年來深度學習中一個十分熱門的方向,卷積網路之父、深度學習元老級人物 LeCun Yan 就曾說過「GAN is the most interesting idea in the last 10 years in machine learning」。尤其是近兩年,GAN 的論文呈現井噴的趨勢,GitHub 上有人收集了各種各樣的 GAN 變種、應用、研究論文等,其中有名稱的多達數百篇。作者還統計了 GAN 論文發表數目隨時間變化的趨勢,如圖 7-1 所示,足見 GAN 的火爆程度。
圖 7-1 GAN 的論文數目逐月累加圖
GAN 的原理簡介
GAN 的開山之作是被稱為「GAN 之父」的 Ian Goodfellow 發表於 2014 年的經典論文 Generative Adversarial Networks ,在這篇論文中他提出了生成對抗網路,並設計了第一個 GAN 實驗——手寫數字生成。
GAN 的產生來自於一個靈機一動的想法:
「What I cannot create, I do not understand.」(那些我所不能創造的,我也沒有真正地理解它。)
—Richard Feynman
類似地,如果深度學習不能創造圖片,那麼它也沒有真正地理解圖片。當時深度學習已經開始在各類計算機視覺領域中攻城略地,在幾乎所有任務中都取得了突破。但是人們一直對神經網路的黑盒模型表示質疑,於是越來越多的人從可視化的角度探索卷積網路所學習的特徵和特徵間的組合,而 GAN 則從生成學習角度展示了神經網路的強大能力。GAN 解決了非監督學習中的著名問題:給定一批樣本,訓練一個系統能夠生成類似的新樣本。
生成對抗網路的網路結構如圖 7-2 所示,主要包含以下兩個子網路。
- 生成器(generator):輸入一個隨機雜訊,生成一張圖片。
- 判別器(discriminator):判斷輸入的圖片是真圖片還是假圖片。
圖 7-2 生成對抗網路結構圖
訓練判別器時,需要利用生成器生成的假圖片和來自真實世界的真圖片;訓練生成器時,只用雜訊生成假圖片。判別器用來評估生成的假圖片的質量,促使生成器相應地調整參數。
生成器的目標是儘可能地生成以假亂真的圖片,讓判別器以為這是真的圖片;判別器的目標是將生成器生成的圖片和真實世界的圖片區分開。可以看出這二者的目標相反,在訓練過程中互相對抗,這也是它被稱為生成對抗網路的原因。
上面的描述可能有點抽象,讓我們用收藏齊白石作品(齊白石作品如圖 7-3 所示)的書畫收藏家和假畫販子的例子來說明。假畫販子相當於是生成器,他們希望能夠模仿大師真跡偽造出以假亂真的假畫,騙過收藏家,從而賣出高價;書畫收藏家則希望將贗品和真跡區分開,讓真跡流傳於世,銷毀贗品。這裡假畫販子和收藏家所交易的畫,主要是齊白石畫的蝦。齊白石畫蝦可以說是畫壇一絕,歷來為世人所追捧。
圖 7-3 齊白石畫蝦圖真跡
在這個例子中,一開始假畫販子和書畫收藏家都是新手,他們對真跡和贗品的概念都很模糊。假畫販子仿造出來的假畫幾乎都是隨機塗鴉,而書畫收藏家的鑒定能力很差,有不少贗品被他當成真跡,也有許多真跡被當成贗品。
首先,書畫收藏家收集了一大堆市面上的贗品和齊白石大師的真跡,仔細研究對比,初步學習了畫中蝦的結構,明白畫中的生物形狀彎曲,並且有一對類似鉗子的「螯足」,對於不符合這個條件的假畫全部過濾掉。當收藏家用這個標準到市場上進行鑒定時,假畫基本無法騙過收藏家,假畫販子損失慘重。但是假畫販子自己仿造的贗品中,還是有一些矇騙過關,這些矇騙過關的贗品中都有彎曲的形狀,並且有一對類似鉗子的「螯足」。於是假畫販子開始修改仿造的手法,在仿造的作品中加入彎曲的形狀和一對類似鉗子的「螯足」。除了這些特點,其他地方例如顏色、線條都是隨機畫的。假畫販子製造出的第一版贗品如圖 7-4 所示。
圖 7-4 假畫販子製造的第一版贗品
當假畫販子把這些畫拿到市面上去賣時,很容易就騙過了收藏家,因為畫中有一隻彎曲的生物,生物前面有一對類似鉗子的東西,符合收藏家認定的真跡的標準,所以收藏家就把它當成真跡買回來。隨著時間的推移,收藏家買回越來越多的假畫,損失慘重,於是他又閉門研究贗品和真跡之間的區別,經過反覆比較對比,他發現齊白石畫蝦的真跡中除了有彎曲的形狀,蝦的觸鬚蔓長,通身作半透明狀,並且畫的蝦的細節十分豐富,蝦的每一節之間均呈白色狀。
收藏家學成之後,重新出山,而假畫販子的仿造技法沒有提升,所製造出來的贗品被收藏家輕鬆識破。於是假畫販子也開始嘗試不同的畫蝦手法,大多都是徒勞無功,不過在眾多嘗試之中,還是有一些贗品騙過了收藏家的眼睛。假畫販子發現這些仿製的贗品觸鬚蔓長,通身作半透明狀,並且畫的蝦的細節十分豐富,如圖 7-5 所示。於是假畫販子開始大量仿造這種畫,並拿到市面上銷售,許多都成功地騙過了收藏家。
圖 7-5 假畫販子製造的第二版贗品
收藏家再度損失慘重,被迫關門研究齊白石的真跡和贗品之間的區別,學習齊白石真跡的特點,提升自己的鑒定能力。就這樣,通過收藏家和假畫販子之間的博弈,收藏家從零開始慢慢提升了自己對真跡和贗品的鑒別能力,而假畫販子也不斷地提高自己仿造齊白石真跡的水平。收藏家利用假畫販子提供的贗品,作為和真跡的對比,對齊白石畫蝦真跡有了更好的鑒賞能力;而假畫販子也不斷嘗試,提升仿造水平,提升仿造假畫的質量,即使最後製造出來的仍屬於贗品,但是和真跡相比也很接近了。收藏家和假畫販子二者之間互相博弈對抗,同時又不斷促使著對方學習進步,達到共同提升的目的。
在這個例子中,假畫販子相當於一個生成器,收藏家相當於一個判別器。一開始生成器和判別器的水平都很差,因為二者都是隨機初始化的。訓練過程分為兩步交替進行,第一步是訓練判別器(只修改判別器的參數,固定生成器),目標是把真跡和贗品區分開;第二步是訓練生成器(只修改生成器的參數,固定判別器),為的是生成的假畫能夠被判別器判別為真跡(被收藏家認為是真跡)。這兩步交替進行,進而分類器和判別器都達到了一個很高的水平。訓練到最後,生成器生成的蝦的圖片(如圖 7-6 所示)和齊白石的真跡幾乎沒有差別。
圖 7-6 生成器生成的蝦
下面我們來思考網路結構的設計。判別器的目標是判斷輸入的圖片是真跡還是贗品,所以可以看成是一個二分類網路,參考第 6 章中 Dog vs. Cat 的實驗,我們可以設計一個簡單的卷積網路。生成器的目標是從雜訊中生成一張彩色圖片,這裡我們採用廣泛使用的 DCGAN(Deep Convolutional Generative Adversarial Networks)結構,即採用全卷積網路,其結構如圖 7-7 所示。網路的輸入是一個 100 維的雜訊,輸出是一個 3×64×64 的圖片。這裡的輸入可以看成是一個 100×1×1 的圖片,通過上卷積慢慢增大為 4×4、8×8、16×16、32×32 和 64×64。上卷積,或稱轉置卷積,是一種特殊的卷積操作,類似於卷積操作的逆運算。當卷積的 stride 為 2 時,輸出相比輸入會下採樣到一半的尺寸;而當上卷積的 stride 為 2 時,輸出會上採樣到輸入的兩倍尺寸。這種上採樣的做法可以理解為圖片的信息保存於 100 個向量之中,神經網路根據這 100 個向量描述的信息,前幾步的上採樣先勾勒出輪廓、色調等基礎信息,後幾步上採樣慢慢完善細節。網路越深,細節越詳細。
圖 7-7 DCGAN 中生成器網路結構圖
在 DCGAN 中,判別器的結構和生成器對稱:生成器中採用上採樣的卷積,判別器中就採用下採樣的卷積,生成器是根據雜訊輸出一張 64×64×3 的圖片,而判別器則是根據輸入的 64×64×3 的圖片輸出圖片屬於正負樣本的分數(概率)。
用 GAN 生成動漫頭像
本節將用 GAN 實現一個生成動漫人物頭像的例子。在日本的技術博客網站上 有個博主(估計是一位二次元的愛好者),利用 DCGAN 從 20 萬張動漫頭像中學習,最終能夠利用程序自動生成動漫頭像,生成的圖片效果如圖 7-8 所示。源程序是利用 Chainer 框架實現的,本節我們嘗試利用 PyTorch 實現。
圖 7-8 DCGAN 生成的動漫頭像
原始的圖片是從網站中爬取的,並利用 OpenCV 從中截取頭像,處理起來比較麻煩。這裡我們使用知乎用戶何之源爬取並經過處理的 5 萬張圖片。可以從本書配套程序的 README.MD 的百度網盤鏈接下載所有的圖片壓縮包,並解壓縮到指定的文件夾中。需要注意的是,這裡圖片的解析度是 3×96×96,而不是論文中的 3×64×64,因此需要相應地調整網路結構,使生成圖像的尺寸為 96。
我們首先來看本實驗的代碼結構。
接著來看 model.py 中是如何定義生成器的。
可以看出生成器的搭建相對比較簡單,直接使用 nn.Sequential 將上卷積、激活、池化等操作拼接起來即可,這裡需要注意上卷積 ConvTransposed2d 的使用。當 kernel size 為 4、stride 為 2、padding 為 1 時,根據公式 H_out=(H_in-1)*stride-2*padding+kernel_size
,輸出尺寸剛好變成輸入的兩倍。最後一層採用 kernel size 為 5、stride 為 3、padding 為 1,是為了將 32×32 上採樣到 96×96,這是本例中圖片的尺寸,與論文中 64×64 的尺寸不一樣。最後一層用 Tanh 將輸出圖片的像素歸一化至 -1~1,如果希望歸一化至 0~1,則需使用 Sigmoid。接著我們來看判別器的網路結構。
可以看出判別器和生成器的網路結構幾乎是對稱的,從卷積核大小到 padding、stride 等設置,幾乎一模一樣。例如生成器的最後一個卷積層的尺度是(5,3,1),判別器的第一個卷積層的尺度也是(5,3,1)。另外,這裡需要注意的是生成器的激活函數用的是 ReLU,而判別器使用的是 LeakyReLU,二者並無本質區別,這裡的選擇更多是經驗總結。每一個樣本經過判別器後,輸出一個 0~1 的數,表示這個樣本是真圖片的概率。在開始寫訓練函數前,先來看看模型的配置參數。
這些只是模型的默認參數,還可以利用 Fire 等工具通過命令行傳入,覆蓋默認值。另外,我們也可以直接使用 opt.attr,還可以利用 IDE/IPython 提供的自動補全功能,十分方便。這裡的超參數設置大多是照搬 DCGAN 論文的默認值,作者經過大量實驗,發現這些參數能夠更快地訓練出一個不錯的模型。
當我們下載完數據之後,需要將所有圖片放在一個文件夾,然後將該文件夾移動至 data 目錄下(請確保 data 下沒有其他的文件夾)。這種處理方式是為了能夠直接使用 torchvision 自帶的 ImageFolder 讀取圖片,而不必自己寫 Dataset。數據讀取與載入的代碼如下:
可見,用 ImageFolder 配合 DataLoader 載入圖片十分方便。
在進行訓練之前,我們還需要定義幾個變數:模型、優化器、雜訊等。
在載入預訓練模型時,最好指定 map_location。因為如果程序之前在 GPU 上運行,那麼模型就會被存成 torch.cuda.Tensor,這樣載入時會默認將數據載入至顯存。如果運行該程序的計算機中沒有 GPU,載入就會報錯,故通過指定 map_location 將 Tensor 默認載入入內存中,待有需要時再移至顯存中。
下面開始訓練網路,訓練步驟如下。
(1)訓練判別器。
- 固定生成器
- 對於真圖片,判別器的輸出概率值儘可能接近 1
- 對於生成器生成的假圖片,判別器儘可能輸出 0
(2)訓練生成器。
- 固定判別器
- 生成器生成圖片,儘可能讓判別器輸出 1
(3)返回第一步,循環交替訓練。
這裡需要注意以下幾點。
- 訓練生成器時,無須調整判別器的參數;訓練判別器時,無須調整生成器的參數。
- 在訓練判別器時,需要對生成器生成的圖片用 detach 操作進行計算圖截斷,避免反向傳播將梯度傳到生成器中。因為在訓練判別器時我們不需要訓練生成器,也就不需要生成器的梯度。
- 在訓練分類器時,需要反向傳播兩次,一次是希望把真圖片判為 1,一次是希望把假圖片判為 0。也可以將這兩者的數據放到一個 batch 中,進行一次前向傳播和一次反向傳播即可。但是人們發現,在一個 batch 中只包含真圖片或只包含假圖片的做法最好。
- 對於假圖片,在訓練判別器時,我們希望它輸出為 0;而在訓練生成器時,我們希望它輸出為 1。因此可以看到一對看似矛盾的代碼:error_d_fake = criterion(fake_output, fake_labels) 和 error_g = criterion(fake_output, true_labels)。其實這也很好理解,判別器希望能夠把假圖片判別為 fake_label,而生成器則希望能把它判別為 true_label,判別器和生成器互相對抗提升。
接下來就是一些可視化的代碼。每次可視化使用的雜訊都是固定的 fix_noises,因為這樣便於我們比較對於相同的輸入,生成器生成的圖片是如何一步步提升的。另外,由於我們對輸入的圖片進行了歸一化處理(-1~1),在可視化時則需要將它還原成原來的 scale(0~1) 。
除此之外,還提供了一個函數,能載入預訓練好的模型,並利用雜訊隨機生成圖片。
完整的代碼請參考本書的附帶樣例代碼 chapter7/AnimeGAN。參照 README.MD 中的指南配置環境,並準備好數據,而後用如下命令即可開始訓練:
如果使用 visdom 的話,此時打開 http://[your ip]:8097 就能看到生成的圖像。
訓練完成後,我們可以利用生成網路隨機生成動漫頭像,輸入命令如下:
實驗結果分析
實驗結果如圖 7-9 所示,分別是訓練 1 個、10 個、20 個、30 個、40 個、200 個 epoch 之後神經網路生成的動漫頭像。需要注意的是,每次生成器輸入的雜訊都是一樣的,所以我們可以對比在相同的輸入下,生成圖片的質量是如何慢慢改善的。
剛開始生成的圖像比較模糊(1 個 epoch),但是可以看出圖像已經有面部輪廓。
繼續訓練 10 個 epoch 之後,生成的圖多了很多細節信息,包括頭髮、顏色等,但是總體還是很模糊。
訓練 20 個 epoch 之後,細節繼續完善,包括頭髮的紋理、眼睛的細節等,但還是有不少塗抹的痕迹。
訓練到第 40 個 epoch 時,已經能看出明顯的面部輪廓和細節,但還是有塗抹現象,並且有些細節不夠合理,例如眼睛一大一小,面部的輪廓扭曲嚴重。
當訓練到 200 個 epoch 之後,圖片的細節已經十分完善,線條更流暢,輪廓更清晰,雖然還有一些不合理之處,但是已經有不少圖片能夠以假亂真了。
圖 7-9 GAN 生成的動漫頭像
類似的生成動漫頭像的項目還有「用 DRGAN 生成高清的動漫頭像」,效果如圖 7-10 所示。但遺憾的是,由於論文中使用的數據涉及版權問題,未能公開。這篇論文的主要改進包括使用了更高質量的圖片數據和更深、更複雜的模型。
圖 7-10 用 DRGAN 生成的動漫頭像
本章講解的樣常式序還可以應用到不同的生成圖片場景中,只要將訓練圖片改成其他類型的圖片即可,例如 LSUN 客房圖片集、MNIST 手寫數據集或 CIFAR10 數據集等。事實上,上述模型還有很大的改進空間。在這裡,我們使用的全卷積網路只有四層,模型比較淺,而在 ResNet 的論文發表之後,也有不少研究者嘗試在 GAN 的網路結構中引入 Residual Block 結構,並取得了不錯的視覺效果。感興趣的讀者可以嘗試將示例代碼中的單層卷積修改為 Residual Block,相信可以取得不錯的效果。
近年來,GAN 的一個重大突破在於理論研究。論文 Towards Principled Methods for Training Generative Adversarial Networks 從理論的角度分析了 GAN 為何難以訓練,作者隨後在另一篇論文 Wasserstein GAN 中針對性地提出了一個更好的解決方案。但是 Wasserstein GAN 這篇論文在部分技術細節上的實現過於隨意,所以隨後又有人有針對性地提出 Improved Training of Wasserstein GANs,更好地訓練 WGAN。後面兩篇論文分別用 PyTorch 和 TensorFlow 實現,代碼可以從 GitHub 上搜索到。筆者當初也嘗試用 100 行左右的代碼實現了 Wasserstein GAN,感興趣的讀者可以去了解 。
隨著 GAN 研究的逐漸成熟,人們也嘗試把 GAN 用於工業實際問題之中,而在眾多相關論文中,最令人印象深刻的就是 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks ,論文中提出了一種新的 GAN 結構稱為 CycleGAN。CycleGAN 利用 GAN 實現風格遷移、黑白圖像彩色化,以及馬和斑馬相互轉化等,效果十分出眾。論文的作者用 PyTorch 實現了所有代碼,並開源在 GitHub 上,感興趣的讀者可以自行查閱。
本章主要介紹 GAN 的基本原理,並帶領讀者利用 GAN 生成動漫頭像。GAN 有許多變種,GitHub 上有許多利用 PyTorch 實現的各種 GAN,感興趣的讀者可以自行查閱。
作者介紹
陳雲,Python 程序員、Linux 愛好者和 PyTorch 源碼貢獻者。主要研究方向包括計算機視覺和機器學習。「2017 知乎看山杯機器學習挑戰賽」一等獎,「2017 天池醫療 AI 大賽」第八名。熱衷於推廣 PyTorch,並有豐富的使用經驗,活躍於 PyTorch 論壇和知乎相關板塊。
福利!福利!我們將給 AI 前線的粉絲送出《深度學習框架 PyTorch 入門與實踐》紙質書籍 10 本!在本文下方留言給出你想要這本書的理由,我們會邀請你加入贈書群,本次獲獎名單由抽獎小程序隨機抽取,2 月 6 日(周二)上午 10 點開獎,獲獎者每人獲得一本。另附京東購買地址,戳「閱讀原文」!
更多乾貨內容,可關注AI前線,ID:ai-front,後台回復「AI」、「TF」、「大數據」可獲得《AI前線》系列PDF迷你書和技能圖譜。
推薦閱讀:
※計算機視覺與影視業邂逅
※國內AI陷入自嗨,谷歌發現第二太陽!人工智慧找到外星人已不遠?
※人工智慧爆紅,能否成為聯想轉型的一支奇兵?
※郝景芳《人之彼岸》,人類迎接未來世界的正確姿態
※微軟AI單憑文字就可作畫,誰最先受到衝擊?