PyTorch實戰指南

這不是一篇PyTorch的入門教程!

本文較長,你可能需要花費20分鐘才能看懂大部分內容

建議在電腦,結合代碼閱讀本文

本指南的配套代碼地址: chenyuntc/pytorch-best-practice

在學習某個深度學習框架時,掌握其基本知識和介面固然重要,但如何合理組織代碼,使得代碼具有良好的可讀性和可擴展性也必不可少。本文不會深入講解過多知識性的東西,更多的則是傳授一些經驗,關於如何使得自己的程序更pythonic,更符合pytorch的設計理念。這些內容可能有些爭議,因其受我個人喜好和coding風格影響較大,你可以將這部分當成是一種參考或提議,而不是作為必須遵循的準則。歸根到底,都是希望你能以一種更為合理的方式組織自己的程序。

在做深度學習實驗或項目時,為了得到最優的模型結果,中間往往需要很多次的嘗試和修改。根據我的個人經驗,在從事大多數深度學習研究時,程序都需要實現以下幾個功能:

  • 模型定義
  • 數據處理和載入
  • 訓練模型(Train&Validate)
  • 訓練過程的可視化
  • 測試(Test/Inference)

另外程序還應該滿足以下幾個要求:

  • 模型需具有高度可配置性,便於修改參數、修改模型,反覆實驗
  • 代碼應具有良好的組織結構,使人一目了然
  • 代碼應具有良好的說明,使其他人能夠理解

在本文我將應用這些內容,並結合實際的例子,來講解如何用PyTorch完成Kaggle上的經典比賽:Dogs vs. Cats。本文所有示常式序均在github上開源 github.com/chenyuntc/py

目錄

1 比賽介紹

2 文件組織架構

3 關於__init__.py

4 數據載入

5 模型定義

6 工具函數

7 配置文件

8 main.py

8.1 訓練

8.2 驗證

8.3 測試

8.4 幫助函數

9 使用

10 爭議

1 比賽介紹

Dogs vs. Cats是一個傳統的二分類問題,其訓練集包含25000張圖片,均放置在同一文件夾下,命名格式為<category>.<num>.jpg, 如cat.10000.jpgdog.100.jpg,測試集包含12500張圖片,命名為<num>.jpg,如1000.jpg。參賽者需根據訓練集的圖片訓練模型,並在測試集上進行預測,輸出它是狗的概率。最後提交的csv文件如下,第一列是圖片的<num>,第二列是圖片為狗的概率。

id,labeln10001,0.889n10002,0.01n...n

2 文件組織架構

首先來看程序文件的組織結構:

├── checkpoints/n├── data/n│ ├── __init__.pyn│ ├── dataset.pyn│ └── get_data.shn├── models/n│ ├── __init__.pyn│ ├── AlexNet.pyn│ ├── BasicModule.pyn│ └── ResNet34.pyn└── utils/n│ ├── __init__.pyn│ └── visualize.pyn├── config.pyn├── main.pyn├── requirements.txtn├── README.mdn

其中:

  • checkpoints/: 用於保存訓練好的模型,可使程序在異常退出後仍能重新載入模型,恢復訓練
  • data/:數據相關操作,包括數據預處理、dataset實現等
  • models/:模型定義,可以有多個模型,例如上面的AlexNet和ResNet34,一個模型對應一個文件
  • utils/:可能用到的工具函數,在本次實驗中主要是封裝了可視化工具
  • config.py:配置文件,所有可配置的變數都集中在此,並提供默認值
  • main.py:主文件,訓練和測試程序的入口,可通過不同的命令來指定不同的操作和參數
  • requirements.txt:程序依賴的第三方庫
  • README.md:提供程序的必要說明

3 關於__init__.py

可以看到,幾乎每個文件夾下都有__init__.py,一個目錄如果包含了__init__.py 文件,那麼它就變成了一個包(package)。__init__.py可以為空,也可以定義包的屬性和方法,但其必須存在,其它程序才能從這個目錄中導入相應的模塊或函數。例如在data/文件夾下有__init__.py,則在main.py 中就可以

