圖像識別:基於位置的柔性注意力機制

第一個基於位置的柔性注意力機制是NIPS 2015上DeepMind團隊提出的空間變換網路(Spatial Transformer Networks,STN),用在CNN模型當中幫助圖像識別。對於輸入的特徵圖,對識別圖像最有幫助的其中一部分可能不是規則(如矩形)的形狀,所以需要做不規則採樣。基於此,STN以整個特徵圖為輸入,特徵圖經過變換後生成新的特徵圖。典型的STN包含三部分:定位網路、網格生成器、採樣器,如圖所示。

空間變換網路。

令輸入 X_{	ext{in}} 的尺寸為 U_{	ext{in}}	imes V_{	ext{in}}	imes Q_{	ext{in}} ,輸出 X_{	ext{out}} 的尺寸為 U_{	ext{out}}	imes V_{	ext{out}}	imes Q_{	ext{out}} ,其中 UVQ 分別表示特徵圖的高度、寬度和通道數。比如 X_{	ext{in}} 是彩色圖片時,它的尺寸就是 U_{	ext{in}}	imes V_{	ext{in}}	imes 3 。STN同步在每個通道上起作用,所以通道數不會改變, Q_{	ext{out}}=Q_{	ext{in}} ,方便起見,後面的敘述會省略掉通道這個維度。

定位網路 phi_{W_{	ext{loc}}}X_{	ext{in}} 為輸入,生成特徵圖變換參數 A

A= phi_{W_{	ext{loc}}}(X_{	ext{in}}).

令網格 G 表示輸出特徵圖 X_{	ext{out}} 的所有像素的位置,也就是

G = {G_i} = left{ (x_{1,1}^{X_{	ext{out}}}, y_{1,1}^{X_{	ext{out}}}), (x_{1,2}^{X_{	ext{out}}}, y_{1,2}^{X_{	ext{out}}}), cdots, (x_{U_{	ext{out}},V_{	ext{out}}}^{X_{	ext{out}}}, y_{U_{	ext{out}},V_{	ext{out}}}^{X_{	ext{out}}}) 
ight},

其中 (x,y) 表示像素的坐標。

那麼,由參數 A 定義的變換 	au 會應用到網格 G 上面,生成變換後的網格 SS 表示輸入 X_{	ext{in}} 中被選擇像素的位置,被選中的像素會放到 G 對應的位置上作為輸出。我們有

S_i = 	au_{A}(G_i),

S = {S_i} = left{ (x_{1,1}^{S}, y_{1,1}^{S}), (x_{1,2}^{S}, y_{1,2}^{S}), cdots, (x_{U_{	ext{out}},V_{	ext{out}}}^{S}, y_{U_{	ext{out}},V_{	ext{out}}}^{S}) 
ight},

最後,採樣器根據$S$的位置在輸入特徵圖 X_{	ext{in}} 上採樣,比如雙線性插值採樣,生成 X_{	ext{out}}


上面說的比較籠統,下面以仿射變換為具體例子進行說明,如圖所示。

空間變換網路的變換過程。

圖中,右邊紅色的點表示網格 G ,變換 	au (圖中綠色虛線)應用到 G 上生成 S ,即左邊紅(藍)色的點。 S 表示採樣時的位置。

	au 可以是任意變換,比如仿射變換、平面射影變換、薄板樣條插值等。仿射變換包含平移、旋轉、縮放、偏斜、裁剪等操作,能夠滿足大部分圖像相關的任務,所以這裡我們令 	au 為二維仿射變換,所以 A 可以表示為如下矩陣

A = egin{bmatrix} a_{1,1} & a_{1,2} & a_{1,3} \ a_{2,1} & a_{2,2} & a_{2,3} end{bmatrix}.

仿射變換可以寫成

S_i = {x_i^S choose x_i^S} = 	au_{A}(G_i) = Aegin{pmatrix} x_i^{X_{	ext{out}}} \ y_i^{X_{	ext{out}}} \ 1end{pmatrix} = egin{bmatrix} a_{1,1} & a_{1,2} & a_{1,3} \ a_{2,1} & a_{2,2} & a_{2,3} end{bmatrix} egin{pmatrix} x_i^{X_{	ext{out}}} \ y_i^{X_{	ext{out}}} end{pmatrix} .

由於 S 是經過變換計算出來的,不一定能精確對應到 X_{	ext{in}} 中的像素,所以需要使用採樣器:

