DenseNet:比ResNet更優的CNN模型

DenseNet:比ResNet更優的CNN模型

來自專欄 機器學習演算法工程師

碼字不易,歡迎給個贊!

歡迎交流與轉載,文章會同步發布在公眾號:機器學習演算法全棧工程師(Jeemy110)

歷史文章:

小白將:你必須要知道CNN模型:ResNet?

zhuanlan.zhihu.com圖標


前言

在計算機視覺領域,卷積神經網路(CNN)已經成為最主流的方法,比如最近的GoogLenet,VGG-19,Incepetion等模型。CNN史上的一個里程碑事件是ResNet模型的出現,ResNet可以訓練出更深的CNN模型,從而實現更高的準確度。ResNet模型的核心是通過建立前面層與後面層之間的「短路連接」(shortcuts,skip connection),這有助於訓練過程中梯度的反向傳播,從而能訓練出更深的CNN網路。今天我們要介紹的是DenseNet模型,它的基本思路與ResNet一致,但是它建立的是前面所有層與後面層的密集連接(dense connection),它的名稱也是由此而來。DenseNet的另一大特色是通過特徵在channel上的連接來實現特徵重用(feature reuse)。這些特點讓DenseNet在參數和計算成本更少的情形下實現比ResNet更優的性能,DenseNet也因此斬獲CVPR 2017的最佳論文獎。本篇文章首先介紹DenseNet的原理以及網路架構,然後講解DenseNet在Pytorch上的實現。

設計理念

相比ResNet,DenseNet提出了一個更激進的密集連接機制:即互相連接所有的層,具體來說就是每個層都會接受其前面所有層作為其額外的輸入。圖1為ResNet網路的連接機制,作為對比,圖2為DenseNet的密集連接機制。可以看到,ResNet是每個層與前面的某層(一般是2~3層)短路連接在一起,連接方式是通過元素級相加。而在DenseNet中,每個層都會與前面所有層在channel維度上連接(concat)在一起(這裡各個層的特徵圖大小是相同的,後面會有說明),並作為下一層的輸入。對於一個 L 層的網路,DenseNet共包含 frac{L(L+1)}{2} 個連接,相比ResNet,這是一種密集連接。而且DenseNet是直接concat來自不同層的特徵圖,這可以實現特徵重用,提升效率,這一特點是DenseNet與ResNet最主要的區別。

圖1 ResNet網路的短路連接機制(其中+代表的是元素級相加操作)

圖2 DenseNet網路的密集連接機制(其中c代表的是channel級連接操作)

如果用公式表示的話,傳統的網路在 l 層的輸出為:

\x_l = H_l(x_{l-1})

而對於ResNet,增加了來自上一層輸入的identity函數:

\x_l = H_l(x_{l-1}) + x_{l-1}

在DenseNet中,會連接前面所有層作為輸入:

\x_l = H_l([x_0, x_1, ..., x_{l-1}])

其中,上面的 H_l(cdot) 代表是非線性轉化函數(non-liear transformation),它是一個組合操作,其可能包括一系列的BN(Batch Normalization),ReLU,Pooling及Conv操作。注意這裡 l 層與 l-1 層之間可能實際上包含多個卷積層。

DenseNet的前向過程如圖3所示,可以更直觀地理解其密集連接方式,比如 h_3 的輸入不僅包括來自 h_2x_2 ,還包括前面兩層的 x_1x_2 ,它們是在channel維度上連接在一起的。

圖3 DenseNet的前向過程

CNN網路一般要經過Pooling或者stride>1的Conv來降低特徵圖的大小,而DenseNet的密集連接方式需要特徵圖大小保持一致。為了解決這個問題,DenseNet網路中使用DenseBlock+Transition的結構,其中DenseBlock是包含很多層的模塊,每個層的特徵圖大小相同,層與層之間採用密集連接方式。而Transition模塊是連接兩個相鄰的DenseBlock,並且通過Pooling使特徵圖大小降低。圖4給出了DenseNet的網路結構,它共包含4個DenseBlock,各個DenseBlock之間通過Transition連接在一起。

