pytorch中nn和nn.functional有什麼區別?

如題,二者有很多相同函數,希望說明時能用代碼舉例。

謝謝!


其實這兩個是差不多的,不過一個包裝好的類,一個是可以直接調用的函數。我們可以去翻這兩個模塊的具體實現代碼,我下面以卷積Conv1d為例。

首先是torch.nn下的Conv1d:

class Conv1d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
super(Conv1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias)

def forward(self, input):
return F.conv1d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)

這是torch.nn.functional下的conv1d:

def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1,
groups=1):
if input is not None and input.dim() != 3:
raise ValueError("Expected 3D tensor as input, got {}D tensor instead.".format(input.dim()))

f = ConvNd(_single(stride), _single(padding), _single(dilation), False,
_single(0), groups, torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
return f(input, weight, bias)

可以看到torch.nn下的Conv1d類在forward時調用了nn.functional下的conv1d,當然最終的計算是通過C++編寫的THNN庫中的ConvNd進行計算的,因此這兩個其實是互相調用的關係。

你可能會疑惑為什麼需要這兩個功能如此相近的模塊,其實這麼設計是有其原因的。如果我們只保留nn.functional下的函數的話,在訓練或者使用時,我們就要手動去維護weight, bias, stride這些中間量的值,這顯然是給用戶帶來了不便。而如果我們只保留nn下的類的話,其實就犧牲了一部分靈活性,因為做一些簡單的計算都需要創造一個類,這也與PyTorch的風格不符。

下面我舉幾個例子方便各位理解:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(512, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 2)

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(F.dropout(self.fc2(x), 0.5))
x = F.dropout(self.fc3(x), 0.5)
return x

以一個最簡單的三層網路為例。需要維持狀態的,主要是三個線性變換,所以在構造Module是,定義了三個nn.Linear對象,而在計算時,relu,dropout之類不需要保存狀態的可以直接使用。


nn.functional中的都是沒有副作用的函數,也就是說function內部一定是沒有Variable的,只是純粹從輸入到輸出的一個變換

nn下面的可能有可能沒有,一般都是nn.Module的子類,可以藉助nn.Module的父方法方便的管理各種需要的變數


補充一點。nn.functional中的函數僅僅定義了一些具體的基本操作,不能構成PyTorch中的一個layer。當你需要自定義一些非標準layer時,可以在其中調用nn.functional中的操作。


推薦閱讀:

2017 年 8 月 6 日發布的 pytorch 0.2.0 哪個特性最吸引你?
如何有效地閱讀PyTorch的源代碼?

TAG:深度學習DeepLearning | PyTorch |