pytorch例子-實現卷積網路進行圖像分類
前言
本案例採用CIFAR10數據集。 - 採用torchvision載入、標準化、切分數據集 - 定義一個卷積網路 - 定義損失函數 - 訓練網路 - 測試網路
1. 載入並標準化CIFAR10數據
%matplotlib inlineimport torchimport torchvisionimport torchvision.transforms as transforms
將[0, 1]範圍的數據歸一化為[-1, 1]
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 載入訓練數據trainset = torchvision.datasets.CIFAR10(root=./data, train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)# 載入測試數據集testset = torchvision.datasets.CIFAR10(root=./data, train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gzFiles already downloaded and verified
定義class
classes = (plane, car, bird, cat, deer, dog, frog, horse, ship, truck)
展示一些數據
import matplotlib.pyplot as pltimport numpy as npdef imshow(img): img = img / 2 + 0.5 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0)))# 隨機獲取一些圖像dataiter = iter(trainloader)images, labels = dataiter.next()# 展示圖像imshow(torchvision.utils.make_grid(images))# 列印出標籤print( .join(%5s % classes[labels[j]] for j in range(4)))
car deer truck frog
2. 定義一個卷積神經網路
- 定義神經網路支持3個信道
- torch.view: >返回一個新的tensor,但是不同的shape,
from torch.autograd import Variableimport torch.nn as nnimport torch.nn.functional as Fclass Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) # 二維轉一維 x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return xnet = Net()
3. 定義損失函數和優化函數
- loss: cross-entropy
- optimize: SGD
import torch.optim as optimcriterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4. 訓練網路
循環迭代數據,向網路中喂數據
# 迭代兩次for epoch in range(2): running_loss = 0. for i, data in enumerate(trainloader, 0): # 獲取輸入數據 inputs, labels = data # 綁定變數 inputs, labels = Variable(inputs), Variable(labels) # 初始化梯度參數 optimizer.zero_grad() # 向前傳播以及向後傳播 outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 列印出統計變數 running_loss += loss.data[0] if i % 2000 == 1999: print([%d, %5d] loss: %.3f%(epoch + 1, i+1, running_loss / 2000)) running_loss = 0.
[1, 2000] loss: 2.214[1, 4000] loss: 1.864[1, 6000] loss: 1.665[1, 8000] loss: 1.575[1, 10000] loss: 1.516[1, 12000] loss: 1.482[2, 2000] loss: 1.421[2, 4000] loss: 1.369[2, 6000] loss: 1.354[2, 8000] loss: 1.318[2, 10000] loss: 1.304[2, 12000] loss: 1.294
5. 測試網路的表現
(1)隨機選擇一些測試樣本
dataiter = iter(testloader)images, labels = dataiter.next()# 列印出圖像imshow(torchvision.utils.make_grid(images))print(真實標籤: , .join(%5s % classes[labels[j]] for j in range(4)))
真實標籤: cat ship ship plane
outputs = net(Variable(images))_, predicted = torch.max(outputs.data, 1) # 按照列的方向求每一行的最大值print(預測標籤: , .join(%5s % classes[predicted[j]] for j in range(4)))
預測標籤: ship ship car plane
(2)測試一下在全測試樣本上的表現
correct = 0total = 0for data in testloader: images, labels = data outputs = net(Variable(images)) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum()print(Accuracy of the network on the 10000 test images: %d %% % (100 * correct / total))
Accuracy of the network on the 10000 test images: 55 %
(3)看一下每個類別的分類準確率
class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))for data in testloader: images, labels = data outputs = net(Variable(images)) _, predicted = torch.max(outputs.data, 1) c = (predicted == labels).squeeze() for i in range(4): label = labels[i] class_correct[label] += c[i] class_total[label] += 1for i in range(10): print(Accuracy of %5s: %2d %% % (classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of plane: 54 %Accuracy of car: 65 %Accuracy of bird: 31 %Accuracy of cat: 33 %Accuracy of deer: 50 %Accuracy of dog: 46 %Accuracy of frog: 63 %Accuracy of horse: 70 %Accuracy of ship: 60 %Accuracy of truck: 77 %
推薦閱讀:
※最近一點微小的工作
※關於 PyTorch 0.3.0 在Windows下的安裝和使用
※關於Windows PRs併入PyTorch的master分支
※PyTorch教程+代碼:色塊秒變風景油畫
※知乎「看山杯」 奪冠記
TAG:PyTorch |