from data.dataset import DogCatn

而如果在data/__init__.py中寫入

from .dataset import DogCatn

則在main.py中就可以直接寫為:

from data import DogCatn

或者

import data;ndataset = data.DogCatn

相比於from data.dataset import DogCat更加便捷。

4 數據載入

數據的相關處理主要保存在data/dataset.py中。關於數據載入的相關操作,其基本原理就是使用Dataset進行數據集的封裝,再使用Dataloader實現數據並行載入。

Kaggle提供的數據包括訓練集和測試集,而我們在實際使用中,還需專門從訓練集中取出一部分作為驗證集。對於這三類數據集,其相應操作也不太一樣,而如果專門寫三個Dataset,則稍顯複雜和冗餘,因此這裡通過加一些判斷來區分。對於訓練集,我們希望做一些數據增強處理,如隨機裁剪、隨機翻轉、加雜訊等,而驗證集和測試集則不需要。下面看dataset.py的代碼:

import osnfrom PIL import Imagenfrom torch.utils import datanimport numpy as npnfrom torchvision import transforms as Tnnclass DogCat(data.Dataset):n n def __init__(self, root, transforms=None, train=True, test=False):n n 目標:獲取所有圖片路徑,並根據訓練、驗證、測試劃分數據n n self.test = testn imgs = [os.path.join(root, img) for img in os.listdir(root)] n # 訓練集和驗證集的文件命名不一樣n # test1: data/test1/8973.jpgn # train: data/train/cat.10004.jpg n if self.test:n imgs = sorted(imgs, key=lambda x: int(x.split(.)[-2].split(/)[-1]))n else:n imgs = sorted(imgs, key=lambda x: int(x.split(.)[-2]))n n imgs_num = len(imgs)n n # shuffle imgsn np.random.seed(100)n imgs = np.random.permutation(imgs)nn # 劃分訓練、驗證集,驗證:訓練 = 3:7n if self.test:n self.imgs = imgsn elif train:n self.imgs = imgs[:int(0.7*imgs_num)]n else :n self.imgs = imgs[int(0.7*imgs_num):] n n if transforms is None:n n # 數據轉換操作,測試驗證和訓練的數據轉換有所區別 n normalize = T.Normalize(mean = [0.485, 0.456, 0.406], n std = [0.229, 0.224, 0.225])nn # 測試集和驗證集不用數據增強n if self.test or not train: n self.transforms = T.Compose([n T.Scale(224),n T.CenterCrop(224),n T.ToTensor(),n normalizen ]) n # 訓練集需要數據增強n else :n self.transforms = T.Compose([n T.Scale(256),n T.RandomSizedCrop(224),n T.RandomHorizontalFlip(),n T.ToTensor(),n normalizen ]) n n n def __getitem__(self, index):n n 返回一張圖片的數據n 對於測試集,沒有label,返回圖片id,如1000.jpg返回1000n n img_path = self.imgs[index]n if self.test: n label = int(self.imgs[index].split(.)[-2].split(/)[-1])n else: n label = 1 if dog in img_path.split(/)[-1] else 0n data = Image.open(img_path)n data = self.transforms(data)n return data, labeln n def __len__(self):n n 返回數據集中所有圖片的個數n n return len(self.imgs)n

關於數據集使用的注意事項,在上一章中已經提到,將文件讀取等費時操作放在__getitem__函數中,利用多進程加速。避免一次性將所有圖片都讀進內存,不僅費時也會佔用較大內存,而且不易進行數據增強等操作。另外在這裡,我們將訓練集中的30%作為驗證集,可用來檢查模型的訓練效果,避免過擬合。

在使用時,我們可通過dataloader載入數據。

train_dataset = DogCat(opt.train_data_root, train=True)ntrainloader = DataLoader(train_dataset,n batch_size = opt.batch_size,n shuffle = True,n num_workers = opt.num_workers)n nfor ii, (data, label) in enumerate(trainloader):nttrain()n

5 模型定義

