DeepLearning-風格遷移

背景介紹

n

不知道大家是否用過prisma,就算沒有用過,也一定看見別人用過這個軟體,下面是一張這個軟體得到的一個效果圖

官方宣傳的賣點是一秒鐘讓你的作品擁有名家風格,什麼畢加索,梵高,都不在話下。通過這個效果再將你的照片發到朋友圈,是不是效果爆棚,簡直是各種裝逼界的一股清流,秒殺各種修圖ps好嗎。而且可以完美的掩飾掉一些瑕疵,又比ps更自然,更有逼格,是不是很棒。

n

這個軟體將使用的方法發了一篇論文,並且這個軟體在發布的時候就取得了上千萬的融資,是不是瞬間感覺現在學習了知識也能成為千萬富豪了。如今在這個高速發展的時代,知識付費的時代確實已經到來了,所以我們現在努力學習各種知識就是在賺錢啊,有木有。這樣大家的學習的時候就能夠有著更大的動力了。

n

這篇論文感興趣的同學可以去查看一下,裡面主要涉及的是卷積神經網路。今天這篇文章要做的是什麼呢?我們希望自己能夠簡單的實現這個風格遷移演算法,並且用自己的演算法來得到新的風格圖片。一想到我們放到朋友圈的照片是自己寫的演算法來實現的就感覺成就感爆棚,有沒有。

n

環境配置

n

廢話不多說,我們先來看看需要的基本配置。首先需要python環境,建議使用anaconda;然後我們使用的深度學習框架是pytorch,當然你也可以用tensorflow,具體框架的介紹可以去看看之前寫的文章,需要安裝pytorch和torchvision,這裡查看安裝幫助;同時需要一些其他的包,如果缺什麼就pip安裝就好。

n

這篇文章主要參考於pytorch的官方tutorial,感興趣的同學可以直接移步至官方教程的地方,這篇文章我會說一些自己的理解,代碼部分基本都是參考這個教程,但是我會做一些說明,力求更加清楚。

n

原理分析

n

其實要實現的東西很清晰,就是需要將兩張圖片融合在一起,這個時候就需要定義怎麼才算融合在一起。首先需要的就是內容上是相近的,然後風格上是相似的。這樣來我們就知道我們需要做的事情是什麼了,我們需要計算融合圖片和內容圖片的相似度,或者說差異性,然後儘可能降低這個差異性;同時我們也需要計算融合圖片和風格圖片在風格上的差異性,然後也降低這個差異性就可以了。這樣我們就能夠量化我們的目標了。

n

對於內容的差異性我們該如何定義呢?其實我們能夠很簡答的想到就是兩張圖片每個像素點進行比較,也就是求一下差,因為簡單的計算他們之間的差會有正負,所以我們可以加一個平方,使得差全部是正的,也可以加絕對值,但是數學上絕對值會破壞函數的可微性,所以大家都用平方,這個地方不理解也沒關係,記住普遍都是使用平方就行了。

n

對於風格的差異性我們該如何定義呢?這才是一個難點。這也是這篇文章提出的創新點,引入了Gram矩陣計算風格的差異。我盡量不使用數學的語言來解釋,而使用通俗的語言。

首先需要的預先知識是卷積網路的知識,這裡不細講了,不了解的同學可以看之前的卷積網路的文章。我們知道一張圖片通過卷積網路之後可以的到一個特徵圖,Gram矩陣就是在這個特徵圖上面定義出來的。每個特徵圖的大小一般是 MxNxC 或者是 CxMxN 這種大小,這裡C表示的時候厚度,放在前面和後面都可以,MxN 表示的是一個矩陣的大小,其實就是有 C 個 MxN 這樣的矩陣疊在一起。

n

Gram矩陣是如何定義的呢?首先Gram矩陣的大小是有特徵圖的厚度決定的,等於 CxC,那麼每一個Gram矩陣的元素,也就是 Gram(i, j) 等於多少呢?先把特徵圖中第 i 層和第 j 層取出來,這樣就得到了兩個 MxN的矩陣,然後將這兩個矩陣對應元素相乘然後求和就得到了 Gram(i, j),同理 Gram 的所有元素都可以通過這個方式得到。這樣 Gram 中每個元素都可以表示兩層特徵圖的一種組合,就可以定義為它的風格。

n

然後風格的差異就是兩幅圖的 Gram 矩陣的差異,就像內容的差異的計算方法一樣,計算一下這兩個矩陣的差就可以量化風格的差異。

n

實現

n

以下的內容都是用pytorch實現的,如果對pytorch不熟悉的同學可以看一下我之前的pytorch介紹文章,看看官方教程,如果不想了解pytorch的同學可以用自己熟悉的框架實現這個演算法,理論部分前面已經講完了。

n

內容差異的loss定義

n

class Content_Loss(nn.Module):n def __init__(self, target, weight):n super(Content_Loss, self).__init__()n self.weight = weightn self.target = target.detach() * self.weightn # 必須要用detach來分離出target,這時候target不再是一個Variable,這是為了動態計算梯度,否則forward會出錯,不能向前傳播n self.criterion = nn.MSELoss()nn def forward(self, input):n self.loss = self.criterion(input * self.weight, self.target)n out = input.clone()n return outnn def backward(self, retain_variabels=True):n self.loss.backward(retain_variables=retain_variabels)n return self.lossn

