DARNN:一種新的時間序列預測方法——基於雙階段注意力機制的循環神經網路
來自專欄 Machine Failure32 人贊了文章
RNN及seq2seq model的基礎知識可以參見詳解RNN及其變種:GRU,LSTM、Seq2seq模型及注意力機制
論文參見A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction
論文題目為《基於雙階段注意力機制的循環神經網路》。本文介紹了一種基於seq2seq模型(encoder decoder 模型),並結合注意力機制的一種時間序列預測方法。與傳統的注意力機制只用在解碼器的輸入階段,即對不同時刻產生不同的context vector不同,該文還在編碼器的輸入階段引入了注意力機制,從而同時實現了選取特徵因子(feature selection)和把握長期時序依賴關係(long-term temporal dependencies)。
第一階段,使用注意力機制自適應地提取每個時刻的相關feature;第二階段,使用另一個注意力機制選取與之相關的encoder hidden states。
使用input attention的編碼器
我們定義為encoder在時刻的hidden state, 其中是hidden state的size。
第一階段,使用當前時刻的輸入,以及上一個時刻編碼器的hidden state,來計算當前時刻編碼器的hidden state,其中m是編碼器的size。更新公式可寫為:
對於這個問題,我們可以使用通常的循環神經網路vanilla RNN或LSTM以及GRU作為。但為了自適應地選取相關feature,作者在此處引入了注意力機制。簡單來說,即對每個時刻的輸入,為其中的每個影響因子賦予一定的注意力權重(attention weight)。衡量了時刻的第個feature的重要性。更新後的為
如何計算?
可以根據上一個時刻編碼器的hidden state和cell state計算得到:
其中是hidden state與cell state的連接(concatenation)。我的理解是與類似,只不過少了一個需要訓練的參數。 該式即把第個driving series與前一個時刻的hidden state和cell state線性組合,再用tanh激活得到。
得到後,再用softmax函數將其歸一化:
使用更新後的作為編碼器的輸入
得到了更新後的,作者又選取了LSTM作為編碼器
通過上述的input attention機制,編碼器能夠focus on其中重要的驅動因子,而不是對所有因子一視同仁。
使用temporal attention的解碼器
為了區別起見,與論文中公式略有不同的是,我將解碼器中的時間序列下標標註為,以與編碼器種的下標區分。
第二階段的解碼器注意力機制設計類似於傳統的attention based seq2seq model。基本的出發點為,傳統的seq2seq模型中,編碼器輸出的context vector基於最後時刻的hidden state或對所有hidden state取平均。這樣輸出的context vector對所有時刻均相同,無法起到只選取相關時刻編碼器hidden state的功能。我們自然地想到可以在不同時刻採用不同的context vector。類似於seq2seq,最簡單的辦法是對所有時刻的取加權平均,即:
的設計類似於Bahanau的工作,基於前一個時刻解碼器的hidden state和cell state計算得到:
解碼器的輸入是上一個時刻的目標序列和hidden state以及context vector,即
作者在這裡設計了來combie與的信息,即
然後
類似於編碼器的最後一個公式,這裡仍舊使用LSTM作為。
Final prediction
回顧一下非線性自回歸(Nonlinear autoregressive exogenous, NARX)模型的最終目標,我們需要建立當前輸入與所有時刻的輸入以及之前時刻的輸出之間的關係,即:
通過之前編碼器解碼器模型的學習,我們已經得到了解碼器的hidden state 和context vector,與。我們再使用一個全連接層對做回歸,即
這樣可以得到最終的預測
Following
作為一個非常新的時間序列模型,作者在Nasdaq100數據集上實現了很好的預測效果
我在網上找到了該篇論文的PyTorch實現,準備試著跑一下http://chandlerzuo.github.io/blog/2017/11/darnn
推薦閱讀:
TAG:神經網路 | 深度學習DeepLearning | 時間序列分析 |