pytorch,mxnet的模型有200M,但為什麼做推理時顯存會佔用800M左右?

pytorch,mxnet的模型有200M,但為什麼做推理時顯存會佔用800M左右?


網路在訓練時,除了本身定義的可訓練參數要佔用顯存,還有中間變數(包含梯度)也要佔用顯存。


因為有輸入和中間變數啊。


前向傳播的時候會為反向傳播緩存求導相關中間變數。

如果你只是想做前向傳播,你可以這麼做。

model.eval()

with torch.no_grad():

result = model(input)


你說的那個是參數量,還有輸入輸出呢


考慮方程y=ax+b,a和b是模型,計算時x和y也要佔用顯存。


模型推理過程中顯存的佔用主要是兩方面,一個是模型參數,另一個是模型推理中的中間變數。比如說dense block,下圖中的跨層鏈接,在推理的過程中都會作為中間變數而佔用顯存

&-->


推薦閱讀:

又輸了!中國五位大神吊打,Open AI揮別Dota2賽場!
平安智慧城董事長俞太尉出席「全球智能經濟峰會暨第八屆智博會」
呂律:人工智慧與法制——KI in der Juristik
人工智慧與愛無能
2018蘋果發布會的重點是:AI人工智慧

TAG:人工智慧 | 神經網路 | 深度學習DeepLearning | MXNet | PyTorch |