pytorch的計算圖
12 人贊了文章
前言:
接觸pytorch這幾個月來,一開始就對計算圖的奧妙模糊不清,不知道其內部如何傳播。這幾天有點時間,就去翻閱了Github,pytorch Forum,還有很多個人博客(後面會給出鏈接),再加上自己的原本一些見解,現在對它的計算圖有了更深層次的理解。
pytorch是非常好用和容易上手的深度學習框架,因為它所構建的是動態圖,極大的方便了coding and debug。可是對於初學者而言,計算圖是一個需要深刻理解的概念,在後期的搭建的神經網路都是基於計算圖而設計的。
一、構建計算圖
pytorch是動態圖機制,所以在訓練模型時候,每迭代一次都會構建一個新的計算圖。而計算圖其實就是代表程序中變數之間的關係。舉個列子: 在這個運算過程就會建立一個如下的計算圖:
在這個計算圖中,節點就是參與運算的變數,在pytorch中是用Variable()變數來包裝的,而圖中的邊就是變數之間的運算關係,比如:torch.mul(),torch.mm(),torch.div() 等等。
注意圖中的 leaf_node,葉子結點就是由用戶自己創建的Variable變數,在這個圖中僅有a,b,c 是 leaf_node。為什麼要關注leaf_node?因為在網路backward時候,需要用鏈式求導法則求出網路最後輸出的梯度,然後再對網路進行優化,如下就是網路的求導過程。
二、圖的細節。
pytoch構建的計算圖是動態圖,為了節約內存,所以每次一輪迭代完之後計算圖就被在內存釋放,所以當你想要多次backward時候就會報如下錯:
net = nn.Linear(3, 4) # 一層的網路,也可以算是一個計算圖就構建好了input = Variable(torch.randn(2, 3), requires_grad=True) # 定義一個圖的輸入變數output = net(input) # 最後的輸出loss = torch.sum(output) # 這邊加了一個sum() ,因為被backward只能是標量loss.backward() # 到這計算圖已經結束,計算圖被釋放了
上面這個程序是能夠正常運行的,但是下面就會報錯
net = nn.Linear(3, 4)input = Variable(torch.randn(2, 3), requires_grad=True)output = net(input)loss = torch.sum(output)loss.backward()loss.backward()RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.
之所以會報這個錯,因為計算圖在內存中已經被釋放。但是,如果你需要多次backward只需要在第一次反向傳播時候添加一個標識,如下:
net = nn.Linear(3, 4)input = Variable(torch.randn(2, 3), requires_grad=True)output = net(input)loss = torch.sum(output)loss.backward(retain_graph=True) # 添加retain_graph=True標識,讓計算圖不被立即釋放loss.backward()
這樣在第一次backward之後,計算圖並不會被立即釋放。
讀到這裡,可能你對計算圖中的backward還是一知半解。例如上面提過backward只能是標量。那麼在實際運用中,如果我們只需要求圖中某一節點的梯度,而不是整個圖的,又該如何做呢?下面舉個例子,列子下面會給出解釋。
x = Variable(torch.FloatTensor([[1, 2]]), requires_grad=True) # 定義一個輸入變數y = Variable(torch.FloatTensor([[3, 4], [5, 6]]))loss = torch.mm(x, y) # 變數之間的運算loss.backward(torch.FloatTensor([[1, 0]]), retain_graph=True) # 求梯度,保留圖 print(x.grad.data) # 求出 x_1 的梯度x.grad.data.zero_() # 最後的梯度會累加到葉節點,所以葉節點清零loss.backward(torch.FloatTensor([[0, 1]])) # 求出 x_2的梯度print(x.grad.data) # 求出 x_2的梯度
結果如下:
3 5[torch.FloatTensor of size 1x2] 4 6[torch.FloatTensor of size 1x2]
可能看到上面例子有點懵,用數學表達式形式解釋一下,上面程序等價於下面的數學表達式:
這樣我們就很容易利用backward得到一個雅克比行列式:
到這裡應該對pytorch的計算圖和backward有一定了解了吧。
如有錯誤,歡迎指正。
References:
Calculus on Computational Graphs: Backpropagation
Computational Graphs in PyTorch
PyTorch中的backward
推薦閱讀:
※層次化是長序列的未來 Hierarchical Attention Networks for Document Classification
※深度神經網路(DNN)的訓練過程
※復盤 AI Challenger 場景分類
※神經網路正則化(1):L1/L2正則化
※<EYD與機器學習>十三 CNN
TAG:PyTorch | 神經網路 | 深度學習DeepLearning |