TensorFlow分散式訓練加速之梯度壓縮

今年的NIPS出現「Imagenet is the new MNIST「口號,宣告使用MNIST數據集檢驗網路模型性能已經成為過去式。演算法工程師們早就意識到訓練數據集大小的重要性,並且進一步發現,針對特定的模型大小,訓練數據集的大小和泛化誤差之間存在一下的關係[1]:

訓練數據集大小必須跨越Power-law Region,才能得到網路模型的實際性能。

網路模型的大小(表現為網路模型的參數數量)越來越大,處理單個sample需要消耗更多的單精度運算。例如,Resnet-50處理一張225x225的圖片需要消耗7.72	imes10^{9} 次單精度運算,如果在Imagenet上訓練90epochs,需要消耗 90	imes1.28	imes 10^{6} 	imes 7.72	imes 10^{9}approx 10^{18} 次單精度運算。以Tesla K80的峰值單精度浮點性能5.6 Tflops估計,理想情況下也要訓練50多個小時。

雖然網路模型參數有很多冗餘,可以使用模型壓縮演算法減少冗餘,提高推理速度,但目前還沒有演算法能減少訓練期間的網路模型參數,無法提高訓練速度。使用較少模型參數的網路訓練,只能得到性能差很多的模型。

目前唯一可以顯著提升訓練效率的方式是使用Large Batch Size進行大規模分散式訓練。

許多實驗表明,通訊開銷是大規模分散式訓練的瓶頸。以Data-Parallize訓練Restnet-50為例,Resnet-50參數大小約100MB(25.6million * 4B),在Tesla K80上用TensorFlow訓練,每次迭代需要181.404ms[2],這意味著必須要有550MB/s以上的帶寬,才能避免帶寬飽和。如果考慮VGG-16(138million * 4B),需要2965MB/s以上帶寬,才能避免帶寬飽和。

幸運的是只需要訓練3~4epoches,梯度的稀疏度就能達到99.9%,就能將梯度壓縮270x到600x而不損失精度[3]。例如將Resnet-50梯度從97MB壓縮到0.35MB,即使在最普通的1Gbps乙太網上跑大規模分散式訓練,也不會出現帶寬飽和。

梯度壓縮減少通訊開銷[3]

梯度壓縮原理

當梯度達到99.9%稀疏度,只有絕對值最大的0.1%梯度發送到參數伺服器。如何找出絕對值最大的0.1%梯度呢?可以使用top-k selection演算法找到threshold。top-k selection的時間複雜度是O(N)。為了提高速度,可以先random sample少量梯度,再top-k selection演算法找一個近似的threshold。如果大於threshold的梯度超過0.1%,可以在結果集中再次使用top-k selection演算法。未超過threshold的梯度會累積到下次迭代的梯度中。

梯度壓縮能極大減少網路開銷,但會影響收斂,導致模型精度降低。以下四個方面可以減少這種影響,達到不損失精度。

一,如圖是在使用Momentum優化器訓練中,使用梯度壓縮(C)導致和原優化演算法(B)路徑分離。

本地梯度累積導致精度降低[3]

通過如下修改Momentum優化演算法:

即給本地累積梯度也apply momentum,可以使梯度壓縮(C)和原優化演算法(B)路徑一致。

這裡有點難理解,推導如下:

假設第i個梯度從直到t-1次迭代才超過threshold,觸發更新,此時

u^{(i)}_{t-1} = m^{t-2} g^{(i)}_{1}+ … + m g^{(i)}_{t-2} + g^{(i)}_{t-1}

v^{(i)}_{t-1} = (1+…+m^{t-2}) g^{(i)}_{1} + … + (1+m) g^{(i)}_{t-2} + g^{(i)}_{t-1}

更新權重 w^{(i)}_{t} = w^{(i)}_{1} – lr 	imes v^{(i)}_{t-1}

更新後需要執行 v^{(i)}_{t-1} = 0

如果第t次迭代,觸發更新,此時

u^{(i)}_{t} = m^{t-1} g^{(i)}_{1} + … + m g^{(i)}_{t-1} + g^{(i)}_{t}

v^{(i)}_{t} = m^{t-1} g^{(i)}_{1} + … + m g^{(i)}_{t-1} + g^{(i)}_{t} w^{(i)}_{t+1} = w^{(i)}_{t} – lr 	imes v^{(i)}_{t}

= w^{(i)}_{1} – lr 	imes (v^{(i)}_{t-1} + v^{(i)}_{t} )

= w^{(i)}_{1} - lr 	imes[ (1+…+m^{t-1}) g^{(i)}_{1} + … + (1+m) g^{(i)}_{t-1} + g^{(i)}_{t} ]

所以本地梯度累積跟原優化演算法一致。

二,梯度壓縮可能導致梯度爆炸問題,可以在每個節點對本地梯度使用梯度裁剪演算法。

三,對大於threshold的梯度,將Momentum清零。

根據[4]中結論,async SGD會產生一個implicit momentum,導致收斂變慢。本地梯度累計跟async SGD存在相似性:未能及時更新梯度產生staleness。[4]中通過grid search的方法發現negative momentum能一定程度抵消implicit momentum效果,提高收斂速度。

