Recurrent Models of Visual Attention 源碼解析
RAM github地址,網路架構與另一篇論文可以說一模一樣,請移步Deep Reinforcement Learning for Visual Object Tracking in Videos 論文解讀。
該論文主旨在找尋mnist字母的中心點,以分類的成功率作為reward function。換句話說,找到的字母中心點越近,分類成功率越高,獎勵越高。
首先使用tensorflow 的 rnn_encoder得到每一個輸出
outputs, _ = seq2seq.rnn_decoder( inputs, init_state, lstm_cell, loop_function=get_next_input)
再將每個輸出都過一個全鏈接層,這一步的目的是訓練reward function,
for t, output in enumerate(outputs[1:]): baseline_t = tf.nn.xw_plus_b(output, w_baseline, b_baseline) baseline_t = tf.squeeze(baseline_t) baselines.append(baseline_t)baselines = tf.stack(baselines) # [timesteps, batch_sz]baselines = tf.transpose(baselines) # [batch_sz, timesteps]
獎勵的目標為分類的成功率。
advs = rewards - tf.stop_gradient(baselines)logllratio = tf.reduce_mean(logll * advs)reward = tf.reduce_mean(reward)baselines_mse = tf.reduce_mean(tf.square((rewards - baselines)))
這一步小弟不太理解,老師的解釋是:期望每一個輸出都趨於高斯分布,logll就是與期望的高斯分布的差值。
logll = loglikelihood(loc_mean_arr, sampled_loc_arr, config.loc_std)
最後,由loss, reward,以及上一個logll進行梯度更新。
loss = -logllratio + xent + baselines_mse # `-` for minimizegrads = tf.gradients(loss, var_list)
以上就是源碼的大概解析,僅提供一個參考。
推薦閱讀:
※機器學習中的優化方法
※機器學習萌新必學的Top10演算法
※機器學習基本套路
※機器學習入門基礎——邏輯回歸
※ML6-Keras1 "hello world of deep learning "(李宏毅筆記)
TAG:計算機視覺 | 深度學習DeepLearning | 機器學習 |