batch normalization的multi-GPU版本該怎麼實現?

batch normalization中 的multi-GPU實現涉及到大量GPU間通信,這時效率就會很慢。當前各個平台(caffe, torch)的batch normalization的實現在multi-GPU情況下是否只考慮了單個GPU上的均值與方差?


我們維護的Caffe版本里有可以參考的實現

https://github.com/yjxiong/caffe/blob/action_recog/src/caffe/layers/sync_bn_layer.cpp

這個東西主要用在單個樣本較大導致每個GPU的batchsize特別小的時候。這時用SyncBN有可能使訓練更穩定。


終於搞定了Synchronized BatchNorm,來回答一下這個問題。

首先針對問題本身,目前所有的framework,包括Caffe,Torch,TF,PyTroch等等,BatchNorm的實現都是只考慮了single gpu。也就是說BN使用的均值和標準差是單個gpu算的,相當於縮小了mini-batch size。至於為什麼這樣實現,1)因為沒有sync的需求,因為對於大多數vision問題,單gpu上的mini-batch已經夠大了,完全不會影響結果。2)影響訓練速度,BN layer通常是在網路結構裡面廣泛使用的,這樣每次都同步一下GPUs,十分影響訓練速度。

關於如何實現,這裡不得不寫一下心得,從發現問題到實現synchronize,花了將近一周的時間,一開始把這個問題想簡單了,希望對大家有所幫助吧。偶然中發現自己做的task十分依賴synchronzie BN (因為在公司實習, 不方便說project)。然後開始實現:

1。 因為通常訓練過程是將網路複製到不同的gpu上,然後進行forward和backward,之後只需要collect gradient,再更新主gpu上的網路,然後下一個iteration再複製一遍。在這種方式下,不同gpu的BN layer是沒辦法通信的。所以第一步是重做了DataParallel架構,可以enable 每層的通信。

2。 第二步就是推推公式,相信各位大神花點兒時間動動筆就搞定,我就不贅述了。然後開始編程寫kernel,核心思想是每次forward 時候sync一下均值和方差,然後backward的時候sync一些grad of 均值和方差,再繼續backward。

3。本身前兩步搞定之後,應該一切順利了,可是訓練的時候總是跑幾十個iteration就卡死。這個其實是pytorch autograd engine 的問題, 因為每個BN layer的均值和方差都是cross gpu 的grad graph,而我們又是大量使用BN,所以成個back-prop的graph破壞了pytorch grad engine。解決方案是寫一個cross gpu的autograd function來handle。

大體思路是這樣的,可能發paper的時候再release。不過大部分task應該不涉及這個題。


低調地update一下,我已經公開了這個PyTorch implementation, 大家可以先用,我過幾天放個arxiv paper,鏈接: Synchronized BatchNorm


暫時主流的做法不是 Per-Node / Per-Card 的做么?畢竟這裡 Sync 的 Cost 太大了,最後可以 Post BN 的算回。據我所知,CNTK,Torch 和 Tensorflow,Neon 也都類似。我覺得還有部分原因是因為大多現在主流框架都利用 cuDNN 去計算了,相對會更不好處理。

這樣也能保證每次 BN 都是 32 BatchSize 的算,跟原 Paper 一樣,以我個人有限的實驗經驗來看,Per-Node 的 BN Batch-Size 更大(總 Batch-Size 不變)結果不小的變化。

當然我主要工作是做驗證的。這部分訓練的時候,該 Sync 不該 Sync 我覺得還是結果說話的,我覺得值得一試,如果確實結果提升大,那確實要考慮改了,如果不大,畢竟現在也 Work,效率也高得多。

當然對於部分 Task,比如 Action Recognition,這麼做確實有特殊的必要,畢竟如果 3D-Conv 的網路,Per-Node 也沒多少個 Sample,受限於顯存大小。

我覺得非得這麼樣 Sync 的算,從優化的角度主要也還是節點通信的 Cost,這個又回到了老問題上面,顯卡間就是 NCCL 那種策略,或者最簡單的 Tree-based 相鄰節點的 Sync (先同一個 PCI-E Switch 的,再同一個 CPU Node 的,然後 QPI 的,再 IB 的)

附上一些鏈接:

[1] Batch Normalization for Multi-GPU / Data Parallelism · Issue #7439 · tensorflow/tensorflow

[2] Distributed Implementations using multiple GPUs


用SERelu


推薦閱讀:

做圖像檢索,圖像庫從哪兒能下載到?
沐神的第三代parameter server的worker節點只需要保存部分參數,怎麼理解?
目前世界上有對強人工智慧的嘗試嗎?具體瓶頸是什麼?
目前火熱的Deep Learning會滅絕傳統的SIFT/ SURF的特徵提取的演算法嗎?
目前主流的attention方法都有哪些?

TAG:機器學習 | 深度學習DeepLearning | Caffe深度學習框架 | TensorFlow | Torch深度學習框架 |