圖4 使用DenseBlock+Transition的DenseNet網路

網路結構

如前所示,DenseNet的網路結構主要由DenseBlock和Transition組成,如圖5所示。下面具體介紹網路的具體實現細節。

圖6 DenseNet的網路結構

在DenseBlock中,各個層的特徵圖大小一致,可以在channel維度上連接。DenseBlock中的非線性組合函數 H(cdot) 採用的是BN+ReLU+3x3 Conv的結構,如圖6所示。另外值得注意的一點是,與ResNet不同,所有DenseBlock中各個層卷積之後均輸出 k 個特徵圖,即得到的特徵圖的channel數為 k ,或者說採用 k 個卷積核。 k 在DenseNet稱為growth rate,這是一個超參數。一般情況下使用較小的 k (比如12),就可以得到較佳的性能。假定輸入層的特徵圖的channel數為 k_0 ,那麼 l 層輸入的channel數為 k_0+k(l-1) ,因此隨著層數增加,儘管 k 設定得較小,DenseBlock的輸入會非常多,不過這是由於特徵重用所造成的,每個層僅有 k 個特徵是自己獨有的。

圖6 DenseBlock中的非線性轉換結構

由於後面層的輸入會非常大,DenseBlock內部可以採用bottleneck層來減少計算量,主要是原有的結構中增加1x1 Conv,如圖7所示,即BN+ReLU+1x1 Conv+BN+ReLU+3x3 Conv,稱為DenseNet-B結構。其中1x1 Conv得到 4k 個特徵圖它起到的作用是降低特徵數量,從而提升計算效率。

圖7 使用bottleneck層的DenseBlock結構

對於Transition層,它主要是連接兩個相鄰的DenseBlock,並且降低特徵圖大小。Transition層包括一個1x1的卷積和2x2的AvgPooling,結構為BN+ReLU+1x1 Conv+2x2 AvgPooling。另外,Transition層可以起到壓縮模型的作用。假定Transition的上接DenseBlock得到的特徵圖channels數為 m ,Transition層可以產生 lfloor	heta m
floor 個特徵(通過卷積層),其中 	heta in (0,1] 是壓縮係數(compression rate)。當 	heta=1 時,特徵個數經過Transition層沒有變化,即無壓縮,而當壓縮係數小於1時,這種結構稱為DenseNet-C,文中使用 	heta=0.5 。對於使用bottleneck層的DenseBlock結構和壓縮係數小於1的Transition組合結構稱為DenseNet-BC。

DenseNet共在三個圖像分類數據集(CIFAR,SVHN和ImageNet)上進行測試。對於前兩個數據集,其輸入圖片大小為 32	imes 32 ,所使用的DenseNet在進入第一個DenseBlock之前,首先進行進行一次3x3卷積(stride=1),卷積核數為16(對於DenseNet-BC為 2k )。DenseNet共包含三個DenseBlock,各個模塊的特徵圖大小分別為 32	imes 3216	imes 168	imes 8 ,每個DenseBlock裡面的層數相同。最後的DenseBlock之後是一個global AvgPooling層,然後送入一個softmax分類器。注意,在DenseNet中,所有的3x3卷積均採用padding=1的方式以保證特徵圖大小維持不變。對於基本的DenseNet,使用如下三種網路配置: {L=40, k=12} , {L=100, k=12}{L=40, k=24} 。而對於DenseNet-BC結構,使用如下三種網路配置: {L=100, k=12} , {L=250, k=24}{L=190, k=40} 。這裡的 L 指的是網路總層數(網路深度),一般情況下,我們只把帶有訓練參數的層算入其中,而像Pooling這樣的無參數層不納入統計中,此外BN層儘管包含參數但是也不單獨統計,而是可以計入它所附屬的卷積層。對於普通的 {L=40, k=12} 網路,除去第一個卷積層、2個Transition中卷積層以及最後的Linear層,共剩餘36層,均分到三個DenseBlock可知每個DenseBlock包含12層。其它的網路配置同樣可以算出各個DenseBlock所含層數。