模型的定義主要保存在models/目錄下,其中BasicModule是對nn.Module的簡易封裝,提供快速載入和保存模型的介面。

class BasicModule(t.nn.Module):n n 封裝了nn.Module,主要提供save和load兩個方法n nn def __init__(self,opt=None):n super(BasicModule,self).__init__()n self.model_name = str(type(self)) # 模型的默認名字nn def load(self, path):n n 可載入指定路徑的模型n n self.load_state_dict(t.load(path))nn def save(self, name=None):n n 保存模型,默認使用「模型名字+時間」作為文件名,n 如AlexNet_0710_23:57:29.pthn n if name is None:n prefix = checkpoints/ + self.model_name + _n name = time.strftime(prefix + %m%d_%H:%M:%S.pth)n t.save(self.state_dict(), name)n return namen

在實際使用中,直接調用model.save()及model.load(opt.load_path)即可。

其它自定義模型一般繼承BasicModule,然後實現自己的模型。其中AlexNet.py實現了AlexNet,ResNet34實現了ResNet34。在models/__init__py中,代碼如下:

from .AlexNet import AlexNetnfrom .ResNet34 import ResNet34n

這樣在主函數中就可以寫成:

from models import AlexNetn

import modelsnmodel = models.AlexNet()n

import modelsnmodel = getattr(models, AlexNet)()n

其中最後一種寫法最為關鍵,這意味著我們可以通過字元串直接指定使用的模型,而不必使用判斷語句,也不必在每次新增加模型後都修改代碼。新增模型後只需要在models/__init__.py中加上

from .new_module import NewModulen

即可。

其它關於模型定義的注意事項,在上一章中已詳細講解,這裡就不再贅述,總結起來就是:

  • 盡量使用nn.Sequential(比如AlexNet)
  • 將經常使用的結構封裝成子Module(比如GoogLeNet的Inception結構,ResNet的Residual Block結構)
  • 將重複且有規律性的結構,用函數生成(比如VGG的多種變體,ResNet多種變體都是由多個重複卷積層組成)

感興趣的 讀者可以看看在`models/resnet34.py`如何用不到80行的代碼(包括空行和注釋)實現resnet34。當然這些模型在torchvision中有實現,而且還提供了預訓練的權重,讀者可以很方便的使用:

import torchvision as tvnresnet34 = tv.models.resnet34(pretrained=True)n

6 工具函數

在項目中,我們可能會用到一些helper方法,這些方法可以統一放在utils/文件夾下,需要使用時再引入。在本例中主要是封裝了可視化工具visdom的一些操作,其代碼如下,在本次實驗中只會用到plot方法,用來統計損失信息。

#coding:utf8nimport visdomnimport timenimport numpy as npnnclass Visualizer(object):n n 封裝了visdom的基本操作,但是你仍然可以通過`self.vis.function`n 或者`self.function`調用原生的visdom介面n 比如 n self.text(hello visdom)n self.histogram(t.randn(1000))n self.line(t.arange(0, 10),t.arange(1, 11))n nn def __init__(self, env=default, **kwargs):n self.vis = visdom.Visdom(env=env, **kwargs)n n # 畫的第幾個數,相當於橫坐標n # 比如(』loss,23) 即loss的第23個點n self.index = {} n self.log_text = n def reinit(self, env=default, **kwargs):n n 修改visdom的配置n n self.vis = visdom.Visdom(env=env, **kwargs)n return selfnn def plot_many(self, d):n n 一次plot多個n @params d: dict (name, value) i.e. (loss, 0.11)n n for k, v in d.iteritems():n self.plot(k, v)nn def img_many(self, d):n for k, v in d.iteritems():n self.img(k, v)nn def plot(self, name, y, **kwargs):n n self.plot(loss, 1.00)n n x = self.index.get(name, 0)n self.vis.line(Y=np.array([y]), X=np.array([x]),n win=unicode(name),n opts=dict(title=name),n update=None if x == 0 else append,n **kwargsn )n self.index[name] = x + 1nn def img(self, name, img_, **kwargs):n n self.img(input_img, t.Tensor(64, 64))n self.img(input_imgs, t.Tensor(3, 64, 64))n self.img(input_imgs, t.Tensor(100, 1, 64, 64))n self.img(input_imgs, t.Tensor(100, 3, 64, 64), nrows=10)n n self.vis.images(img_.cpu().numpy(),n win=unicode(name),n opts=dict(title=name),n **kwargsn )nn def log(self, info, win=log_text):n n self.log({loss:1, lr:0.0001})n nn self.log_text += ([{time}] {info} <br>.format(n time=time.strftime(%m%d_%H%M%S),n info=info)) n self.vis.text(self.log_text, win) nn def __getattr__(self, name):n n self.function 等價於self.vis.functionn 自定義的plot,image,log,plot_many等除外n n return getattr(self.vis, name)n