X_{	ext{out},i} = sum_u^{U_	ext{in}}sum_v^{V_	ext{in}}X_{	ext{in},u,v}k(x_i^S-v)k(y_i^S-u), forall iin[1,2,cdots,U_	ext{out}V_	ext{out}],

其中 k 以是任意對 x_i^Sy_i^S 可導的採樣器,最常用的就是雙線性插值:

X_{	ext{out},i} = sum_u^{U_	ext{in}}sum_v^{V_	ext{in}}X_{	ext{in},u,v}max(0,1-|x_i^S-v|)max(0,1-|y_i^S-u|), forall iin[1,2,cdots,U_	ext{out}V_	ext{out}],

這裡 X_{	ext{in}}X_{	ext{out}} 的坐標是歸一化的,也就是 (x_{1,1}^{X_{	ext{in}}}, y_{1,1}^{X_{	ext{in}}})=(-1,-1)(x_{U_	ext{in},V_	ext{in}}^{X_{	ext{in}}}, y_{U_	ext{in},V_	ext{in}}^{X_{	ext{in}}})=(+1,+1)

這樣,對於一個輸入特徵圖 X_{	ext{in}} ,注意力機制生成了一個能夠關注 X_{	ext{in}} 中感興趣區域的輸出 X_{	ext{out}}X_{	ext{out}} 送給後續模型進行處理。


  • 訓練

如果STN的三部分定位網路 phi_{W_{	ext{loc}}} 、網格生成器 	au 、採樣器 k 對於它們各自的輸入都是可微的,那麼STN也是一個可微的模型。舉個例子, phi_{W_{	ext{loc}}} 是一個可微的神經網路, 	au 是仿射變換, k 是雙線性插值,我們就可以得到

frac{partial X_{	ext{out},i}}{partial X_{	ext{in},u,v}} = sum_u^{U_	ext{in}}sum_v^{V_	ext{in}}max(0,1-|x_i^S-v|)max(0,1-|y_i^S-u|),

frac{partial X_{	ext{out},i}}{partial x_i^S} = sum_u^{U_	ext{in}}sum_v^{V_	ext{in}}X_{	ext{in},u,v}max(0,1-|y_i^S-u|) ,

frac{partial X_{	ext{out},i}}{partial y_i^S} = sum_u^{U_	ext{in}}sum_v^{V_	ext{in}}X_{	ext{in},u,v}max(0,1-|x_i^S-v|) ,

frac{partial x_i^S}{partial a_{1,1}} = x_i^{X_{	ext{out}}} ,

frac{partial x_i^S}{partial a_{1,2}} = y_i^{X_{	ext{out}}} ,

frac{partial x_i^S}{partial a_{1,3}} = 1 ,

frac{partial y_i^S}{partial a_{2,1}} = x_i^{X_{	ext{out}}} ,

frac{partial y_i^S}{partial a_{2,2}} = y_i^{X_{	ext{out}}} ,

frac{partial y_i^S}{partial a_{2,3}} = 1 ,

而定位網路參數的梯度 frac{partial A}{partial W_{	ext{loc}}} 就可以通過標準的反向傳播來求。這樣,整個STN網路就是可微的,可以通過梯度下降法來訓練。


原文在MNIST數據集上做的模擬實驗結果如圖所示,(a)列為輸入 X_{	ext{in}} ,(b)列展示了在輸入中的採樣區域 S ,(c)列是輸出 X_{	ext{out}} ,(d)是最終分類結果,過程是比較可解釋、有說服力的。

空間變換網路在MNIST數據集的模擬實驗。


注意力機制全集:

  • 計算機視覺中的注意力機制
  • 圖像描述:基於項的注意力機制
  • 數字串識別:基於位置的硬性注意力機制
  • 圖像識別:基於位置的柔性注意力機制

參考文獻:

Wang F, Tax D M J. Survey on the attention based RNN model and its applications in computer vision[J]. arXiv preprint arXiv:1601.06823, 2016.

Jaderberg M, Simonyan K, Zisserman A. Spatial transformer networks[C]//Advances in Neural Information Processing Systems. 2015: 2017-2025.

推薦閱讀:

Kaggle HousePrice : LB 0.11666(前15%), 用搭積木的方式(3.實踐-訓練、調參和Stacking)
softmax函數計算時候為什麼要減去一個最大值?

TAG:深度学习DeepLearning | 计算机视觉 | 机器学习 |