PyTorch源碼淺析(四)

自動微分的算符庫THNN和THCUNN

在不考慮自動微分的引擎的情況下,實際上想要實現簡單的自動微分很簡單。只要將不同的算符實現為具有forward方法,和backward方法的類型就可以了,然後用某個引擎來控制調用的順序(或者說遍歷計算圖的順序)。

在這裡我先簡要地講一些計算圖(以下來自之前在安慶的統計物理workshop的slide,相關圖片參考了Cornell的cs5740,但是是我自己畫的,應該不用侵刪...)

首先一個計算圖定義是,一個具有如下性質的有向圖:

  • 邊:某個函數變數,或者說某個函數的依賴
  • 具有輸入邊的點:某個函數(或者說算符)
  • 有輸出邊的點:某個變數

舉個例子的話,對於一個簡單的表達式 x^T A x 可以表達為如下的計算圖

計算圖的求值分為前向傳播( forward propagation )和後向傳播( backward propagation ),分別用來求輸出值和梯度。大致的過程就是對葉子節點賦值,然後將節點上函數計算的結果作為值傳到下一個節點上,直到整個圖的節點都被遍歷過。這其中根據圖中的結構(比如有圈,loop)可以進行優化,tensorflow等框架就會對這些情況優化。這裡略去,下面根據幾個圖大致看一下這個過程:

(其實這個放slide時候是有動畫效果的,anyways,有空再做一個gif)

所以以上的結構使得我們有兩種方式去建立一個計算圖:

  • 靜態圖方案:
    • 先定義圖的結構,然後給葉子節點賦值(這也是tensorflow中placeholder的由來)
    • 然後根據葉子節點的賦值進行forward
  • 動態圖方案,在forward的同時建立圖的結構(也就是所謂的動態圖)

然後類似forward,可以用後向傳播算梯度

以上的過程,使得我們實際上只需要在代碼里對每個節點(算符)實現兩個方法forward和backward就能夠實現前向傳播和後向傳播。舉個THNN里的例子

void THNN_(Sigmoid_updateOutput)( THNNState *state, THTensor *input, THTensor *output){ THTensor_(sigmoid)(output, input);}

前面我們已經講過了,PyTorch的C代碼中,下劃線前面是類型名稱,後面是方法名稱,這裡updateOutput就是forward方法,而backward方法如下

void THNN_(Sigmoid_updateGradInput)( THNNState *state, THTensor *gradOutput, THTensor *gradInput, THTensor *output){ THNN_CHECK_NELEMENT(output, gradOutput); THTensor_(resizeAs)(gradInput, output); TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output, real z = *output_data; *gradInput_data = *gradOutput_data * (1. - z) * z; );}

但是對於有參數的節點例如linear層,convolution層,使用C語言管理參數是很麻煩的,我們希望這裡的實現更加乾淨,所以還需要另外一個方法單獨計算參數的梯度,在linear層等有參數的層里就額外有一個accGradParameters方法

void THNN_(Linear_accGradParameters)( THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *weight, THTensor *bias, THTensor *gradWeight, THTensor *gradBias, THTensor *addBuffer, accreal scale_)

這樣等到我們將其封裝到Python之後,就可以用Python的垃圾回收等功能去管理參數了。THCUNN中的實現基本一致,就是使用的是CUDA,會帶__host__, __device__之類的標示符。

下一篇講稀疏格式(CSC等)。

羅秀哲:PyTorch源碼淺析(目錄)?

zhuanlan.zhihu.com圖標
推薦閱讀:

PyTorch框架快速增長,兩個月來GitHub新標星數居第三
Pytorch筆記02-torch.nn以及torch.optim
PyTorch教程+代碼:色塊秒變風景油畫
起名字這個技術活,終於用Pytorch找到解決辦法了!

TAG:PyTorch | 深度學習DeepLearning |