【pytorch】模型的搭建保存載入

使用pytorch進行網路模型的搭建、保存與載入,是非常快速、方便的、妙不可言的。

搭建ConvNet

所有的網路都要繼承torch.nn.Module,然後在構造函數中使用torch.nn中的提供的介面定義layer的屬性,最後,在forward函數中將各個layer連接起來。

下面,以LeNet為例:

class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) out = self.fc3(x) return out

這樣一來,我們就搭建好了網路模型,是不是很簡潔明了呢?此外,還可以使用torch.nn.Sequential,更方便進行模塊化的定義,如下:

class LeNetSeq(nn.Module): def __init__(self): super(LeNetSeq, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2), ) self.fc = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 10) ) def forward(self, x): x = self.conv(x) x = out.view(x.size(0), -1) out = self.fc(x) return out

Module有很多屬性,可以查看權重、參數等等;如下:

net = lenet.LeNet()print(net)for param in net.parameters(): print(type(param.data), param.size()) print(list(param.data)) print(net.state_dict().keys())#參數的keysfor key in net.state_dict():#模型參數 print key, corresponds to, list(net.state_dict()[key])

那麼,如何進行參數初始化呢?使用 torch.nn.init ,如下:

def initNetParams(net): Init net parameters. for m in net.modules(): if isinstance(m, nn.Conv2d): init.xavier_uniform(m.weight) if m.bias: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant(m.weight, 1) init.constant(m.bias, 0) elif isinstance(m, nn.Linear): init.normal(m.weight, std=1e-3) if m.bias: init.constant(m.bias, 0)initNetParams(net)

保存ConvNet

使用torch.save()對網路結構和模型參數的保存,有兩種保存方式:

  • 保存整個神經網路的的結構信息和模型參數信息,save的對象是網路net;
  • 保存神經網路的訓練模型參數,save的對象是net.state_dict()。

torch.save(net1, net.pkl) # 保存整個神經網路的結構和模型參數 torch.save(net1.state_dict(), net_params.pkl) # 只保存神經網路的模型參數

載入ConvNet

對應上面兩種保存方式,重載方式也有兩種。

  • 對應第一種完整網路結構信息,重載的時候通過torch.load(『.pth』)直接初始化新的神經網路對象即可。
  • 對應第二種只保存模型參數信息,需要首先導入對應的網路,通過net.load_state_dict(torch.load(.pth))完成模型參數的重載。

在網路比較大的時候,第一種方法會花費較多的時間,所佔的存儲空間也比較大。

# 保存和載入整個模型 torch.save(model_object, model.pth) model = torch.load(model.pth) # 僅保存和載入模型參數 torch.save(model_object.state_dict(), params.pth) model_object.load_state_dict(torch.load(params.pth))

相關代碼可以查看:tfygg/pytorch-tutorials

---------------------------------------------------------------------------------------------------------------------------

在各方小夥伴的努力和支持下,pytorch中文文檔第一版終於上線啦!!!(鼓掌)文檔還有很多小瑕疵,但是大體可以放心使用了~我們遵循快速迭代的原則,所以趕緊上線第一版來接受廣大開源社區的意見和建議。歡迎加入我們!

pytorch中文文檔:pytorch-cn.readthedocs.io

github項目地址:awfssv/pytorch-cn

中文翻譯組QQ群:628478868

還有pytorch項目交流群:613523596

歡迎關注!


推薦閱讀:

TAG:深度學習DeepLearning | PyTorch |