對於ImageNet數據集,圖片輸入大小為 224	imes 224 ,網路結構採用包含4個DenseBlock的DenseNet-BC,其首先是一個stride=2的7x7卷積層(卷積核數為 2k ),然後是一個stride=2的3x3 MaxPooling層,後面才進入DenseBlock。ImageNet數據集所採用的網路配置如表1所示:

表1 ImageNet數據集上所採用的DenseNet結構

實驗結果及討論

這裡給出DenseNet在CIFAR-100和ImageNet數據集上與ResNet的對比結果,如圖8和9所示。從圖8中可以看到,只有0.8M的DenseNet-100性能已經超越ResNet-1001,並且後者參數大小為10.2M。而從圖9中可以看出,同等參數大小時,DenseNet也優於ResNet網路。其它實驗結果見原論文。

圖8 在CIFAR-100數據集上ResNet vs DenseNet

圖9 在ImageNet數據集上ResNet vs DenseNet

綜合來看,DenseNet的優勢主要體現在以下幾個方面:

  • 由於密集連接方式,DenseNet提升了梯度的反向傳播,使得網路更容易訓練。由於每層可以直達最後的誤差信號,實現了隱式的「deep supervision」;
  • 參數更小且計算更高效,這有點違反直覺,由於DenseNet是通過concat特徵來實現短路連接,實現了特徵重用,並且採用較小的growth rate,每個層所獨有的特徵圖是比較小的;
  • 由於特徵復用,最後的分類器使用了低級特徵。

要注意的一點是,如果實現方式不當的話,DenseNet可能耗費很多GPU顯存,一種高效的實現如圖10所示,更多細節可以見這篇論文Memory-Efficient Implementation of DenseNets。不過我們下面使用Pytorch框架可以自動實現這種優化。

圖10 DenseNet的更高效實現方式

使用Pytorch實現DenseNet

這裡我們採用Pytorch框架來實現DenseNet,目前它已經支持Windows系統。對於DenseNet,Pytorch在torchvision.models模塊里給出了官方實現,這個DenseNet版本是用於ImageNet數據集的DenseNet-BC模型,下面簡單介紹實現過程。

首先實現DenseBlock中的內部結構,這裡是BN+ReLU+1x1 Conv+BN+ReLU+3x3 Conv結構,最後也加入dropout層以用於訓練過程。

class _DenseLayer(nn.Sequential): """Basic unit of DenseBlock (using bottleneck layer) """ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): super(_DenseLayer, self).__init__() self.add_module("norm1", nn.BatchNorm2d(num_input_features)) self.add_module("relu1", nn.ReLU(inplace=True)) self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate, kernel_size=1, stride=1, bias=False)) self.add_module("norm2", nn.BatchNorm2d(bn_size*growth_rate)) self.add_module("relu2", nn.ReLU(inplace=True)) self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)) self.drop_rate = drop_rate def forward(self, x): new_features = super(_DenseLayer, self).forward(x) if self.drop_rate > 0: new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) return torch.cat([x, new_features], 1)

據此,實現DenseBlock模塊,內部是密集連接方式(輸入特徵數線性增長):

class _DenseBlock(nn.Sequential): """DenseBlock""" def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): super(_DenseBlock, self).__init__() for i in range(num_layers): layer = _DenseLayer(num_input_features+i*growth_rate, growth_rate, bn_size, drop_rate) self.add_module("denselayer%d" % (i+1,), layer)

此外,我們實現Transition層,它主要是一個卷積層和一個池化層:

class _Transition(nn.Sequential): """Transition layer between two adjacent DenseBlock""" def __init__(self, num_input_feature, num_output_features): super(_Transition, self).__init__() self.add_module("norm", nn.BatchNorm2d(num_input_feature)) self.add_module("relu", nn.ReLU(inplace=True)) self.add_module("conv", nn.Conv2d(num_input_feature, num_output_features, kernel_size=1, stride=1, bias=False)) self.add_module("pool", nn.AvgPool2d(2, stride=2))

