圖像識別:基於位置的柔性注意力機制
第一個基於位置的柔性注意力機制是NIPS 2015上DeepMind團隊提出的空間變換網路(Spatial Transformer Networks,STN),用在CNN模型當中幫助圖像識別。對於輸入的特徵圖,對識別圖像最有幫助的其中一部分可能不是規則(如矩形)的形狀,所以需要做不規則採樣。基於此,STN以整個特徵圖為輸入,特徵圖經過變換後生成新的特徵圖。典型的STN包含三部分:定位網路、網格生成器、採樣器,如圖所示。
令輸入 的尺寸為 ,輸出 的尺寸為 ,其中 、 、 分別表示特徵圖的高度、寬度和通道數。比如 是彩色圖片時,它的尺寸就是 。STN同步在每個通道上起作用,所以通道數不會改變, ,方便起見,後面的敘述會省略掉通道這個維度。
定位網路 以 為輸入,生成特徵圖變換參數
令網格 表示輸出特徵圖 的所有像素的位置,也就是
其中 表示像素的坐標。
那麼,由參數 定義的變換 會應用到網格 上面,生成變換後的網格 , 表示輸入 中被選擇像素的位置,被選中的像素會放到 對應的位置上作為輸出。我們有
最後,採樣器根據$S$的位置在輸入特徵圖 上採樣,比如雙線性插值採樣,生成 。
上面說的比較籠統,下面以仿射變換為具體例子進行說明,如圖所示。
圖中,右邊紅色的點表示網格 ,變換 (圖中綠色虛線)應用到 上生成 ,即左邊紅(藍)色的點。 表示採樣時的位置。
可以是任意變換,比如仿射變換、平面射影變換、薄板樣條插值等。仿射變換包含平移、旋轉、縮放、偏斜、裁剪等操作,能夠滿足大部分圖像相關的任務,所以這裡我們令 為二維仿射變換,所以 可以表示為如下矩陣
仿射變換可以寫成
由於 是經過變換計算出來的,不一定能精確對應到 中的像素,所以需要使用採樣器:
其中 以是任意對 和 可導的採樣器,最常用的就是雙線性插值:
這裡 和 的坐標是歸一化的,也就是 , 。
這樣,對於一個輸入特徵圖 ,注意力機制生成了一個能夠關注 中感興趣區域的輸出 。 送給後續模型進行處理。
- 訓練
如果STN的三部分定位網路 、網格生成器 、採樣器 對於它們各自的輸入都是可微的,那麼STN也是一個可微的模型。舉個例子, 是一個可微的神經網路, 是仿射變換, 是雙線性插值,我們就可以得到
而定位網路參數的梯度 就可以通過標準的反向傳播來求。這樣,整個STN網路就是可微的,可以通過梯度下降法來訓練。
原文在MNIST數據集上做的模擬實驗結果如圖所示,(a)列為輸入 ,(b)列展示了在輸入中的採樣區域 ,(c)列是輸出 ,(d)是最終分類結果,過程是比較可解釋、有說服力的。
注意力機制全集:
- 計算機視覺中的注意力機制
- 圖像描述:基於項的注意力機制
- 數字串識別:基於位置的硬性注意力機制
- 圖像識別:基於位置的柔性注意力機制
參考文獻:
Wang F, Tax D M J. Survey on the attention based RNN model and its applications in computer vision[J]. arXiv preprint arXiv:1601.06823, 2016.
Jaderberg M, Simonyan K, Zisserman A. Spatial transformer networks[C]//Advances in Neural Information Processing Systems. 2015: 2017-2025.
推薦閱讀:
※Kaggle HousePrice : LB 0.11666(前15%), 用搭積木的方式(3.實踐-訓練、調參和Stacking)
※softmax函數計算時候為什麼要減去一個最大值?
TAG:深度学习DeepLearning | 计算机视觉 | 机器学习 |