你必須要知道CNN模型:ResNet

歡迎交流與轉載,文章會同步發布在公眾號:機器學習演算法全棧工程師(Jeemy110)

引言

深度殘差網路(Deep residual network, ResNet)的提出是CNN圖像史上的一件里程碑事件,讓我們先看一下ResNet在ILSVRC和COCO 2015上的戰績:

圖1 ResNet在ILSVRC和COCO 2015上的戰績

ResNet取得了5項第一,並又一次刷新了CNN模型在ImageNet上的歷史:

圖2 ImageNet分類Top-5誤差

ResNet的作者何凱明也因此摘得CVPR2016最佳論文獎,當然何博士的成就遠不止於此,感興趣的可以去搜一下他後來的輝煌戰績。那麼ResNet為什麼會有如此優異的表現呢?其實ResNet是解決了深度CNN模型難訓練的問題,從圖2中可以看到14年的VGG才19層,而15年的ResNet多達152層,這在網路深度完全不是一個量級上,所以如果是第一眼看這個圖的話,肯定會覺得ResNet是靠深度取勝。事實當然是這樣,但是ResNet還有架構上的trick,這才使得網路的深度發揮出作用,這個trick就是殘差學習(Residual learning)。下面詳細講述ResNet的理論及實現。

深度網路的退化問題

從經驗來看,網路的深度對模型的性能至關重要,當增加網路層數後,網路可以進行更加複雜的特徵模式的提取,所以當模型更深時理論上可以取得更好的結果,從圖2中也可以看出網路越深而效果越好的一個實踐證據。但是更深的網路其性能一定會更好嗎?實驗發現深度網路出現了退化問題(Degradation problem):網路深度增加時,網路準確度出現飽和,甚至出現下降。這個現象可以在圖3中直觀看出來:56層的網路比20層網路效果還要差。這不會是過擬合問題,因為56層網路的訓練誤差同樣高。我們知道深層網路存在著梯度消失或者爆炸的問題,這使得深度學習模型很難訓練。但是現在已經存在一些技術手段如BatchNorm來緩解這個問題。因此,出現深度網路的退化問題是非常令人詫異的。

圖3 20層與56層網路在CIFAR-10上的誤差

殘差學習

深度網路的退化問題至少說明深度網路不容易訓練。但是我們考慮這樣一個事實:現在你有一個淺層網路,你想通過向上堆積新層來建立深層網路,一個極端情況是這些增加的層什麼也不學習,僅僅複製淺層網路的特徵,即這樣新層是恆等映射(Identity mapping)。在這種情況下,深層網路應該至少和淺層網路性能一樣,也不應該出現退化現象。好吧,你不得不承認肯定是目前的訓練方法有問題,才使得深層網路很難去找到一個好的參數。

這個有趣的假設讓何博士靈感爆發,他提出了殘差學習來解決退化問題。對於一個堆積層結構(幾層堆積而成)當輸入為 x 時其學習到的特徵記為 H(x) ,現在我們希望其可以學習到殘差 F(x)=H(x)-x ,這樣其實原始的學習特徵是 F(x)+x 。之所以這樣是因為殘差學習相比原始特徵直接學習更容易。當殘差為0時,此時堆積層僅僅做了恆等映射,至少網路性能不會下降,實際上殘差不會為0,這也會使得堆積層在輸入特徵基礎上學習到新的特徵,從而擁有更好的性能。殘差學習的結構如圖4所示。這有點類似與電路中的「短路」,所以是一種短路連接(shortcut connection)。

圖4 殘差學習單元

為什麼殘差學習相對更容易,從直觀上看殘差學習需要學習的內容少,因為殘差一般會比較小,學習難度小點。不過我們可以從數學的角度來分析這個問題,首先殘差單元可以表示為:

begin{align} & {{y}_{l}}=h({{x}_{l}})+F({{x}_{l}},{{W}_{l}})  & {{x}_{l+1}}=f({{y}_{l}})  end{align}