最後我們實現DenseNet網路:

class DenseNet(nn.Module): "DenseNet-BC model" def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000): """ :param growth_rate: (int) number of filters used in DenseLayer, `k` in the paper :param block_config: (list of 4 ints) number of layers in each DenseBlock :param num_init_features: (int) number of filters in the first Conv2d :param bn_size: (int) the factor using in the bottleneck layer :param compression_rate: (float) the compression rate used in Transition Layer :param drop_rate: (float) the drop rate after each DenseLayer :param num_classes: (int) number of classes for classification """ super(DenseNet, self).__init__() # first Conv2d self.features = nn.Sequential(OrderedDict([ ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), ("norm0", nn.BatchNorm2d(num_init_features)), ("relu0", nn.ReLU(inplace=True)), ("pool0", nn.MaxPool2d(3, stride=2, padding=1)) ])) # DenseBlock num_features = num_init_features for i, num_layers in enumerate(block_config): block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate) self.features.add_module("denseblock%d" % (i + 1), block) num_features += num_layers*growth_rate if i != len(block_config) - 1: transition = _Transition(num_features, int(num_features*compression_rate)) self.features.add_module("transition%d" % (i + 1), transition) num_features = int(num_features * compression_rate) # final bn+ReLU self.features.add_module("norm5", nn.BatchNorm2d(num_features)) self.features.add_module("relu5", nn.ReLU(inplace=True)) # classification layer self.classifier = nn.Linear(num_features, num_classes) # params initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1) elif isinstance(m, nn.Linear): nn.init.constant_(m.bias, 0) def forward(self, x): features = self.features(x) out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1) out = self.classifier(out) return out

選擇不同網路參數,就可以實現不同深度的DenseNet,這裡實現DenseNet-121網路,而且Pytorch提供了預訓練好的網路參數:

def densenet121(pretrained=False, **kwargs): """DenseNet121""" model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) if pretrained: # .s are no longer allowed in module names, but pervious _DenseLayer # has keys norm.1, relu.1, conv.1, norm.2, relu.2, conv.2. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( r^(.*denselayerd+.(?:norm|relu|conv)).((?:[12]).(?:weight|bias|running_mean|running_var))$) state_dict = model_zoo.load_url(model_urls[densenet121]) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) return model

下面,我們使用預訓練好的網路對圖片進行測試,這裡給出top-5預測值:

densenet = densenet121(pretrained=True)densenet.eval()img = Image.open("./images/cat.jpg")trans_ops = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])images = trans_ops(img).view(-1, 3, 224, 224)outputs = densenet(images)_, predictions = outputs.topk(5, dim=1)labels = list(map(lambda s: s.strip(), open("./data/imagenet/synset_words.txt").readlines()))for idx in predictions.numpy()[0]: print("Predicted labels:", labels[idx])

給出的預測結果為:

Predicted labels: n02123159 tiger catPredicted labels: n02123045 tabby, tabby catPredicted labels: n02127052 lynx, catamountPredicted labels: n02124075 Egyptian catPredicted labels: n02119789 kit fox, Vulpes macrotis

註:完整代碼見xiaohu2015/DeepLearning_tutorials。

小結

這篇文章詳細介紹了DenseNet的設計理念以及網路結構,並給出了如何使用Pytorch來實現。值得注意的是,DenseNet在ResNet基礎上前進了一步,相比ResNet具有一定的優勢,但是其卻並沒有像ResNet那麼出名(吃顯存問題?深度不能太大?)。期待未來有更好的網路模型出現吧!

參考文獻

  1. DenseNet-CVPR-Slides.
  2. Densely Connected Convolutional Networks.

碼字不易,歡迎給個贊!

歡迎交流與轉載,文章會同步發布在公眾號:機器學習演算法全棧工程師(Jeemy110)


推薦閱讀:

TAG:深度學習DeepLearning | 卷積神經網路CNN | 圖像識別 |