pytorch中的鉤子(Hook)有何作用?
最新開始用pytorch,感覺搭建模型和調試都比較方便,看了看文檔,對hook不是很理解,請教大神這個hook設計初衷是啥,一般在什麼場景下應用?
正好最近也在看,就做一迴文檔搬運工吧。
首先明確一點,有哪些hook?
我看到的有3個:
1. torch.autograd.Variable.register_hook (Python method, in Automatic differentiation package
2. torch.nn.Module.register_backward_hook (Python method, in torch.nn)
3. torch.nn.Module.register_forward_hook
第一個是register_hook,是針對Variable對象的,後面的兩個:register_backward_hook和register_forward_hook是針對nn.Module這個對象的。
其次,明確一下,為什麼需要用hook
打個比方,有這麼個函數, , , 你想通過梯度下降法求最小值。在PyTorch裡面很容易實現,你只需要:
import torch
from torch.autograd import Variable
x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
z.backward()
x.data -= lr*x.grad.data
但問題是,如果我想要求中間變數 的梯度,系統會返回錯誤。
事實上,如果你輸入:
type(y.grad)
系統會告訴你:NoneType
這個問題在PyTorch的論壇上有人提問過,開發者說是因為當初開發時設計的是,對於中間變數,一旦它們完成了自身反傳的使命,就會被釋放掉。
因此,hook就派上用場了。簡而言之,register_hook的作用是,當反傳時,除了完成原有的反傳,額外多完成一些任務。你可以定義一個中間變數的hook,將它的grad值列印出來,當然你也可以定義一個全局列表,將每次的grad值添加到裡面去。
import torch
from torch.autograd import Variable
grad_list = []
def print_grad(grad):
grad_list.append(grad)
x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data
需要注意的是,register_hook函數接收的是一個函數,這個函數有如下的形式:
hook(grad) -&> Variable or None
也就是說,這個函數是擁有改變梯度值的威力的!
至於register_forward_hook和register_backward_hook的用法和這個大同小異。只不過對象從Variable改成了你自己定義的nn.Module。
當你訓練一個網路,想要提取中間層的參數、或者特徵圖的時候,使用hook就能派上用場了。
參考資料:
1. Why cant I see .grad of an intermediate variable?
2. Extract feature maps from intermediate layers without modifying forward()
相當於插件。可以實現一些額外的功能,而又不用修改主體代碼。把這些額外功能實現了掛在主代碼上,所以叫鉤子,很形象。Emacs也有同樣的概念。
你不理解的是hook對吧 你可以這麼理解 在一個完整的業務流程當中不修改代碼而能插入業務的一種方式
推薦閱讀:
TAG:機器學習 | 深度學習DeepLearning | Torch深度學習框架 | PyTorch |