Recurrent Models of Visual Attention 源碼解析

RAM github地址,網路架構與另一篇論文可以說一模一樣,請移步Deep Reinforcement Learning for Visual Object Tracking in Videos 論文解讀。

該論文主旨在找尋mnist字母的中心點,以分類的成功率作為reward function。換句話說,找到的字母中心點越近,分類成功率越高,獎勵越高。

首先使用tensorflow 的 rnn_encoder得到每一個輸出

outputs, _ = seq2seq.rnn_decoder( inputs, init_state, lstm_cell, loop_function=get_next_input)

再將每個輸出都過一個全鏈接層,這一步的目的是訓練reward function,

for t, output in enumerate(outputs[1:]): baseline_t = tf.nn.xw_plus_b(output, w_baseline, b_baseline) baseline_t = tf.squeeze(baseline_t) baselines.append(baseline_t)baselines = tf.stack(baselines) # [timesteps, batch_sz]baselines = tf.transpose(baselines) # [batch_sz, timesteps]

獎勵的目標為分類的成功率。

advs = rewards - tf.stop_gradient(baselines)logllratio = tf.reduce_mean(logll * advs)reward = tf.reduce_mean(reward)baselines_mse = tf.reduce_mean(tf.square((rewards - baselines)))

這一步小弟不太理解,老師的解釋是:期望每一個輸出都趨於高斯分布,logll就是與期望的高斯分布的差值。

logll = loglikelihood(loc_mean_arr, sampled_loc_arr, config.loc_std)

最後,由loss, reward,以及上一個logll進行梯度更新。

loss = -logllratio + xent + baselines_mse # `-` for minimizegrads = tf.gradients(loss, var_list)

以上就是源碼的大概解析,僅提供一個參考。


推薦閱讀:

機器學習中的優化方法
機器學習萌新必學的Top10演算法
機器學習基本套路
機器學習入門基礎——邏輯回歸
ML6-Keras1 "hello world of deep learning "(李宏毅筆記)

TAG:計算機視覺 | 深度學習DeepLearning | 機器學習 |