Meta Learning 入門:MAML 和 Reptile
本文介紹我最近學習的兩個 Meta Learning 的演算法:MAML 和 Reptile。原始論文分別見:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks 和 Reptile: a Scalable Metalearning Algorithm 。文章內容結合了李宏毅老師的課程、toy example,以及我個人的理解,可能文字有點多,但力求通俗易懂。
背景
Meta Learning 中文翻譯為「元學習」,它研究的不是如何提升模型解決某項具體的任務(分類,回歸,檢測)的能力,而是研究如何提升模型解決一系列任務的能力。
如果把訓練演算法類比成學生在學校的學習,那麼傳統的機器學習任務對應的是不同科目,例如數學、語文、英語,每個科目上訓練一個模型。而 Meta Learning 則是要提升一個學生整體的學習能力,讓學生學會學習(就是所謂的 learn to learn)。就像所有的學生都上一樣的課,做一樣的作業,可偏偏有的學生各科成績都好,有的學生偏科,而有的學生各科成績都差。
- 各科成績都好的學生,說明他大腦 Meta Learning 的能力強,可以迅速適應不同科目的學習任務。
- 而對於偏科的學生,他們大腦的 Meta Learning 能力就相對弱一些,只能學習某項具體的任務,換個任務就不 work 了。對這種學生,老師的建議一般是:「在弱科上多花一點時間」,可這麼做是有風險的,最糟糕的一種情況是:弱勢科目沒學好,強勢科目成績反而下降了。可以看到,現如今大多數深度神經網路都是「偏科生」,且不說分類、回歸這樣差別較大的任務對應的網路模型完全不同,即使同樣是分類任務,把人臉識別網路架構用在分類 ImageNet 數據上,就未必能達到很高的準確率。
- 至於各科成績都差的學生,說明他們不但 Meta Learning 能力弱,在任何科目上的學習能力都弱,需要被老師重點關照……
Meta Learning 的演算法有很多,有些高大上的演算法可以針對不同的訓練任務,輸出不同的神經網路結構和超參數,例如 Neural Architecture Search (NAS) 和 AutoML。這些演算法大多都相當複雜,我們普通人難以實現。另外一種比較容易實現的 Meta Learning 演算法,就是本文要介紹的 MAML 和 Reptile,它們不改變深度神經網路的結構,只改變網路的初始化參數。
從網路參數初始化談起
訓練神經網路的第一步是初始化參數。當前大多數深度學習框架都收錄了不同的參數初始化方法,例如均勻分布、正太分布,或者用 xavier_uniform
,kaiming_uniform
,xavier_normal
,kaiming_normal
等演算法。除了用隨機數,也可以用預訓練的網路參數來初始化神經網路,也就是所謂 transfer learning,或者更準確地說是 fine-tuning 的技術。如果不了解 fine-tuning 技術,可以閱讀這篇博客:Building powerful image classification models using very little data,它通過微調預訓練的 VGG-16 網路,用較少的數據訓練了一個高精度的貓狗分類器(這是我當年跑通的第一個 deep learning 演算法,從此走上煉丹的不歸路)。fine-tuning 之所以能 work,是因為預訓練的神經網路本身就有很強的特徵提取能力,能夠提取很多有含義的特徵,例如毛皮,耳朵,鼻子,眼睛,分辨貓狗,只需要知道這些特徵是如何組合的就好了,這比從頭開始學習如何提取毛皮、耳朵、鼻子等特徵要高效得多。
利用預訓練的網路進行參數初始化,相當於賦予了網路很多先驗知識。類比我們人類,讓一個小學沒畢業的人去聽高等數學,顯然他是無法聽懂的;而讓一個高考數學滿分的高中畢業生去聽,他可能要學得輕鬆得多。如果忽略智商因素,我們人類的大腦從結構上說都是大同小異,為啥表現差別那麼大呢?因為它們積累的知識量不同,後者積累的知識更多,也就是常說的「基礎紮實」,換成神經網路的術語,就是後者的網路只需要 fine-tune 一下就好了,而前者需要 train from scratch ,要補很多課才行。
通過上面的例子我們發現,預訓練的網路比隨機初始化的網路有更強的學習能力,因此 fine-tuning 也算是一種 Meta Learning 的演算法。它和我們今天要介紹的 MAML 以及 Reptile 都是通過初始化網路參數,使神經網路獲得更強的學習能力,從而在少量數據上訓練後就能有很好的性能。
訓練數據:以 task 為基本單位
對於傳統的機器學習問題,每個模型通常只用來解決一個任務,要麼是人臉識別,要麼是物體分類,要麼是物體檢測,等等。即使有多輸出的網路,例如同時檢測人臉的位置和關鍵點,本質上其訓練任務(task)還是一個,換一個任務又要從頭開始訓練(不考慮 fine-tuning 的技術)。而 MAML 專註於提升模型整體的學習能力,而不是解決某個具體問題的能力,因此,它的訓練數據是以 task 為基本單位的,每個 task 都有自己獨立的損失函數。訓練時,不停地在不同的 task 上切換,從而達到初始化網路參數的目的,最終得到的模型,面對新的 task 時可以學習得更快。
還拿學生的例子類比,Meta Learning 相當於讓學生學習多門功課,比如第一節數學,第二節英語,第三節歷史,等等……每門功課是一個 task。聽說山東高考要考一門「基本能力測試」,這門課會學很多特別雜的知識(類比既檢測人臉的位置,又檢測人臉關鍵點坐標的網路),但這歸根結底只是一門課而已,並不能因為它學的知識雜就變成了好幾門課。
下圖是李宏毅老師在 PPT 中給的例子,Task 1 是貓狗分類任務,而 Task 2 是蘋果橘子分類任務。Learning Algorithm F 即 Meta Learning 的演算法,它在 Task 1 上經過訓練,吐出一個分類器 用於分類貓和狗,在測試集上算出的損失為
,在 Task 2 上又吐出一個分類器
,損失函數是
。

