標籤:

pytorch例子-實現卷積網路進行圖像分類

前言

本案例採用CIFAR10數據集。 - 採用torchvision載入、標準化、切分數據集 - 定義一個卷積網路 - 定義損失函數 - 訓練網路 - 測試網路

1. 載入並標準化CIFAR10數據

%matplotlib inlineimport torchimport torchvisionimport torchvision.transforms as transforms

將[0, 1]範圍的數據歸一化為[-1, 1]

transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 載入訓練數據trainset = torchvision.datasets.CIFAR10(root=./data, train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)# 載入測試數據集testset = torchvision.datasets.CIFAR10(root=./data, train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gzFiles already downloaded and verified

定義class

classes = (plane, car, bird, cat, deer, dog, frog, horse, ship, truck)

展示一些數據

import matplotlib.pyplot as pltimport numpy as npdef imshow(img): img = img / 2 + 0.5 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0)))# 隨機獲取一些圖像dataiter = iter(trainloader)images, labels = dataiter.next()# 展示圖像imshow(torchvision.utils.make_grid(images))# 列印出標籤print( .join(%5s % classes[labels[j]] for j in range(4)))

car deer truck frog

2. 定義一個卷積神經網路

  • 定義神經網路支持3個信道
  • torch.view: >返回一個新的tensor,但是不同的shape,

from torch.autograd import Variableimport torch.nn as nnimport torch.nn.functional as Fclass Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) # 二維轉一維 x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return xnet = Net()

3. 定義損失函數和優化函數

  • loss: cross-entropy
  • optimize: SGD

import torch.optim as optimcriterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 訓練網路

循環迭代數據,向網路中喂數據

# 迭代兩次for epoch in range(2): running_loss = 0. for i, data in enumerate(trainloader, 0): # 獲取輸入數據 inputs, labels = data # 綁定變數 inputs, labels = Variable(inputs), Variable(labels) # 初始化梯度參數 optimizer.zero_grad() # 向前傳播以及向後傳播 outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 列印出統計變數 running_loss += loss.data[0] if i % 2000 == 1999: print([%d, %5d] loss: %.3f%(epoch + 1, i+1, running_loss / 2000)) running_loss = 0.

[1, 2000] loss: 2.214[1, 4000] loss: 1.864[1, 6000] loss: 1.665[1, 8000] loss: 1.575[1, 10000] loss: 1.516[1, 12000] loss: 1.482[2, 2000] loss: 1.421[2, 4000] loss: 1.369[2, 6000] loss: 1.354[2, 8000] loss: 1.318[2, 10000] loss: 1.304[2, 12000] loss: 1.294

5. 測試網路的表現

(1)隨機選擇一些測試樣本

dataiter = iter(testloader)images, labels = dataiter.next()# 列印出圖像imshow(torchvision.utils.make_grid(images))print(真實標籤: , .join(%5s % classes[labels[j]] for j in range(4)))

真實標籤: cat ship ship plane

outputs = net(Variable(images))_, predicted = torch.max(outputs.data, 1) # 按照列的方向求每一行的最大值print(預測標籤: , .join(%5s % classes[predicted[j]] for j in range(4)))

預測標籤: ship ship car plane

(2)測試一下在全測試樣本上的表現

correct = 0total = 0for data in testloader: images, labels = data outputs = net(Variable(images)) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum()print(Accuracy of the network on the 10000 test images: %d %% % (100 * correct / total))

Accuracy of the network on the 10000 test images: 55 %

(3)看一下每個類別的分類準確率

class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))for data in testloader: images, labels = data outputs = net(Variable(images)) _, predicted = torch.max(outputs.data, 1) c = (predicted == labels).squeeze() for i in range(4): label = labels[i] class_correct[label] += c[i] class_total[label] += 1for i in range(10): print(Accuracy of %5s: %2d %% % (classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of plane: 54 %Accuracy of car: 65 %Accuracy of bird: 31 %Accuracy of cat: 33 %Accuracy of deer: 50 %Accuracy of dog: 46 %Accuracy of frog: 63 %Accuracy of horse: 70 %Accuracy of ship: 60 %Accuracy of truck: 77 %

推薦閱讀:

最近一點微小的工作
關於 PyTorch 0.3.0 在Windows下的安裝和使用
關於Windows PRs併入PyTorch的master分支
PyTorch教程+代碼:色塊秒變風景油畫
知乎「看山杯」 奪冠記

TAG:PyTorch |