PyTorch 的 backward 為什麼有一個 grad_variables 參數?

TL; DR; 假設 x 經過一番計算得到 y,那麼 y.backward(w) 求的不是 y 對 x 的導數,而是 l = torch.sum(y*w) 對 x 的導數。w 可以視為 y 的各分量的權重,也可以視為遙遠的損失函數 l 對 y 的偏導數(這正是函數說明文檔的含義)。特別地,若 y 為標量,w 取默認值 1.0,才是按照我們通常理解的那樣,求 y 對 x 的導數。

自從 autograd 模塊被合併到 PyTorch master 分支以後,PyTorch 變得越來越得心應手,成為很多人的首選。在 PyTorch 教程 Autograd: automatic differentiation 里提到,torch.autograd.backward() 函數需要一個 grad_output 參數(此處疑為筆誤,根據文檔描述,torch.autograd.backward() 的參數應該是 grad_variables,函數 torch.autograd.grad() 的參數才是 grad_output)。如果是對一個標量進行反向傳播,那麼這個參數可以省略(預設值為 1.0)。

所以,這個參數到底是幹嘛的呢?

這個函數的原型是:

torch.autograd.backward(variables, grad_variables=None, retain_graph=None, create_graph=None, retain_variables=None)

文檔里的介紹是:

The graph is differentiated using the chain rule. If any of variables are non-scalar (i.e. their data has more than one element) and require gradient, the function additionally requires specifying grad_variables. It should be a sequence of matching length, that contains gradient of the differentiated function w.r.t. corresponding variables (None is an acceptable value for all variables that don』t need gradient tensors).

variables 和 grad_variables 都可以是 sequence,不過平常也不太有一串變數對另一串變數求導這種需求:如果有這種需求的話,自己寫個循環就行了;像 PyTorch 的這個介面,以及 TensorFlow 里提供的求導介面,雖然可以傳一堆 x 和一堆 y 進去,但是返回的都是一堆 y 的和對各個 x 的導數,這樣一來這個介面的用法就顯得很奇怪。反倒不如定義這一堆 y 的和為 z,然後求 z 對各個 x 的導數更加自然。

事實上,TF 和 PyTorch 這麼設計不是沒有原因的。原因就是: Tensor 沒法對 Tensor 求導!舉一個簡單的例子,如果要求一個 Tensor 對另一個 Tensor 的導數,先考慮矩陣對矩陣這種情形:假設 m*n 的矩陣 x 經過運算得到了 p*q 的矩陣 y,y 又經過運算得到了 s*t 的矩陣 z,那麼:dz/dy 是啥?假設可以求導,那麼得到的應該是四階張量吧,形狀是 s*t*p*q?好的,dy/dx 再算一下,得到一個四階張量 p*q*m*n。然後怎麼反向傳播?dz/dx = dz/dy * dy/dx = 兩個四階張量相乘???這要怎麼乘???當然,也不是說絕對不行,仔細思考一下可以把這個問題解決掉,在長度為 p 和 q 的那兩個維度上求個和就行,但是想一想無窮無盡的運算組合方式,怎麼寫一個足夠 robust 的反向傳播?就算你能解決兩個四維 Tensor 怎麼乘的問題,Tensor 對標量 Scalar 的導數又是啥?四維和三維的 Tensor 又怎麼乘?導數的導數又怎麼求,搞一個六階還是八階張量做中間結果?這一連串的問題,感覺要瘋掉……

一個簡單的解決方案就是:

1、不允許 Tensor 對 Tensor 求導,只允許標量 Scalar 對張量 Tensor 求導,求導結果是和自變數同型的 Tensor。

2、在求 dl/dx 的時候(l 是標量,x 是張量),假設有一個中間結果為張量 y,即 x->y->l,那麼先求 dl/dy(結果是良定義的、和 y 同型的 Tensor),然後根據 x 和 dl/dy 想辦法直接算出 dl/dx,跳過 dy/dx 是啥這種玄學問題!(這種問題在推 MLP 的反向傳播時也能遇到,解決辦法就是跳過它!)

然後再回到 PyTorch 的設計上來, backward() 為啥還需要一個額外的參數?就是為了避免 Tensor 對 Tensor 求導結果是啥這種玄學問題!torch.autograd.backward(y, w), 或者說 y.backward(w) 的含義是:先計算 l = torch.sum(y * w),然後求 l 對(能夠影響到 y 的)所有變數 x 的導數。這裡,y 和 w 是同型 Tensor。也就是說,可以理解成先按照 w 對 y 的各個分量加權,加權求和之後得到真正的 loss,再計算這個 loss 對於所有相關變數的導數。