我們始終要牢記,Meta Learning 最終的目的是要讓模型獲得一個良好的初始化參數。這個初始化參數在訓練 task 上表現或許並不出色,但以這個參數為起點,去學習新的 task 時,學得會又快又好。而普通的 Learning,則是著眼於解決當前的 task,不會考慮如何面對新的 task。
損失函數的奧妙:初始化參數掌控全場,分任務參數各自為營
李宏毅老師的 PPT 給出了 MAML 損失函數 (1) 和普通訓練 (2) 的損失函數的對比。其中 代表網路的基礎參數,
代表基於網路基礎參數學習到的第 n 個分任務的參數,
代表以
為參數的分任務 n 的損失函數。
MAML 中,不同分任務對應一套不同的參數,它們都是基於網路參數 通過梯度下降得到的。這些參數和初始化參數
很像,但又不完全一樣,因為
是最終我們需要的初始化參數,它不能受制於某個分任務,而是要掌控全場。
而普通的學習任務,如公式 (2) 不同的分任務對應了同一套參數 ,這樣得到的
將是照顧到所有分任務的全局最優解。但在當前所有訓練 task 上的全局最優解,一定是好的初始化值么?下面兩圖就是反例。

上圖中橫坐標代表網路參數,縱坐標代表損失函數。淺綠和墨綠兩條曲線代表兩個 task 的損失函數隨參數變化曲線。我們發現如果用公式 (2),我們將得到左圖的效果:在靠近墨綠色曲線的附近,找到可以使兩個任務損失函數之和最小的參數 。但是如果用它作為初始化參數,在 task 2 上訓練,我們最終將收斂到 task 2 左側的 local minima,而不是其右側更小的 global minima。而使用公式 (1) 時,
會掌控全場,它不太在意訓練集上的損失,而是著眼於最大化網路的學習能力。如右圖所示,這個
就是很好的初始化參數,往左往右,皆可到達不同任務的 global minima。即:左側圖片得到的全局最優解,並不是模型學習能力最強的地方,也就是說,用這個全局最優解初始化網路參數,再在新任務上做 fine-tuning,最終得到的模型可能不會收斂到新任務的最優解。這就是 MAML 優於 fine-tuning 的地方。
再舉個例子,最近教育界常常講通識教育,素質教育,不能唯成績論,防止學生變成做題考試的機器。這裡面就暗合了 Meta Learning 的道理。我們在學校中學了很多門課,傻子都知道以後工作了並不會門門都用得上,但我們還是得學它們。因為我們要通過學習這些看似無用的課,獲取一種寶貴的能力:那就是學習新知識的能力。我們傳統的教育,認為門門功課考高分就是優秀,導致產生一些死讀書,讀死書的學生,最後人生髮展並不是很好。現在教育改革的目標,就是教會學生自主學習,著眼於學生未來的發展,而不糾結於把一個類型的題練八百遍,在規定時間按照規定步驟做出來。(至於是不是可行,以及如何通過高考公平地選拔人才,那是另一個話題了)。
評價標準
既然 Meta Learning 是 learn to learn,那麼如何證明 Meta Learning 演算法的有效性呢?顯而易見,只需要證明用這種演算法得到的網路模型學習能力很強就行了。具體到我們的 MAML 和 Reptile,只需要證明,用它們這些演算法初始化之後的神經網路,在新的任務上訓練,其收斂速率與準確率比從隨機初始化的神經網路要高。
這裡所謂「新任務」,一般是指難度比較大的任務,畢竟難度大的任務才有區分度嘛,要是都像 MNIST 數據那麼簡單,隨便一訓練就 99% 的準確率,也看不出網路初始化參數所起的作用了。因此一般用 few-shot learning 的任務來評估 Meta Learning 演算法的有效性。所謂 few-shot learning,就是指每類只有少量訓練數據的學習任務(MNIST 每個數字都有上萬張訓練圖片,因此不是 few-shot learning)。數據集 Omniglot:brendenlake/omniglot,是一個類似 MNIST 的手寫數據集,如下圖所示。該數據集包含 1623 類,每類只有 20 個訓練數據,因此它屬於 few-shot learning 的範疇,經常作為 benchmark 用來衡量 Meta Learning 演算法的性能。