其中 x_{l}x_{l+1} 分別表示的是第 l 個殘差單元的輸入和輸出,注意每個殘差單元一般包含多層結構。 F 是殘差函數,表示學習到的殘差,而 h(x_{l})=x_{l} 表示恆等映射, f 是ReLU激活函數。基於上式,我們求得從淺層 l 到深層 L 的學習特徵為:

{{x}_{L}}={{x}_{l}}+sumlimits_{i-l}^{L-1}{F({{x}_{i}}},{{W}_{i}})

利用鏈式規則,可以求得反向過程的梯度:

frac{partial loss}{partial {{x}_{l}}}=frac{partial loss}{partial {{x}_{L}}}cdot frac{partial {{x}_{L}}}{partial {{x}_{l}}}=frac{partial loss}{partial {{x}_{L}}}cdot left( 1+frac{partial }{partial {{x}_{L}}}sumlimits_{i=l}^{L-1}{F({{x}_{i}},{{W}_{i}})} right)

式子的第一個因子 frac{partial loss}{partial {{x}_{L}}} 表示的損失函數到達 L 的梯度,小括弧中的1表明短路機制可以無損地傳播梯度,而另外一項殘差梯度則需要經過帶有weights的層,梯度不是直接傳遞過來的。殘差梯度不會那麼巧全為-1,而且就算其比較小,有1的存在也不會導致梯度消失。所以殘差學習會更容易。要注意上面的推導並不是嚴格的證明。

ResNet的網路結構

ResNet網路是參考了VGG19網路,在其基礎上進行了修改,並通過短路機制加入了殘差單元,如圖5所示。變化主要體現在ResNet直接使用stride=2的卷積做下採樣,並且用global average pool層替換了全連接層。ResNet的一個重要設計原則是:當feature map大小降低一半時,feature map的數量增加一倍,這保持了網路層的複雜度。從圖5中可以看到,ResNet相比普通網路每兩層間增加了短路機制,這就形成了殘差學習,其中虛線表示feature map數量發生了改變。圖5展示的34-layer的ResNet,還可以構建更深的網路如表1所示。從表中可以看到,對於18-layer和34-layer的ResNet,其進行的兩層間的殘差學習,當網路更深時,其進行的是三層間的殘差學習,三層卷積核分別是1x1,3x3和1x1,一個值得注意的是隱含層的feature map數量是比較小的,並且是輸出feature map數量的1/4。

圖5 ResNet網路結構圖

表1 不同深度的ResNet

下面我們再分析一下殘差單元,ResNet使用兩種殘差單元,如圖6所示。左圖對應的是淺層網路,而右圖對應的是深層網路。對於短路連接,當輸入和輸出維度一致時,可以直接將輸入加到輸出上。但是當維度不一致時(對應的是維度增加一倍),這就不能直接相加。有兩種策略:(1)採用zero-padding增加維度,此時一般要先做一個downsamp,可以採用strde=2的pooling,這樣不會增加參數;(2)採用新的映射(projection shortcut),一般採用1x1的卷積,這樣會增加參數,也會增加計算量。短路連接除了直接使用恆等映射,當然都可以採用projection shortcut。

圖6 不同的殘差單元

作者對比18-layer和34-layer的網路效果,如圖7所示。可以看到普通的網路出現退化現象,但是ResNet很好的解決了退化問題。

圖7 18-layer和34-layer的網路效果

最後展示一下ResNet網路與其他網路在ImageNet上的對比結果,如表2所示。可以看到ResNet-152其誤差降到了4.49%,當採用集成模型後,誤差可以降到3.57%。

表2 ResNet與其他網路的對比結果

說一點關於殘差單元題外話,上面我們說到了短路連接的幾種處理方式,其實作者在文獻[2]中又對不同的殘差單元做了細緻的分析與實驗,這裡我們直接拋出最優的殘差結構,如圖8所示。改進前後一個明顯的變化是採用pre-activation,BN和ReLU都提前了。而且作者推薦短路連接採用恆等變換,這樣保證短路連接不會有阻礙。感興趣的可以去讀讀這篇文章。

圖8 改進後的殘差單元及效果

ResNet的TensorFlow實現

