如何利用TFlearn搭建LeNet-5

Lenet-5是一個經典的CNN卷積神經網路模型,由紐約大學Yan LeCun 提出,用於識別手寫字體。該模型在手寫體識別領域取得極大成功,曾被廣泛應用於美國銀行支票手寫體識別。

TFlearn是 TensorFlow 的高層次 API,目的是便於快速搭建試驗環境,同時保持對 TensorFlow 的完全透明和兼容性。順便說一下TensorFlow。Tensorflow是谷歌推出的機器學習開源系統,支持DNN、CNN、RNN和LSTM演算法,可以用於語音識別或圖像識別等多項機器深度學習領域,這都是目前在Image,Speech和NLP最流行的深度神經網路模型。對於開發者來說,我們接觸到的深度學習的所有模型都能通過Tensorflow來實現。

在深度學習中,手寫字體數據訓練是所有教學資料中被反覆用到的範例。可以通過Tensorflow實現,也可以通過TFlearn實現。今天就來對比一下這兩者實現的代碼,揭示我們在選用工具時需要考慮的問題。

卷積神經網路一個關鍵的理解點在於對 CNNs 維度的理解,理解維度可以幫你在模型大小和模型質量上,做精確的權衡。下圖是LeNet-5的模型圖:

  1. 從一張32x32的手寫字體圖片輸入開始,送入第一層卷積層成為28x28x6的數據矩陣,然後降採樣成14x14x6的數據矩陣;
  2. 進入第二層卷積層成為10x10x16,然後降採樣為5x5x16的矩陣;
  3. 隨後送入第三層尺寸為120的全連接,
  4. 第四層尺寸為84的全連接,和
  5. 第五層尺寸為10的全連接。而第五層就已經是整個網路的輸出層了,這10個輸出就是0-9這10個數字中網路給出的學習結果。

通過調整網路的維度,即中間這些卷積層、全連接層的大小,我們可以取得不同的識別精度。

用Tensorflow來實現這些維度的調整,可以讓開發者做精確地權衡。例如每一層的輸入輸出尺寸我們都可以通過下面的算式計算得到:

TensorFlow 使用如下等式計算 SAME 、VALID PADDING下的維度:

SAME Padding, 輸出的高和寬,計算如下:

out_height = ceil(float(in_height) / float(strides1))

out_width = ceil(float(in_width) / float(strides[2]))

VALID Padding, 輸出的高和寬,計算如下:

out_height = ceil(float(in_height - filter_height + 1) / float(strides1))

out_width = ceil(float(in_width - filter_width + 1) / float(strides[2]))

要實現上述LeNet-5的模型,Tensorflow的代碼為:

def LeNet(x): n # Arguments used for tf.truncated_normal, randomly defines variables for the weights and biases for each layern mu = 0n sigma = 0.1n n # SOLUTION: Layer 1: Convolutional. Input = 32x32x1. Output = 28x28x6.n conv1_W = tf.Variable(tf.truncated_normal(shape=(5, 5, 1, 6), mean = mu, stddev = sigma))n conv1_b = tf.Variable(tf.zeros(6))n conv1 = tf.nn.conv2d(x, conv1_W, strides=[1, 1, 1, 1], padding=VALID) + conv1_bnn # SOLUTION: Activation.n conv1 = tf.nn.relu(conv1)nn # SOLUTION: Pooling. Input = 28x28x6. Output = 14x14x6.n conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding=VALID)nn # SOLUTION: Layer 2: Convolutional. Output = 10x10x16.n conv2_W = tf.Variable(tf.truncated_normal(shape=(5, 5, 6, 16), mean = mu, stddev = sigma))n conv2_b = tf.Variable(tf.zeros(16))n conv2 = tf.nn.conv2d(conv1, conv2_W, strides=[1, 1, 1, 1], padding=VALID) + conv2_bn n # SOLUTION: Activation.n conv2 = tf.nn.relu(conv2)nn # SOLUTION: Pooling. Input = 10x10x16. Output = 5x5x16.n conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding=VALID)nn # SOLUTION: Flatten. Input = 5x5x16. Output = 400.n fc0 = flatten(conv2)n n # SOLUTION: Layer 3: Fully Connected. Input = 400. Output = 120.n fc1_W = tf.Variable(tf.truncated_normal(shape=(400, 120), mean = mu, stddev = sigma))n fc1_b = tf.Variable(tf.zeros(120))n fc1 = tf.matmul(fc0, fc1_W) + fc1_bn n # SOLUTION: Activation.n fc1 = tf.nn.relu(fc1)nn # SOLUTION: Layer 4: Fully Connected. Input = 120. Output = 84.n fc2_W = tf.Variable(tf.truncated_normal(shape=(120, 84), mean = mu, stddev = sigma))n fc2_b = tf.Variable(tf.zeros(84))n fc2 = tf.matmul(fc1, fc2_W) + fc2_bn n # SOLUTION: Activation.n fc2 = tf.nn.relu(fc2)nn # SOLUTION: Layer 5: Fully Connected. Input = 84. Output = 10.n fc3_W = tf.Variable(tf.truncated_normal(shape=(84, 10), mean = mu, stddev = sigma))n fc3_b = tf.Variable(tf.zeros(10))n logits = tf.matmul(fc2, fc3_W) + fc3_bn n return logitsn

而我們都通常希望對一個問題快速選定工具,然後上手測試。實現同樣的模型,TFlearn就省去了很多行代碼,僅通過以下幾行就完全達到了上面的相同的效果。

conv1 = tflearn.conv_2d(input_data, 6, 5, activation=relu, name=conv1)nconv1 = max_pool_2d(conv1, 2, 2) nn nconv2 = tflearn.conv_2d(conv1, 16, 5, activation=relu, name=conv1)nconv2 = max_pool_2d(conv2, 2, 2) nnfc0 = flatten(conv2)n nfc1 = fully_connected(fc0, 120, activation=relu) nfc2 = fully_connected(fc1, 84, activation=relu) nfc3 = fully_connected(fc2, 10)n

對比兩者很容易發現TFlearn省去了weights和bias變數的初始化定義,也就省去了每一層輸入輸出的維度計算,只要關注輸出的尺寸就行了。通常開發者調整網路的性能通過調整輸出尺寸就行了。輸出的尺寸意味著神經元的數量,神經元越多運算的參數就越多,神經網路學習到的東西也就越多。

那麼硬幣的另一面,「省去了weights和bias的變數初始化定義,我們失去了什麼靈活性呢?」

合適的初始權重可以使神經網路更接近最優解,也能夠使神經網路更快地得到最優結果。簡單的1和0初始化權重會使我們的網路性能非常糟糕,迭代很多次也難以提高精度。常用的權重初始化有均勻分布、正態分布和截斷正態分布。而且,這其中如果隨機分布範圍選得不合適也會造成網路性能無法達到最優。

所以在大型網路機器中,我們必須使用Tensorflow才行,而小數據量、考察一下數據趨勢等小量學習任務TFlearn就能勝任了。


推薦閱讀:

對深度學習的理解達到什麼水平才能應聘大型互聯網公司的機器學習相關崗位?
PyTorch到底好用在哪裡?
為什麼說雲計算、大數據、機器學習、深度學習被並稱為當今計算機界四大俗?
顯卡、顯卡驅動、cuda 之間的關係是什麼?

TAG:TensorFlow | 深度学习DeepLearning |