pytorch中的鉤子(Hook)有何作用?

最新開始用pytorch,感覺搭建模型和調試都比較方便,看了看文檔,對hook不是很理解,請教大神這個hook設計初衷是啥,一般在什麼場景下應用?


正好最近也在看,就做一迴文檔搬運工吧。

首先明確一點,有哪些hook?

我看到的有3個:

1. torch.autograd.Variable.register_hook (Python method, in Automatic differentiation package

2. torch.nn.Module.register_backward_hook (Python method, in torch.nn)

3. torch.nn.Module.register_forward_hook

第一個是register_hook,是針對Variable對象的,後面的兩個:register_backward_hook和register_forward_hook是針對nn.Module這個對象的。

其次,明確一下,為什麼需要用hook

打個比方,有這麼個函數, xin mathbb{R}^2y=x+2z = frac{1}{2}(y_1^2+y_2^2) 你想通過梯度下降法求最小值。在PyTorch裡面很容易實現,你只需要:

import torch
from torch.autograd import Variable

x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
z.backward()
x.data -= lr*x.grad.data

但問題是,如果我想要求中間變數 y的梯度,系統會返回錯誤。

事實上,如果你輸入:

type(y.grad)

系統會告訴你:NoneType

這個問題在PyTorch的論壇上有人提問過,開發者說是因為當初開發時設計的是,對於中間變數,一旦它們完成了自身反傳的使命,就會被釋放掉。

因此,hook就派上用場了。簡而言之,register_hook的作用是,當反傳時,除了完成原有的反傳,額外多完成一些任務。你可以定義一個中間變數的hook,將它的grad值列印出來,當然你也可以定義一個全局列表,將每次的grad值添加到裡面去。

import torch
from torch.autograd import Variable

grad_list = []

def print_grad(grad):
grad_list.append(grad)

x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data

需要注意的是,register_hook函數接收的是一個函數,這個函數有如下的形式:

hook(grad) -&> Variable or None

也就是說,這個函數是擁有改變梯度值的威力的!

至於register_forward_hook和register_backward_hook的用法和這個大同小異。只不過對象從Variable改成了你自己定義的nn.Module。

當你訓練一個網路,想要提取中間層的參數、或者特徵圖的時候,使用hook就能派上用場了。

參考資料:

1. Why cant I see .grad of an intermediate variable?

2. Extract feature maps from intermediate layers without modifying forward()


相當於插件。可以實現一些額外的功能,而又不用修改主體代碼。把這些額外功能實現了掛在主代碼上,所以叫鉤子,很形象。

Emacs也有同樣的概念。


你不理解的是hook對吧 你可以這麼理解 在一個完整的業務流程當中不修改代碼而能插入業務的一種方式


推薦閱讀:

TAG:機器學習 | 深度學習DeepLearning | Torch深度學習框架 | PyTorch |