Pytorch如何自定義損失函數(Loss Function)?

在Stack Overflow中看到了類似的問題Custom loss function in PyTorch ,回答中說自定義的Loss Function 應繼承 _Loss 具體如何實現還是不太明白,知友們有沒有自定義過Loss Function呢? 如果我在loss function中要用到torch.svd(),還需要實現操作呢?謝謝!


謝邀,沒看到問題描述,那我就用自己的例子好了。

簡而言之,有三種方法。複雜程度依次遞增,計算效率也是依次遞增23333

1. 直接利用torch.Tensor提供的介面:

因為只是需要自定義loss,而loss可以看做對一個或多個Tensor的混合計算,比如計算一個三元組的Loss(Triplet Loss),我們只需要如下操作:(假設輸入的三個(anchor, positive, negative)張量維度是 batch_size * 400&<即triplet(net的輸出)&>)

import torch
import torch.nn as nn
import torch.nn.functional as func
class TripletLossFunc(nn.Module):
def __init__(self, t1, t2, beta):
super(TripletLossFunc, self).__init__()
self.t1 = t1
self.t2 = t2
self.beta = beta
return

def forward(self, anchor, positive, negative):
matched = torch.pow(func.pairwise_distance(anchor, positive), 2)
mismatched = torch.pow(func.pairwise_distance(anchor, negative), 2)
part_1 = torch.clamp(matched - mismatched, min=self.t1)
part_2 = torch.clamp(matched, min=self.t2)
dist_hinge = part_1 + self.beta * part_2
loss = torch.mean(dist_hinge)
return loss

如圖所示,在__init__()中定義超參數,在forward()中定義計算過程就可以了,全程使用torch提供的張量計算介面(道理上同樣可以使用numpy和scipy的,不過感覺效率會低一點),該方法可調用cuda(僅限僅使用了torch介面或者python內建方法),(即你可以直接使用實例化對象的.cuda()方法)

因為繼承了nn.Module,所以這個Loss類在實例化之後可以直接運行__call__()方法,也就是

a = TripletLossFunc(...)
loss = a(anchor, positive, negative)

就可以了。這是第一種方法。

2. 利用PyTorch的numpy/scipy擴展

如果你細心的話你會注意到我在上面使用了torch.nn.functional模塊的函數,那麼,問題來了,萬一需要的計算不在這個模塊中怎麼辦?

這個時候,如果這個操作可以用numpy/scipy實現,那麼,就需要寫一個numpy/scipy擴展

官網教程在此 (官網教程是自定義一個快速傅里葉變換在網路中,我們也可以定義操作然後用在loss中)

你需要做的操作其實只多了一步:

import torch
from torch.autograd import Function
from torch.autograd import Variable
class OwnOp(Function):
def forward(input_tensor):
tensor = input_tensor.numpy()
...... # 其他的numpy/scipy操作
result = ......
return torch.Tensor(result)

def backward(grad_output):
# 如果你只是需要在loss中應用這個操作的時候,這裡直接return輸入就可以了
# 如果你需要在nn中用到這個,需要寫明具體的反向傳播操作,具體跟forward的形式差不多
return grad_output

注意,你只需要定義forward()和backward()兩個方法就可以了,務必需要先調用輸入的.numpy()方法,返回需要把返回值變成torch.Tensor。

寫到這裡,基本滿足大部分需求了,但是,有了另外一個問題,如果我需要計算的東西很多(比如需要涉及到像素級別的計算)或者很複雜,或者numpy/scipy中沒有這些操作怎麼辦?

恩,那就只有最後一種方法了,不過需要你有一定的C語言基礎和會使用CUDA編程(據傳MSRA很多寫CUDA很熟練的神)

3. 寫一個PyTorch的C擴展

恩。。。。最近再被這個玩意折騰,還在學cuda23333,對於這個,我先給個官網的教程

PyTorch C擴展

以及某大神寫的一個roi_pooling的C擴展 ROI

具體的話,需要你先定義最基本的C/CUDA運算

/* triplet_cal.c */
#include &
#include & int triplet_cal_forward(...)
{
// 你的計算代碼
}
int triplet_cal_backward(...)
{
// 你的計算代碼
}

/* triplet_cal.h */
int triplet_cal_forward(...);
int triplet_cal_backward(...);

注意,這裡的文件名必須跟模塊名相同,比如你的模塊名是triplet_cal,那文件名就如上。

然後forward,backward那兩個函數名也必須遵照這個格式。

因為PyTorch自己寫了一個Parser用來解析頭文件,從而進行相關的運算

cuda同理,也需要定義triplet_cal_cuda.c和triplet_cal_cuda.h

cuda需要額外定義cuda運算

/* triplet_cal_kernel.cu */
#ifdef __cplusplus
extern "C" {
#endif

#include &
#include & #include &
#include "triplet_cal_kernel.h"
}
/*
我還不會CUDA23333
現在還在學CUDA的語法
恩所以我也不知道怎麼寫23333
*/

然後,你需要定義build.py,用來註冊這個擴展,使它被PyTorch接受(我自己的擴展還沒寫到這一步,所以我把roi_pooling的拿過來了23333,這個模塊名就叫做roi_pooling)

import os
import torch
from torch.utils.ffi import create_extension

sources = ["src/roi_pooling.c"]
headers = ["src/roi_pooling.h"]
defines = []
with_cuda = False

if torch.cuda.is_available():
print("Including CUDA code.")
sources += ["src/roi_pooling_cuda.c"]
headers += ["src/roi_pooling_cuda.h"]
defines += [("WITH_CUDA", None)]
with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))
print(this_file)
extra_objects = ["src/cuda/roi_pooling.cu.o"]
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]

ffi = create_extension(
"_ext.roi_pooling",
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
extra_objects=extra_objects
)

if __name__ == "__main__":
ffi.build()

之後,要做的跟2就差不多了,調用就可以了,之前只需要

from _ext import roi_pooling # 如果你寫的是roi_pooling

然後寫一個類(跟方法2中的一樣,forward和backward中調用roi_pooling就好)

以上,該回答之後會發在我的專欄中


從 code 裡面可以看到loss 函數部分繼承自_loss, 部分繼承自_WeightedLoss, 而_WeightedLoss繼承自_loss, _loss繼承自 nn.Module.

與定義一個新的模型類相同,定義一個新的loss function 你只需要繼承nn.Module就可以了

一個 pytorch 常見問題的 jupyter notebook 鏈接為A-Collection-of-important-tasks-in-pytorch

import torch.nn as nn
Class NewLoss(nn.Module):
def __init__():
pass

def forward():
pass


推薦閱讀:

Krizhevsky等人是怎麼想到在CNN里用Dropout和ReLu的?
卷積神經網路提取圖像特徵時具有旋轉不變性嗎?
請問各位大大現在的放療計劃系統在做自動化計劃時用的是神經網路嗎?

TAG:Python | 機器學習 | 深度學習DeepLearning | 卷積神經網路CNN | PyTorch |