7 配置文件

在模型定義、數據處理和訓練等過程都有很多變數,這些變數應提供默認值,並統一放置在配置文件中,這樣在後期調試、修改代碼或遷移程序時會比較方便,在這裡我們將所有可配置項放在config.py中。

class DefaultConfig(object):n env = default # visdom 環境n model = AlexNet # 使用的模型,名字必須與models/__init__.py中的名字一致n n train_data_root = ./data/train/ # 訓練集存放路徑n test_data_root = ./data/test1 # 測試集存放路徑n load_model_path = checkpoints/model.pth # 載入預訓練的模型的路徑,為None代表不載入nn batch_size = 128 # batch sizen use_gpu = True # use GPU or notn num_workers = 4 # how many workers for loading datan print_freq = 20 # print info every N batchnn debug_file = /tmp/debug # if os.path.exists(debug_file): enter ipdbn result_file = result.csvn n max_epoch = 10n lr = 0.1 # initial learning raten lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decayn weight_decay = 1e-4 # 損失函數n

可配置的參數主要包括:

  • 數據集參數(文件路徑、batch_size等)
  • 訓練參數(學習率、訓練epoch等)
  • 模型參數

這樣我們在程序中就可以這樣使用:

import modelsnfrom config import DefaultConfignnopt = DefaultConfig()nlr = opt.lrnmodel = getattr(models, opt.model)ndataset = DogCat(opt.train_data_root)n

這些都只是默認參數,在這裡還提供了更新函數,根據字典更新配置參數。

def parse(self, kwargs):n n 根據字典kwargs 更新 config參數n n # 更新配置參數n for k, v in kwargs.iteritems():n if not hasattr(self, k):n # 警告還是報錯,取決於你個人的喜好n warnings.warn("Warning: opt has not attribut %s" %k)n setattr(self, k, v)n n # 列印配置信息tn print(user config:)n for k, v in self.__class__.__dict__.iteritems():n if not k.startswith(__):n print(k, getattr(self, k))n

這樣我們在實際使用時,並不需要每次都修改config.py,只需要通過命令行傳入所需參數,覆蓋默認配置即可。

例如:

opt = DefaultConfig()nnew_config = {lr:0.1,use_gpu:False}nopt.parse(new_config)nopt.lr == 0.1n

8 main.py

在講解主程序main.py之前,我們先來看看2017年3月谷歌開源的一個命令行工具fire ,通過pip install fire即可安裝。下面來看看fire的基礎用法,假設example.py文件內容如下:

import firendef add(x, y):n return x + yn ndef mul(**kwargs):n a = kwargs[a]n b = kwargs[b]n return a * bnnif __name__ == __main__:n fire.Fire()n

那麼我們可以使用:

python example.py add 1 2 # 執行add(1, 2)npython example.py mul --a=1 --b=2 # 執行mul(a=1, b=2),kwargs={a:1, b:2}npython example.py add --x=1 --y=2 # 執行add(x=1, y=2)n

可見,只要在程序中運行fire.Fire(),即可使用命令行參數python file <function> [args,] {--kwargs,}。fire還支持更多的高級功能,具體請參考官方指南 。