其中有個變數weight,這個是表示權重,內容和風格你可以選擇一個權重,比如你想風格上更像,內容上多一點差別沒關係,那麼內容的權重你可以定義小一點,風格的權重可以定義大一點;反之你可以把風格的權重定義小一點,內容的權重定義大一點。

n

風格差異的loss定義

n

Gram 矩陣的定義

n

class Gram(nn.Module):n def __init__(self):n super(Gram, self).__init__()nn def forward(self, input):n a, b, c, d = input.size()n feature = input.view(a * b, c * d)n gram = torch.mm(feature, feature.t())n gram /= (a * b * c * d)n return gramn

n

style loss定義

n

class Style_Loss(nn.Module):n def __init__(self, target, weight):n super(Style_Loss, self).__init__()n self.weight = weightn self.target = target.detach() * self.weightn self.gram = Gram()n self.criterion = nn.MSELoss()nn def forward(self, input):n G = self.gram(input) * self.weightn self.loss = self.criterion(G, self.target)n out = input.clone()n return outnn def backward(self, retain_variabels=True):n self.loss.backward(retain_variables=retain_variabels)n return self.lossn

n

建立模型

n

使用19層的 vgg 作為提取特徵的卷積網路,並且定義哪幾層為需要的特徵。

n

vgg = models.vgg19(pretrained=True).featuresnvgg = vgg.cuda()nncontent_layers_default = [conv_4]nstyle_layers_default = [conv_1, conv_2, conv_3, conv_4, conv_5]nndef get_style_model_and_loss(style_img, content_img, cnn=vgg,n style_weight=1000,n content_weight=1,n content_layers=content_layers_default,n style_layers=style_layers_default):nn content_loss_list = []n style_loss_list = []nn model = nn.Sequential()n model = model.cuda()n gram = loss.Gram()n gram = gram.cuda()nn i = 1n for layer in cnn:n if isinstance(layer, nn.Conv2d):n name = conv_ + str(i)n model.add_module(name, layer)nn if name in content_layers_default:n target = model(content_img)n content_loss = loss.Content_Loss(target, content_weight)n model.add_module(content_loss_ + str(i), content_loss)n content_loss_list.append(content_loss)nn if name in style_layers_default:n target = model(style_img)n target = gram(target)n style_loss = loss.Style_Loss(target, style_weight)n model.add_module(style_loss_ + str(i), style_loss)n style_loss_list.append(style_loss)nn i += 1n if isinstance(layer, nn.MaxPool2d):n name = pool_ + str(i)n model.add_module(name, layer)nn if isinstance(layer, nn.ReLU):n name = relu + str(i)n model.add_module(name, layer)nn return model, style_loss_list, content_loss_listn

n

訓練模型

n

def get_input_param_optimier(input_img):n """n input_img is a Variablen """n input_param = nn.Parameter(input_img.data)n optimizer = optim.LBFGS([input_param])n return input_param, optimizernndef run_style_transfer(content_img, style_img, input_img,n num_epoches=300):n print(Building the style transfer model..)n model, style_loss_list, content_loss_list = get_style_model_and_loss(n style_img, content_imgn )n input_param, optimizer = get_input_param_optimier(input_img)nn print(Opimizing...)n epoch = [0]n while epoch[0] < num_epoches:nn def closure():n input_param.data.clamp_(0, 1)nn model(input_param)n style_score = 0n content_score = 0nn optimizer.zero_grad()n for sl in style_loss_list:n style_score += sl.backward()n for cl in content_loss_list:n content_score += cl.backward()nn epoch[0] += 1n if epoch[0] % 50 == 0:n print(run {}.format(epoch))n print(Style Loss: {:.4f} Content Loss: {:.4f}.format(n style_score.data[0], content_score.data[0]n ))n print()nn return style_score + content_scorenn optimizer.step(closure)nn input_param.data.clamp_(0, 1)nn return input_param.datan

n

需要特別注意的是這個模型裡面參數不再是網路裡面的參數,因為網路使用的是已經預訓練好的 vgg 網路,這個演算法裡面的參數是合成圖片裡面的每個像素點,我們可以將內容圖片直接 copy 成合成圖片,然後訓練使得他的風格和我們的風格圖片相似,同時也可以隨機化一張圖片作為合成圖片,然後訓練他使得他與內容圖片以及風格圖片具有相似性。

n

實驗結果

n

我們使用的風格圖片為

內容圖片為

得到的合成效果為

結語

n

通過這篇文章,我們利用pytorch實現了基本的風格轉移演算法,得到的效果也是滿意的,所以我們可以把自己的圖片通過這個演算法做一個風格轉移,實現你想要的作品的風格,逼格滿滿,大家學習之後肯定會有特別大的成就感,在完成項目的同時也學習到了新的知識,同時也會對這個產生更濃厚的感興趣,興趣才是各種的動力,比任何雞湯都有用,希望大家都能夠找到自己的興趣,熱愛自己所做的事。

n

本文代碼已經上傳到了github上

n

歡迎查看我的知乎專欄,深度煉丹

n

歡迎訪問我的博客


推薦閱讀:

Pytorch如何更新版本與卸載
Pytorch源碼與運行原理淺析--網路篇(一)
深度學習新手必學:使用 Pytorch 搭建一個 N-Gram 模型
項目推薦:自然場景中文文字檢測和不定長中文OCR識別
PyTorch實戰指南

TAG:深度学习DeepLearning | PyTorch | 机器学习 |