pytorch,mxnet的模型有200M,但為什麼做推理時顯存會佔用800M左右?
02-27
pytorch,mxnet的模型有200M,但為什麼做推理時顯存會佔用800M左右?
網路在訓練時,除了本身定義的可訓練參數要佔用顯存,還有中間變數(包含梯度)也要佔用顯存。
因為有輸入和中間變數啊。
前向傳播的時候會為反向傳播緩存求導相關中間變數。
如果你只是想做前向傳播,你可以這麼做。
model.eval()
with torch.no_grad():
result = model(input)
你說的那個是參數量,還有輸入輸出呢
考慮方程y=ax+b,a和b是模型,計算時x和y也要佔用顯存。
模型推理過程中顯存的佔用主要是兩方面,一個是模型參數,另一個是模型推理中的中間變數。比如說dense block,下圖中的跨層鏈接,在推理的過程中都會作為中間變數而佔用顯存
&-->
推薦閱讀:
※又輸了!中國五位大神吊打,Open AI揮別Dota2賽場!
※平安智慧城董事長俞太尉出席「全球智能經濟峰會暨第八屆智博會」
※呂律:人工智慧與法制——KI in der Juristik
※人工智慧與愛無能
※2018蘋果發布會的重點是:AI人工智慧
TAG:人工智慧 | 神經網路 | 深度學習DeepLearning | MXNet | PyTorch |