在主程序main.py中,主要包含四個函數,其中三個需要命令行執行,main.py的代碼組織結構如下:

def train(**kwargs):n n 訓練n n passn ndef val(model, dataloader):n n 計算模型在驗證集上的準確率等信息,用以輔助訓練n n passnndef test(**kwargs):n n 測試(inference)n n passnndef help():n n 列印幫助的信息 n n print(help)nnif __name__==__main__:n import firen fire.Fire()n

根據fire的使用方法,可通過python main.py <function> --args=xx的方式來執行訓練或者測試。

8.1 訓練

訓練的主要步驟如下:

  • 定義網路
  • 定義數據
  • 定義損失函數和優化器
  • 計算重要指標
  • 開始訓練
    • 訓練網路
    • 可視化各種指標
    • 計算在驗證集上的指標

訓練函數的代碼如下:

def train(**kwargs): n # 根據命令行參數更新配置n opt.parse(kwargs)n vis = Visualizer(opt.env)n n # step1: 模型n model = getattr(models, opt.model)()n if opt.load_model_path:n model.load(opt.load_model_path)n if opt.use_gpu: model.cuda()nn # step2: 數據n train_data = DogCat(opt.train_data_root,train=True)n val_data = DogCat(opt.train_data_root,train=False)n train_dataloader = DataLoader(train_data,opt.batch_size,n shuffle=True,n num_workers=opt.num_workers)n val_dataloader = DataLoader(val_data,opt.batch_size,n shuffle=False,n num_workers=opt.num_workers)n n # step3: 目標函數和優化器n criterion = t.nn.CrossEntropyLoss()n lr = opt.lrn optimizer = t.optim.Adam(model.parameters(),n lr = lr,n weight_decay = opt.weight_decay)n n # step4: 統計指標:平滑處理之後的損失,還有混淆矩陣n loss_meter = meter.AverageValueMeter()n confusion_matrix = meter.ConfusionMeter(2)n previous_loss = 1e100nn # 訓練n for epoch in range(opt.max_epoch):n n loss_meter.reset()n confusion_matrix.reset()nn for ii,(data,label) in enumerate(train_dataloader):nn # 訓練模型n input = Variable(data)n target = Variable(label)n if opt.use_gpu:n input = input.cuda()n target = target.cuda()n optimizer.zero_grad()n score = model(input)n loss = criterion(score,target)n loss.backward()n optimizer.step()n n # 更新統計指標以及可視化n loss_meter.add(loss.data[0])n confusion_matrix.add(score.data, target.data)nn if ii%opt.print_freq==opt.print_freq-1:n vis.plot(loss, loss_meter.value()[0])n n # 如果需要的話,進入debug模式n if os.path.exists(opt.debug_file):n import ipdb;n ipdb.set_trace()nn model.save()nn # 計算驗證集上的指標及可視化n val_cm,val_accuracy = val(model,val_dataloader)n vis.plot(val_accuracy,val_accuracy)n vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}"n .format(n epoch = epoch,n loss = loss_meter.value()[0],n val_cm = str(val_cm.value()),n train_cm=str(confusion_matrix.value()),n lr=lr))n n # 如果損失不再下降,則降低學習率n if loss_meter.value()[0] > previous_loss: n lr = lr * opt.lr_decayn for param_group in optimizer.param_groups:n param_group[lr] = lrn n previous_loss = loss_meter.value()[0]n

這裡用到了PyTorchNet裡面的一個工具: metermeter提供了一些輕量級的工具,用於幫助用戶快速統計訓練過程中的一些指標。AverageValueMeter能夠計算所有數的平均值和標準差,這裡用來統計一個epoch中損失的平均值。confusionmeter用來統計分類問題中的分類情況,是一個比準確率更詳細的統計指標。例如對於表格6-1,共有50張狗的圖片,其中有35張被正確分類成了狗,還有15張被誤判成貓;共有100張貓的圖片,其中有91張被正確判為了貓,剩下9張被誤判成狗。相比於準確率等統計信息,混淆矩陣更能體現分類的結果,尤其是在樣本比例不均衡的情況下。