這裡給出ResNet50的TensorFlow實現,模型的實現參考了Caffe版本的實現,核心代碼如下:

class ResNet50(object):n def __init__(self, inputs, num_classes=1000, is_training=True,n scope="resnet50"):n self.inputs =inputsn self.is_training = is_trainingn self.num_classes = num_classesnn with tf.variable_scope(scope):n # construct the modeln net = conv2d(inputs, 64, 7, 2, scope="conv1") # -> [batch, 112, 112, 64]n net = tf.nn.relu(batch_norm(net, is_training=self.is_training, scope="bn1"))n net = max_pool(net, 3, 2, scope="maxpool1") # -> [batch, 56, 56, 64]n net = self._block(net, 256, 3, init_stride=1, is_training=self.is_training,n scope="block2") # -> [batch, 56, 56, 256]n net = self._block(net, 512, 4, is_training=self.is_training, scope="block3")n # -> [batch, 28, 28, 512]n net = self._block(net, 1024, 6, is_training=self.is_training, scope="block4")n # -> [batch, 14, 14, 1024]n net = self._block(net, 2048, 3, is_training=self.is_training, scope="block5")n # -> [batch, 7, 7, 2048]n net = avg_pool(net, 7, scope="avgpool5") # -> [batch, 1, 1, 2048]n net = tf.squeeze(net, [1, 2], name="SpatialSqueeze") # -> [batch, 2048]n self.logits = fc(net, self.num_classes, "fc6") # -> [batch, num_classes]n self.predictions = tf.nn.softmax(self.logits)nnn def _block(self, x, n_out, n, init_stride=2, is_training=True, scope="block"):n with tf.variable_scope(scope):n h_out = n_out // 4n out = self._bottleneck(x, h_out, n_out, stride=init_stride,n is_training=is_training, scope="bottlencek1")n for i in range(1, n):n out = self._bottleneck(out, h_out, n_out, is_training=is_training,n scope=("bottlencek%s" % (i + 1)))n return outnn def _bottleneck(self, x, h_out, n_out, stride=None, is_training=True, scope="bottleneck"):n """ A residual bottleneck unit"""n n_in = x.get_shape()[-1]n if stride is None:n stride = 1 if n_in == n_out else 2nn with tf.variable_scope(scope):n h = conv2d(x, h_out, 1, stride=stride, scope="conv_1")n h = batch_norm(h, is_training=is_training, scope="bn_1")n h = tf.nn.relu(h)n h = conv2d(h, h_out, 3, stride=1, scope="conv_2")n h = batch_norm(h, is_training=is_training, scope="bn_2")n h = tf.nn.relu(h)n h = conv2d(h, n_out, 1, stride=1, scope="conv_3")n h = batch_norm(h, is_training=is_training, scope="bn_3")nn if n_in != n_out:n shortcut = conv2d(x, n_out, 1, stride=stride, scope="conv_4")n shortcut = batch_norm(shortcut, is_training=is_training, scope="bn_4")n else:n shortcut = xn return tf.nn.relu(shortcut + h)n

完整實現可以參見GitHub。

總結

ResNet通過殘差學習解決了深度網路的退化問題,讓我們可以訓練出更深的網路,這稱得上是深度網路的一個歷史大突破吧。也許不久會有更好的方式來訓練更深的網路,讓我們一起期待吧!

參考資料

  1. Deep Residual Learning for Image Recognition.
  2. Identity Mappings in Deep Residual Networks.
  3. 去膜拜一下大神.

歡迎交流與轉載,文章會同步發布在公眾號:機器學習演算法全棧工程師(Jeemy110)


推薦閱讀:

本田與商湯牽手共謀自動駕駛,為何相中這家中國AI獨角獸?
[讀論文] Single-Image Depth Perception in the Wild
論文推薦:GAN,信息抽取,機器閱讀理解,對話系統 | 本周值得讀 #35
模型壓縮那些事(一)
深度學習實踐:使用Tensorflow實現快速風格遷移

TAG:卷积神经网络CNN | 深度学习DeepLearning | 计算机视觉 |