Omniglot 的用法是這樣的,從其中採樣 N 個類,每個類有 K 個訓練兩本,組成一個訓練任務(task),稱為 N-ways K-shot classification。然後再從剩下的類中,繼續重複上一步的採樣,構建第二個 task,最終構建了 m 個 task。把這 m 個 task 分成訓練 task 和測試 task,在訓練 task 上訓練 Meta Learning 的演算法,然後再用測試 task 評估 Meta Learning 得到的演算法的學習能力。
MAML
上文中講了那麼多鋪墊,終於可以正式開講 MAML 演算法了。拋棄原始論文中複雜的符號表示,還是採用李宏毅老師 PPT 中直觀的圖片,左側為 MAML 演算法,右側為傳統的預訓練演算法。

代表網路初始化的參數,我們採樣一個任務(或者好幾個任務作為一個 batch,圖中顯示的是一個任務的情況),編號為 m。如左圖第一個綠色箭頭所示:基於
計算網路在任務 m 上的損失函數,然後用梯度下降法優化
,以 learning rate
得到任務 m 獨有的網路參數
;接下來,如左圖第二個綠色箭頭所示,基於
計算任務 m 新的損失函數,並求出損失函數在
上的梯度。我們不是用這個梯度優化
,而是優化最初的那個
,用 learning rate
乘以梯度,加到
上,得到
,如第左圖一個藍色箭頭所示,該箭頭和第二個綠色箭頭是平行的,代表
的更新方向為
處的梯度。用完了
之後,就把它扔掉了,下次再採樣到任務 m 時,再基於當時的
重新算一遍。(如果這裡的 batchsize 不是1,那麼在計算各個 task 獨有的參數時,還是單獨計算,但是計算
的更新時,需要把各個 task 在其對應的
上的梯度加在一起,再更新主任務的參數
。)
接下來和上一步一樣,再以 為新的網路初始化參數,採樣一個新的任務 n,第一步梯度下降用於更新得到 n 獨有的參數
,第二步梯度下降,用任務 n 的損失函數在
上的梯度,更新主網路參數
,從而得到
。以此類推,最終得到網路的初始化參數。
那麼,為什麼要先在採樣出來的任務上更新一步參數,然後再計算梯度呢?數學推導請參考李宏毅老師的PPT和教學視頻,這裡只說一下我的直觀理解:由於 learning rate 很小,其實在任務 m 上更新完了參數後,得到的 和
的差別也不大,因此這個梯度一定程度可以讓總任務的損失函數降低。但同時,畢竟
和
還是不一樣的,這避免了最終求到的
變成取得所有訓練任務損失函數之和的 global minima,從而變成了 model pre-training 那一套。
這個 jupyter notebook 用 pytorch 實現了 MAML 論文中的 toy example
https://github.com/AdrienLE/ANIML/blob/master/ANIML.ipynb :
該 toy example 的目標是擬合正弦曲線: ,其中 a、b 都是隨機數,每一組 a、b 對應一條正弦曲線,從該正弦曲線採樣 K 個點,用它們的橫縱坐標作為一組 task,橫坐標為神經網路的輸入,縱坐標為神經網路的輸出。我們希望通過在很多 task 上的學習,學到一組神經網路的初始化參數,再輸入測試 task 的 K 個點時,經過快速學習,神經網路能夠擬合測試 task 對應的正弦曲線。