8.2 驗證

驗證相對來說比較簡單,但要注意需將模型置於驗證模式(model.eval()),驗證完成後還需要將其置回為訓練模式(model.train()),這兩句代碼會影響BatchNorm和Dropout等層的運行模式。代碼如下。

def val(model,dataloader):n n 計算模型在驗證集上的準確率等信息n n # 把模型設為驗證模式n model.eval()n n confusion_matrix = meter.ConfusionMeter(2)n for ii, data in enumerate(dataloader):n input, label = datan val_input = Variable(input, volatile=True)n val_label = Variable(label.long(), volatile=True)n if opt.use_gpu:n val_input = val_input.cuda()n val_label = val_label.cuda()n score = model(val_input)n confusion_matrix.add(score.data.squeeze(), label.long())nn # 把模型恢復為訓練模式n model.train()n n cm_value = confusion_matrix.value()n accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) /n (cm_value.sum())n return confusion_matrix, accuracyn

8.2 測試

測試時,需要計算每個樣本屬於狗的概率,並將結果保存成csv文件。測試的代碼與驗證比較相似,但需要自己載入模型和數據。

def test(**kwargs):n opt.parse(kwargs)n # 模型n model = getattr(models, opt.model)().eval()n if opt.load_model_path:n model.load(opt.load_model_path)n if opt.use_gpu: model.cuda()nn # 數據n train_data = DogCat(opt.test_data_root,test=True)n test_dataloader = DataLoader(train_data,n batch_size=opt.batch_size,n shuffle=False,n num_workers=opt.num_workers)n n results = []n for ii,(data,path) in enumerate(test_dataloader):n input = t.autograd.Variable(data,volatile = True)n if opt.use_gpu: input = input.cuda()n score = model(input)n probability = t.nn.functional.softmaxn (score)[:,1].data.tolist() n batch_results = [(path_,probability_) n for path_,probability_ in zip(path,probability) ]n results += batch_resultsn write_csv(results,opt.result_file)n return resultsn

8.4 幫助函數

為了方便他人使用, 程序中還應當提供一個幫助函數,用於說明函數是如何使用。程序的命令行介面中有眾多參數,如果手動用字元串表示不僅複雜,而且後期修改config文件時,還需要修改對應的幫助信息,十分不便。這裡使用了Python標準庫中的inspect方法,可以自動獲取config的源代碼。help的代碼如下:

def help():n n 列印幫助的信息: python file.py helpn n n print(n usage : python {0} <function> [--args=value,]n <function> := train | test | helpn example: n python {0} train --env=env0701 --lr=0.01n python {0} test --dataset=path/to/dataset/root/n python {0} helpn avaiable args:.format(__file__))nn from inspect import getsourcen source = (getsource(opt.__class__))n print(source)n

當用戶執行python main.py help的時候,會列印如下幫助信息:

usage : python main.py <function> [--args=value,]n <function> := train | test | helpn example: n python main.py train --env=env0701 --lr=0.01n python main.py test --dataset=path/to/dataset/n python main.py helpn avaiable args:nclass DefaultConfig(object):n env = default # visdom 環境n model = AlexNet # 使用的模型n n train_data_root = ./data/train/ # 訓練集存放路徑n test_data_root = ./data/test1 # 測試集存放路徑n load_model_path = checkpoints/model.pth # 載入預訓練的模型nn batch_size = 128 # batch sizen use_gpu = True # user GPU or notn num_workers = 4 # how many workers for loading datan print_freq = 20 # print info every N batchnn debug_file = /tmp/debug n result_file = result.csv # 結果文件n n max_epoch = 10n lr = 0.1 # initial learning raten lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decayn weight_decay = 1e-4 # 損失函數n

9 使用

正如help函數的列印信息所述,可以通過命令行參數指定變數名.下面是三個使用例子,fire會將包含-的命令行參數自動轉層下劃線_,也會將非數值的值轉成字元串。所以--train-data-root=data/train和--train_data_root=data/train是等價的

