數字串識別:基於位置的硬性注意力機制

圖像識別是計算機視覺最經典和基礎的應用,現在最常用的模型是CNN。但是CNN的輸入尺寸固定,有一定的局限性:

  • 當輸入圖片比CNN的輸入尺寸大的時候,需要把圖片縮放至規定尺寸,或者切成一些子區域分別處理;但是這樣要麼使得圖片失去一些細節,要麼計算量會大增。
  • 雖然CNN具有一定的空間變換不變性,但是當圖片中只有一小部分區域和標籤相關時,CNN很難保證滿意的效果。

而基於位置的硬性注意力機制選取圖像中一個合適的子區域,交給後續模型進行處理,能夠一定程度克服上述問題。


我們以自然場景數字串識別為例來講基於位置的硬性注意力機制。數字串識別任務是輸入一張圖片,生成對該圖像包含的數字串,可以表示為序列

y={bm{y}_1,cdots,bm{y}_C}, bm{y}_iinmathbb{R}^{10},

其中 C 是數字串的長度。

數字串識別任務,類似於圖像描述,是一個序列預測過程。在每個時刻$t$,注意力機制要從原圖中選出中心位置為 s_t ,高為 h ,寬為 w 的一個子區域,送給編碼器計算中間特徵,然後用RNN解碼器進行解碼輸出預測,如下圖所示。

使用基於位置的硬性注意力機制的數字串識別方法。

注意力機制以原圖 bm{x} 和解碼器上個時刻的狀態 bm{h}_{t-1} 為輸入,利用如下的高斯分布輸出中心位置 s_t

s_t sim mathcal{N}left( f_{text{att}}(bm{x},bm{h}_{t-1}), d right),

其中標準差 d 是一個超參數, f_{text{att}} 通常是一個神經網路。子區域的高 h 和寬 w 也可以通過類似的方式產生,或者根據經驗設置成超參數。

那麼,注意力機制選出的子區域可以表示為 x_{s_t} ,經過編碼器計算出中間特徵:

bm{c}_t = phi_{W_{text{enc}}}(bm{x}_{s_t}),

然後通過解碼器進行解碼

{hat{bm{y}}_t choose bm{h}_t} = phi_{W_{text{dec}}}(bm{c}_t,hat{bm y}_{t-1}, bm{h}_{t-1}),

其中 phi_{W_{text{enc}}}phi_{W_{text{dec}}} 分別是編碼器和解碼器的函數表示。

訓練時使用對數似然函數作為目標函數,訓練過程和基於項的硬性注意力類似,參考圖像描述:基於項的注意力機制。

原文給出了在街景房屋數字串是被數據集SVNH實驗中的注意力機制選擇過程視頻:psi.toronto.edu/~jimmy/,(我上傳到這裡方便觀看)從中可以看出注意力模型能夠從左到右選擇數字所在區域。

https://www.zhihu.com/video/938543299280261120


注意力機制全集:

  • 計算機視覺中的注意力機制
  • 圖像描述:基於項的注意力機制
  • 數字串識別:基於位置的硬性注意力機制
  • 圖像識別:基於位置的柔性注意力機制

參考文獻:

Wang F, Tax D M J. Survey on the attention based RNN model and its applications in computer vision[J]. arXiv preprint arXiv:1601.06823, 2016.

Ba J, Mnih V, Kavukcuoglu K. Multiple object recognition with visual attention[C]. ICLR, 2015.

推薦閱讀:

線性方程組之一:列向量觀點
lightGBM
搞機器學習/AI有什麼必備的數學基礎?|經驗之談+資源大全

TAG:深度学习DeepLearning | 计算机视觉 | 机器学习 |