教你用PyTorch實現「看圖說話」(附代碼、學習資源)
本文共2200字,建議閱讀10分鐘。
本文用淺顯易懂的方式解釋了什麼是「看圖說話」(Image Captioning),藉助github上的PyTorch代碼帶領大家自己做一個模型,並附帶了很多相關的學習資源。
介紹
深度學習目前是一個非常活躍的領域---每天都會有許多應用出現。進一步學習Deep Learning最好的方法就是親自動手。儘可能多的接觸項目並且嘗試自己去做。這將會幫助你更深刻地掌握各個主題,成為一名更好的Deep Learning實踐者。
這篇文章將和大家一起看一個有趣的多模態主題,我們將結合圖像和文本處理技術來構建一個有用的深度學習應用,即看圖說話(Image Captioning)。看圖說話是指從一個圖像中基於其中的對象和動作生成文本描述的過程。例如:
這種過程在現實生活中有很多潛在的應用場景。一個明顯的應用比如保存圖片的描述字幕,以便該圖片隨後可以根據這個描述輕鬆地被檢索出來。
我們開始吧!
注意: 本文假定你了解深度學習的基礎知識,以前曾使用CNN處理過圖像問題。如果想複習這些概念,可以先閱讀下面的文章:
- Fundamentals of Deep Learning – Starting with Artificial Neural Network
- Architecture of Convolutional Neural Networks (CNNs) demystified
- Tutorial: Optimizing Neural Networks using Keras (with Image recognition case study)
- Essentials of Deep Learning – Sequence to Sequence modelling with Attention (using python)
目錄
- 什麼是Image Captioning問題?
- 解決任務的方法
- 應用演練
- 下一步工作
什麼是Image Captioning問題?
設想你看到了這張圖:
你首先想到的是什麼?下面是一些人們可以想到的句子:
A man and a girl sit on the ground and eat . (一個男人和一個女孩坐在地上吃東西)
A man and a little girl are sitting on a sidewalk near a blue bag eating . (一個男人和一個小女孩坐在藍色包旁邊的人行道上吃東西)A man wearing a black shirt and a little girl wearing an orange dress share a treat .(一個穿黑色襯衣的男人和一個穿橘色連衣裙的小女孩分享美食)
快速看一眼就足以讓你理解和描述圖片中發生的事情。從一個人造系統中自動生成這種文字描述就是Image Captioning的任務。
該任務很明確,即產生的輸出是用一句話來描述這幅圖片中的內容---存在的對象,屬性,正在發生的動作以及對象之間的互動等。但是與其他圖像處理問題一樣,在人造系統中再現這種行為也是一項艱巨的任務。因此需要使用像Deep Learning這樣先進複雜的技術來解決該任務。
在繼續下文之前,我想特別感謝Andrej Kartpathy等學者,他們富有洞察力的課程CS231n幫助我理解了這個主題。
解決任務的方法
可以把image captioning任務在邏輯上分為兩個模塊——一個是基於圖像的模型,從圖像中提取特徵和細微的差別, 另一個是基於語言的模型,將第一個模型給出的特徵和對象翻譯成自然的語句。
對於基於圖像的模型而言(即編碼器)我們通常依靠CNN網路。對於基於語言的模型而言(即解碼器),我們依賴RNN網路。下圖總結了前面提到的方法:
通常,一個預先訓練好的CNN網路從輸入圖像中提取特徵。特徵向量被線性轉換成與RNN/LSTM網路的輸入具有相同的維度。這個網路被訓練作為我們特徵向量的語言模型。
為了訓練LSTM模型,我們預先定義了標籤和目標文本。比如,如果字幕是A man and a girl sit on the ground and eat .(一個男人和一個女孩坐在地上吃東西),則我們的標籤和目標文本如下:
這樣做是為了讓模型理解我們標記序列的開始和結束。
具體實現案例
讓我們看一個Pytorch中image captioning的簡單實現。我們將以一幅圖作為輸入,然後使用深度學習模型來預測它的描述。
例子的代碼可以在GitHub上找到。代碼的原始作者是Yunjey Choi 向他傑出的pytorch例子致敬。
在本例中,一個預先訓練好的ResNet-152被用作編碼器,而解碼器是一個LSTM網路。
要運行本例中的代碼,你需要安裝必備軟體,確保有一個可以工作的python環境,最好使用anaconda。然後運行以下命令來安裝其他所需要的庫。
git clone https://github.com/pdollar/coco.git
cd coco/PythonAPI/makepython setup.py buildpython setup.py installcd ../../git clone https://github.com/yunjey/pytorch-tutorial.gitcd pytorch-tutorial/tutorials/03-advanced/image_captioning/
pip install -r requirements.txt
設置完系統後,就該下載所需的數據集並且訓練模型了。這裡我們使用的是MS-COCO數據集。可以運行如下命令來自動下載數據集:
chmod +x download.sh
./download.sh
現在可以繼續並開始模型的構建過程了。首先,你需要處理輸入:
# Search for all the possible words in the dataset and
# build a vocabulary listpython build_vocab.py # resize all the images to bring them to shape 224x224python resize.py
現在,運行下面的命令來訓練模型:
python train.py --num_epochs 10 --learning_rate 0.01
來看一下被封裝好的代碼中是如何定義模型的,可以在model.py文件中找到:
import torch
import torch.nn as nnimport torchvision.models as modelsfrom torch.nn.utils.rnn import pack_padded_sequencefrom torch.autograd import Variableclass EncoderCNN(nn.Module): def __init__(self, embed_size): """Load the pretrained ResNet-152 and replace top fc layer.""" super(EncoderCNN, self).__init__() resnet = models.resnet152(pretrained=True)modules = list(resnet.children())[:-1] # delete the last fc layer.
self.resnet = nn.Sequential(*modules) self.linear = nn.Linear(resnet.fc.in_features, embed_size) self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) self.init_weights() def init_weights(self): """Initialize the weights.""" self.linear.weight.data.normal_(0.0, 0.02) self.linear.bias.data.fill_(0) def forward(self, images):"""Extract the image feature vectors."""
features = self.resnet(images) features = Variable(features.data) features = features.view(features.size(0), -1) features = self.bn(self.linear(features)) return featuresclass DecoderRNN(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, num_layers): """Set the hyper-parameters and build the layers.""" super(DecoderRNN, self).__init__()self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, vocab_size) self.init_weights() def init_weights(self): """Initialize weights.""" self.embed.weight.data.uniform_(-0.1, 0.1) self.linear.weight.data.uniform_(-0.1, 0.1) self.linear.bias.data.fill_(0) def forward(self, features, captions, lengths):"""Decode image feature vectors and generates captions."""
embeddings = self.embed(captions) embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) packed = pack_padded_sequence(embeddings, lengths, batch_first=True) hiddens, _ = self.lstm(packed) outputs = self.linear(hiddens[0]) return outputs def sample(self, features, states=None): """Samples captions for given image features (Greedy search).""" sampled_ids = [] inputs = features.unsqueeze(1) for i in range(20): # maximum sampling length hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size), outputs = self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size) predicted = outputs.max(1)[1] sampled_ids.append(predicted) inputs = self.embed(predicted) inputs = inputs.unsqueeze(1) # (batch_size, 1, embed_size) sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20) return sampled_ids.squeeze()
現在測試我們的模型:
python sample.py --image=png/example.png
對於樣例圖片,我們的模型給出了這樣的輸出:
<start> a group of giraffes standing in a grassy area . <end>
<start>一群長頸鹿站在草地上<end>
以上就是如何建立一個用於image captioning的深度學習模型。
下一步工作
以上模型只是冰山一角。關於這個主題已經有很多的研究。目前在image captioning領域最先進的模型是微軟的CaptionBot。可以在他們的官網上看一個系統的demo.
我列舉一些可以用來構建更好的image captioning模型的想法:
- 加入更多數據 當然這也是深度學習模型通常的趨勢。提供的數據越多,模型效果越好。可以在這裡找到其他的數據集: http://www.cs.toronto.edu/~fidler/slides/2017/CSC2539/Kaustav_slides.pdf
- 使用Attention模型 正如這篇文章所述(Essentials of Deep Learning – Sequence to Sequence modelling with Attention), 使用attention模型有助於微調模型的性能
- 轉向更大更好的技術 研究人員一直在研究一些技術,比如使用強化學習來構建端到端的深度學習系統,或者使用新穎的attention模型用於「視覺哨兵(visual sentinel)」。
結語
這篇文章中,我介紹了image captioning,這是一個多模態任務,它由解密圖片和用自然語句描述圖片兩部分組成。然後我解釋了解決該任務用到的方法並給出了一個應用演練。 對於好奇心強的讀者,我還列舉了幾條可以改進模型性能的方法。
希望這篇文章可以激勵你去發現更多可以用深度學習解決的任務,從而在工業中出現越來越多的突破和創新。如果有任何建議/反饋,歡迎在下面的評論中留言!
原文標題:Automatic Image Captioning using Deep Learning (CNN and LSTM) in PyTorch
原文鏈接:
https://www.analyticsvidhya.com/blog/2018/04/solving-an-image-captioning-task-using-deep-learning/作者:FAIZAN SHAIKH
翻譯:和中華
推薦閱讀:
※白話TensorFlow +實戰系列(一)詳解Tensor與Flow
※iOS開發迎來機器學習的春天--TensorFlow
※深度學習對話系統實戰篇--新版本chatbot代碼實現
※CS 20SI Lecture1: Overview of TensorFlow
※tf.set_random_seed
TAG:TensorFlow | 深度學習DeepLearning | Keras |