# 訓練模型npython main.py train n --train-data-root=data/train/ n --load-model-path=checkpoints/resnet34_16:53:00.pth n --lr=0.005 n --batch-size=32 n --model=ResNet34 n --max-epoch = 20nn# 測試模型npython main.py testn --test-data-root=data/test1 n --load-model-path=checkpoints/resnet34_00:23:05.pth n --batch-size=128 n --model=ResNet34 n --num-workers=12nn# 列印幫助信息npython main.py helpn

6.1.9 爭議

以上的程序設計規範帶有作者強烈的個人喜好,並不想作為一個標準,而是作為一個提議和一種參考。上述設計在很多地方還有待商榷,例如對於訓練過程是否應該封裝成一個trainer對象,或者直接封裝到BaiscModule的train方法之中。對命令行參數的處理也有不少值得討論之處。因此不要將本文中的觀點作為一個必須遵守的規範,而應該看作一個參考。

本章中的設計可能會引起不少爭議,其中比較值得商榷的部分主要有以下幾個方面:

第一個是命令行參數的設置。目前大多數程序都是使用Python標準庫中的argparse來處理命令行參數,也有些使用比較輕量級的click。這種處理相對來說對命令行的支持更完備,但根據作者的經驗來看,這種做法不夠直觀,並且代碼量相對來說也較多。比如argparse,每次增加一個命令行參數,都必須寫如下代碼

parser.add_argument(-save-interval, type=int,n default=500, n help=how many steps to wait before saving [default:500])n

在我眼中,這種實現方式遠不如一個專門的config.py來的直觀和易用。尤其是對於使用 Jupyter notebook或IPython等互動式調試的用戶來說,argparse較難使用。

第二個是模型訓練的方式。有不少人喜歡將模型的訓練過程集成於模型的定義之中,代碼結構如下所示:

class MyModel(nn.Module):n def __init__(self,opt):n self.dataloader = Dataloader(opt)n self.optimizer = optim.Adam(self.parameters(),lr=0.001)n self.lr = opt.lrn self.model = make_model()n n def forward(self,input):n passn n def train_(self):n # 訓練模型n for epoch in range(opt.max_epoch)n for ii,data in enumerate(self.dataloader):n self.train_step(data)n model.save()n n def train_step(self):n passn

抑或是專門設計一個Trainer對象,形如:

import heapqnfrom torch.autograd import Variablennclass Trainer(object):nn def __init__(self, model=None, criterion=None, optimizer=None, dataset=None):n self.model = modeln self.criterion = criterionn self.optimizer = optimizern self.dataset = datasetn self.iterations = 0nn def run(self, epochs=1):n for i in range(1, epochs + 1):n self.train()nn def train(self):n for i, data in enumerate(self.dataset, self.iterations + 1):n batch_input, batch_target = datan self.call_plugins(batch, i, batch_input, batch_target)n input_var = Variable(batch_input)n target_var = Variable(batch_target)n n plugin_data = [None, None]n n def closure():n batch_output = self.model(input_var)n loss = self.criterion(batch_output, target_var)n loss.backward()n if plugin_data[0] is None:n plugin_data[0] = batch_output.datan plugin_data[1] = loss.datan return lossn self.optimizer.zero_grad()n self.optimizer.step(closure)n n self.iterations += in

還有一些人喜歡模仿keras和scikit-learn的設計,設計一個fit介面。

對讀者來說,這些處理方式很難說哪個更好或更差,找到最適合自己的方法才是最好的。

本指南的配套代碼地址: chenyuntc/pytorch-best-practice

P.S. 知乎專欄寫作真難用,改了好幾次都沒保存成功!可能還有不少錯誤之處,多請指教!

推薦閱讀:

【筆記】Finding Tiny Faces
PyTorch中如何使用tensorboard可視化
深度學習入門該用PyTorch還是Keras?熱門公開課換框架背後的學問
python3.6.1及TensorFlow和PyTorch
知乎「看山杯」 奪冠記

TAG:PyTorch | 深度学习DeepLearning | 机器学习 |