這麼設計有什麼好處呢?如前所述,這樣一來,所有求導操作都是求 Scalar 關於 Tensor 的導數,統一了起來,不存在 Tensor 對 Tensor 求導了。再回顧一下 PyTorch 自己的文檔,它說 torch.autograd.backward 的第二個參數 grad_variables 應該是第一個參數 variables 的對應的導數。

嗯??這是什麼情況??其實我上面的解釋是一致的。假設 y 和 w 是同型 Tensor,那麼 l = torch.sum(y*w) 對 y 的導數 dl/dy 就是 w。所以把這裡的 w 理解成 y 的各項的權重也好,或者理解成某個高高在上的虛擬 loss 對 y 的導數也好,其實是一樣的。事實上,l = torch.sum(y*w) 這個形式不正好是導數的定義么?數學分析一上來就學,微分是函數增量的線性主部,而在 l = torch.sum(y*w) 這個形式里,只有線性的項,因此 w 就是 dl/dy。

那為什麼標量就不需要這個參數呢?假設 y 是標量,然後取 w=1.0,那麼 l=torch.sum(y*w) 其實就是 y 本身。所以這時候,dl/dx = dy/dx,可以直接把 loss 和 y 混同,這也就是通常直接把損失函數 loss 當成 y 傳進去的原因。

原理大概講到這裡,寫個程序驗證一下吧?

import torchfrom torch.autograd import Variablex = Variable(torch.randn(3), requires_grad=True)y = Variable(torch.randn(3), requires_grad=True)z = Variable(torch.randn(3), requires_grad=True)print(x)print(y)print(z)t = x + yl = t.dot(z)

初始化的結果是:

# xVariable containing: 0.9168 1.3483 0.4293[torch.FloatTensor of size 3]# yVariable containing: 0.4982 0.7672 1.5884[torch.FloatTensor of size 3]# zVariable containing: 0.1352-0.4037-0.2425[torch.FloatTensor of size 3]

在調用 backward 之前,可以先手動求一下導數,應該是: l = (x+y)^Tz, dl/dx = dl/dy = z, dl/dz=x+y=t, dl/dt=z

不如求一下結果

l.backward(retain_graph=True)print(x.grad)print(y.grad) # x.grad = y.grad = zprint(z)print(z.grad) # z.grad = t = x + yprint(t)

結果確實符合預期,前三個結果一樣,後兩個結果一樣:

# 前三個 print 的結果:Variable containing: 0.1352-0.4037-0.2425[torch.FloatTensor of size 3]# 後兩個 print 的結果:Variable containing: 1.4151 2.1155 2.0177[torch.FloatTensor of size 3]

最後是本文的重點,把 z 作為 t.backward() 的參數會怎麼樣?按照前面手動求導的結果, dl/dt=z ,正好和 PyTorch 的文檔描述相符(variables 是 t,grad_variablesvariables 的導數 dl/dt),根據分析,運行結果應該是損失函數 l 對 x 和 y 的導數

x.grad.data.zero_()y.grad.data.zero_()z.grad.data.zero_()t.backward(z)print(x.grad)print(y.grad)

實際運行的結果果然和 z (也就是 dl/dx 或者 dl/dy)的值一樣:

# x.grad = dl/dx = zVariable containing: 0.1352-0.4037-0.2425[torch.FloatTensor of size 3]# y.grad = dl/dy = zVariable containing: 0.1352-0.4037-0.2425[torch.FloatTensor of size 3]

總結:

假設 x 經過一番計算得到 y,那麼 y.backward(w) 求的不是 y 對 x 的導數,而是 l = torch.sum(y*w) 對 x 的導數。w 可以視為 y 的各分量的權重,也可以視為遙遠的損失函數 l 對 y 的偏導數。也就是說,不一定需要從計算圖最後的節點 y 往前反向傳播,從中間某個節點 n 開始傳也可以,只要你能把損失函數 l 關於這個節點的導數 dl/dn 記錄下來,n.backward(dl/dn) 照樣能往前回傳,正確地計算出損失函數 l 對於節點 n 之前的節點的導數。特別地,若 y 為標量,w 取默認值 1.0,才是按照我們通常理解的那樣,求 y 對 x 的導數。


推薦閱讀:

CS 294: Deep Reinforcement Learning Note(1)
淺析感知機(一)--模型與學習策略
人工智慧公開課

TAG:PyTorch | 深度学习DeepLearning | 机器学习 |