tensorflow sess.run()越來越慢的原因分析及其解決方法
4 人贊了文章
最近在訓練一個檢測器,由於訓練數據不足因此需要做數據增強,那麼我這邊寫了代碼去做數據增強(這部分將會在下一篇進行介紹),其中使用到了tensorflow會話獲取數據,可是問題出現了!gtx 1080ti顯卡11G的內存,總共才處理六萬張圖像,但是運行速度越來越慢,眼看吃完晚飯散步回來幾個小時了還沒見有處理到一半,這就讓我不得不探索個究竟並期待解決這類問題了!
此篇文章,我們只談思路不談代碼,但會給一個基本的代碼框架以便更好說明原因!
首先,使用如下命令2S間隔監測一次gtx 1080ti的使用情況!
watch -n 2 nvidia-smi
得到的信息情況如下:
內存已經差不多吃完了,沒辦法,tensorflow運行起來就是這麼霸道,GPU顯存有多少就基本吃多少,當然我們可以調整其使用內存量,但在這裡不做討論。再仔細看看,是不是發現什麼了,沒錯啦,中部右側的0%顯示GPU使用率為0,what the fuck,什麼情況啊,難道不是用GPU來計算處理的嗎?可是代碼跑起來確實有列印如下使用GPU信息的,而且驅動也沒裝錯啊,之前都有正確運行了的!
那什麼原因呢?我也不知道,但是總得想個辦法來看看到底卡在哪裡了,到底是讀進圖像時卡住了,還是tensorflow某個操作耗時太久了,還是?然後,我想到了個方法,使用datetime.datetime.now()獲取時間,並列印某一個操作前後的時間差,哈哈,總算得到一些有效信息了!
看看一開始的列印信息:
起初列印出來的總耗時還是蠻低的,大概也就0.01~0.04s的範圍,且分析到在23節點列印出來的耗時是最大的,而23節點是列印sess.run()處理前後的間隔耗時!
那麼運行十來分鐘後,此時的時間列印信息如下:
這時候列印出來的總耗時已經增大到0.7s左右了,足足增大了15~70倍的耗時時間,而這才運行了十來分鐘而已!那麼我們看到23節點列印出來的耗時此時也是最大的,而且整個耗時的增大也基本來自這個23節點即sess.run()產生的耗時!
問題直指sess.run()隨著時間的拉長其運行速度越來越慢!
那麼我嘗試到google上搜索sess.run()運行越來越慢的原因,有找到如下類似問題:
why tensorflow run slow in loop · Issue #1439 · tensorflow/tensorflow
這裡提到如下在某一個循環里,不斷建立tensorflow圖節點再運行的話,會導致tensorflow運行越來越慢,有問題的代碼結構大概長這樣子:
for step in range(total_step): tfops = tf add Ops ... sess.run(tf.ops)
看到github上的分析,突然豁然開朗,tensorflow都是符號型結構的,它是在運行之前先建立好一張圖並確認好張量的流向,再在迭代中不斷喂數據進行訓練的,如果我們在循環里不斷的添加節點就導致tensorflow耗時在維護圖結構上了。
github提供的解決思路是在sess.run()之前建立好圖再運行,可是,不巧,我要做的事情,就是要運行動態的tensorflow圖,即每一次運行的圖結構都可能不一樣,並非是固定圖結構,且圖像size不一,也沒辦法進行placeholder放置管道,這可沒有任何答案告訴我怎麼辦啊!
怎麼辦呢?
我一開始調了代碼結構,比如將多個需要session運行的操作放在同一個session里運行,可是實驗反饋無效;另外一個,我主動在運行完圖後銷毀會話即使用sess.close()可是運行起來還是越來越慢;再然後,我試驗在每一次調用sess.run()之前,調用tf.graph()並tf.graph.as_default()或tf.graph.as_graph_def()以為可以每一次都重新建立一張圖來運行並讓tensorflow自己銷毀掉之前的圖,可是,好事多磨啊,最終還是不行啊!
不過,我思路也是清晰的,我就是想要動態建立一張新圖並銷毀掉舊的圖,那麼我一直研究tensorflow Ops.py的源碼,裡面有定義了tensorflow圖介面,最後查到裡面有reset和finalize的圖介面,想必就是我要的介面了,馬上按照如下代碼框架進行試驗!
for step in range(total_step): tf.reset_default_graph() with tf.Session() as sess: tfops = tf add Ops ... sess.run(tfops) tf.get_default_graph().finalize()
哇塞,驗證結果顯示,代碼跑得飛快,嘩啦啦的一直飛速運行就沒有卡過,真是「有心人,天不負」!此時再看gtx 1080ti顯卡的使用率已經基本維持在3%左右了,每一次運行耗時都基本在0.03s左右,六萬張圖像最終處理時間也不過三十分鐘左右,速度提升還是很大的!
推薦閱讀:
※Autoencoder及tensorflow實現
※tensorflow訓練自己製作的tfrecords文件時出現ValueError?
※TensorFlow學習筆記之四——源碼分析之基本操作
※【深度學習不是犯罪】歐盟祭出最嚴數據保護法:專家解讀 GDPR
※貓狗大戰(1)處理數據
TAG:TensorFlow | 深度學習DeepLearning | 機器學習 |