左側是用常規的 fine-tune 演算法初始化神經網路參數。我們觀察發現,當把所有訓練 task 的損失函數之和作為總損失函數,來直接更新網路參數 ,會導致無論測試 task 輸入什麼坐標,預測的曲線始終是 0 附近的曲線,因為 a 和 b 可以任意設置,所以所有可能的正弦函數加起來,它們的期望值為 0,因此為了獲得所有訓練 task 損失函數之和的 global minima,不論什麼輸入坐標,神經網路都將輸出 0。
右側是通過 MAML 訓練的網路,把它作為起點,在新任務上訓練時,可以很好地擬合目標函數的曲線。
Reptile
Reptile 和 MAML 很像,它的演算法如下圖所示:

Reptile 中,每更新一次 ,需要 sample 一個 batch 的 task(圖中 batchsize=1),並在各個 task 上施加多次梯度下降,得到各個 task 對應的
。然後計算
和主任務的參數
的差向量,作為更新
的方向。這樣反覆迭代,最終得到全局的初始化參數。
總結
我們跳出複雜的公式,可以看到 MAML 和 Reptile 都做了一件事:在避免得到訓練 task global minima 的同時,降低訓練 task 的總損失函數。對總的損失函數,有點羞噠噠欲拒還迎的感覺。直觀可以這樣理解:如果直接用所有訓練 task 的損失函數之和對 的梯度更新它,讓損失函數取得 global minima(並不是真正的全局最優,而是「相對來說」的全局最優),會造成一種 task 層面的 overfitting,使得網路只能在訓練 task 上取得好成績,一旦換一個沒見過的 task,它就表現得不太好了。反之,用單一 task 的參數的梯度,作為模糊的梯度下降的大方向,大差不差、馬馬虎虎地降低訓練 task 總的損失函數,往往能獲得一個較好的初始化參數。以這個初始化參數為起點,在新的任務上稍加訓練,它反而能學得又快又好。
至於和 fine-tuning 預訓練網路的技術的區別,個人認為最明顯的是:預訓練神經網路時,往往有它自己的任務,並沒有考慮到將來別人拿這個網路做別的用途。所以它一心一意把自己的 task 搞好就萬事大吉了,至於它能用在別的任務中,純粹是無心插柳。而 Meta Learning 從一開始,就不以降低訓練 task 損失函數為唯一目標,它的目標是各個 task 都學一點,然後學一組有潛力的初始化參數,以備將來新的訓練任務。
推薦閱讀:
TAG:深度學習(DeepLearning) |