這裡 u_{k,t} leftarrow u_{k,t}odot 
eg Mask 將Momentum清零,類似negative momentum。

四,在梯度達到99.9%稀疏度前,有個warm-up stage。

演算法步驟如下[3]

梯度壓縮的TensorFlow實現

在tensorflow/python/training/momentum.py中,為每個variable增加一個residual

def _create_slots(self, var_list): for v in var_list: self._zeros_slot(v, "momentum", self._name)+ for v in var_list:+ self._zeros_slot(v, "residual", self._name)

在tensorflow/core/kernels/training_ops.cc中,獲取residual,並apply momentum

+Tensor residual;+OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 5, use_exclusive_lock_, &residual));functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), grad.flat<T>(),+ momentum.scalar<T>(), residual.flat<T>(),+ use_nesterov_,+ ctx, steps);

在tensorflow/core/kernels/training_ops_gpu.cu.cc中,實現梯度壓縮

if (use_nesterov) { var.device(d) = grad * lr.reshape(single).broadcast(bcast) + accum * momentum.reshape(single).broadcast(bcast) *+ lr.reshape(single).broadcast(bcast) ++ residual * lr.reshape(single).broadcast(bcast); } else {+ var.device(d) = lr.reshape(single).broadcast(bcast) * accum ++ residual * lr.reshape(single).broadcast(bcast); }+ if (steps < WARMUP) {+ return ctx->ps()->update(var.name(), var.data(), d.stream());+ }+ Tensor norm(DataTypeToEnum<T>::value, {});+ norm.scalar<T>().device(d) = var.square().sum().sqrt();+ Tensor threshold(DataTypeToEnum<T>::value, {});+ threshold.scalar<T>()() = 6.0;+ if (norm.scalar<T>()() > threshold.scalar<T>()()) {+ var.device(d) = var * threshold.scalar<T>().reshape(single).broadcast(bcast) /+ norm.scalar<T>().reshape(single).broadcast(bcast);+ }+ int size = var.size();+ int capacity = size * 1.2 * (1.0 - 0.99);+ Tensor sparse_buf;+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(+ DataTypeToEnum<T>::value,+ TensorShape({static_cast<int64>(capacity * sizeof(T))}), &sparse_buf));+ T* buf = sparse_buf.template flat<T>().data();+ Tensor sparse_indice;+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(+ DataTypeToEnum<T>::value,+ TensorShape({static_cast<int64>(capacity * sizeof(T))}), &sparse_indice));+ T* indices = sparse_indice.template flat<T>().data();+ T* data = var.data();+ T* tmp = residual.data();+ int sortSize = std::min(100000, size);+ CudaLaunchConfig config2 = GetCudaLaunchConfig(sortSize, d);+ sampling<T>+ <<<config2.block_count, config2.thread_per_block, 0, d.stream()>>>+ (data, tmp, sortSize, size / sortSize, size);+ thrust::device_ptr<T> dev_data_ptr(tmp);+ thrust::sort(dev_data_ptr, dev_data_ptr + sortSize);+ float rate = 0.99;+ T threshold;+ int k_index = std::max(0, (int)(sortSize * rate) - 1);+ cudaMemcpy(&threshold, tmp + k_index, sizeof(T), cudaMemcpyDeviceToHost);+ CudaLaunchConfig config = GetCudaLaunchConfig(size, d);+ gen_mask<T>+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>+ (data, tmp, accum.data(), threshold, size);+ thrust::device_ptr<T> mask_ptr(tmp);+ thrust::inclusive_scan(mask_ptr, mask_ptr + size, mask_ptr);+ T sum;+ cudaMemcpy(&sum, tmp + size - 1, sizeof(T), cudaMemcpyDeviceToHost);+ unsigned long sparse_size = sum;+ sparsify<T>+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>+ (data, tmp, buf, indices, size);+ assign_residual<T>+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>+ (data, tmp, size);+ cudaMemcpy(&data, buf, sizeof(T) * sparse_size, cudaMemcpyDeviceToDevice);+ cudaMemcpy(&data, indices, sizeof(T) * sparse_size, cudaMemcpyDeviceToDevice);+ cudaStreamSynchronize(d.stream());+ ctx->ps()->update(var.name(), data, sparse_size, d.stream());

這是最初實現代碼,僅供參考。

We are hiring! If you want to ask questions please send your resume to zuo.wang at sky-data.cn

Reference

[1] research.baidu.com/deep

[2] Benchmarking State-of-the-Art Deep Learning Software Tools

[3] Deep Gradient Compression:Reducing the Communication Bandwidth for Distributed Training

[4] Asynchrony begets momentum, with an application to deep learning


推薦閱讀:

識別漢字圖像的數據集
學習tensorflow庫
為什麼在windows下用不了tensorflow?
深度學習對話系統實戰篇--新版本chatbot代碼實現
tensorflow的自動求導具體是在哪部分代碼里實現的?

TAG:深度學習DeepLearning | TensorFlow |