標籤:

黑猿大叔-譯文 | TensorFlow實現Batch Normalization

作者:黑猿大叔

原文鏈接:jianshu.com/p/b2d2f3c7b

查看更多的專業文章、課程信息、產品信息,請移步至「人工智慧LeadAI」公眾號,或請移步至全新打造的官網:www.leadai,org.

正文共10537個字,8張圖,預計閱讀時間27分鐘。

原文:Implementing Batch Normalization in Tensorflow(r2rt.com/implementing-b

來源:R2RT

譯者註:本文基於一個最基礎的全連接網路,演示如何構建Batch Norm層、如何訓練以及如何正確進行測試,玩轉這份示例代碼是理解Batch Norm的最好方式。

文中代碼可在jupyter notebook環境下運行:

  • nn_withBN.ipynb(github.com/EthanYuan/Te),
  • nn_withBN_ok.ipynb(github.com/EthanYuan/Te

批標準化,是Sergey Ioffe和Christian Szegedy在2015年3月的論文BN2015(arxiv.org/pdf/1502.0316)中提出的一種簡單、高效的改善神經網路性能的方法。論文BN2015中,Ioffe和Szegedy指出批標準化不僅能應用更高的學習率、具有正則化器的效用,還能將訓練速度提升14倍之多。本文將基於TensorFlow來實現批標準化。

問題的提出

批標準化所要解決的問題是:模型參數在學習階段的變化,會使每個隱藏層輸出的分布也發生改變。這意味著靠後的層要在訓練過程中去適應這些變化。

批標準化的概念

為了解決這個問題,論文BN2015提出了批標準化,即在訓練時作用於每個神經元激活函數(比如sigmoid或者ReLU函數)的輸入,使得基於每個批次的訓練樣本,激活函數的輸入都能滿足均值為0,方差為1的分布。對於激活函數σ(Wx+b),應用批標準化後變為σ(BN(Wx+b)),其中BN代表批標準化。

批標準化公式

對一批數據中的某個數值進行標準化,做法是先減去整批數據的均值,然後除以整批數據的標準差√(σ2+ε)。注意小的常量ε加到方差中是為了防止除零。給定一個數值xi,一個初始的批標準化公式如下:

上面的公式中,批標準化對激活函數的輸入約束為正態分布,但是這樣一來限制了網路層的表達能力。為此,可以通過乘以一個新的比例參數γ,並加上一個新的位移參數β,來讓網路撤銷批標準化變換。γ和β都是可學習參數。

加入γ和β後得到下面最終的批標準化公式:

基於TensorFlow實現批標準化

我們將把批標準化加進一個有兩個隱藏層、每層包含100個神經元的全連接神經網路,並展示與論文BN2015中圖1(b)和(c)類似的實驗結果。

需要注意,此時該網路還不適合在測試期使用。後面的「模型預測」一節中將會闡釋其中的原因,並給出修復版本。

Imports,configimport numpy as np, tensorflow as tf, tqdmfrom tensorflow.examples.tutorials.mnist import input_dataimport matplotlib.pyplot as plt%matplotlib inlinemnist = input_data.read_data_sets(MNIST_data, one_hot=True)

# Generate predetermined random weights so the networks are similarly initializedw1_initial = np.random.normal(size=(784,100)).astype(np.float32)w2_initial = np.random.normal(size=(100,100)).astype(np.float32)w3_initial = np.random.normal(size=(100,10)).astype(np.float32)# Small epsilon value for the BN transformepsilon = 1e-3Building the graph# Placeholdersx = tf.placeholder(tf.float32, shape=[None, 784]) y_ = tf.placeholder(tf.float32, shape=[None, 10])

# Layer 1 without BNw1 = tf.Variable(w1_initial)b1 = tf.Variable(tf.zeros([100]))z1 = tf.matmul(x,w1)+b1l1 = tf.nn.sigmoid(z1)

下面是經過批標準化的第一層:

# Layer 1 with BNw1_BN = tf.Variable(w1_initial)# Note that pre-batch normalization bias is ommitted. The effect of this bias would be# eliminated when subtracting the batch mean. Instead, the role of the bias is performed# by the new beta variable. See Section 3.2 of the BN2015 paper.z1_BN = tf.matmul(x,w1_BN)# Calculate batch mean and variancebatch_mean1, batch_var1 = tf.nn.moments(z1_BN,[0])# Apply the initial batch normalizing transformz1_hat = (z1_BN - batch_mean1) / tf.sqrt(batch_var1 + epsilon)# Create two new parameters, scale and beta (shift)scale1 = tf.Variable(tf.ones([100]))beta1 = tf.Variable(tf.zeros([100]))# Scale and shift to obtain the final output of the batch normalization# this value is fed into the activation function (here a sigmoid)BN1 = scale1 * z1_hat + beta1l1_BN = tf.nn.sigmoid(BN1)

# Layer 2 without BNw2 = tf.Variable(w2_initial)b2 = tf.Variable(tf.zeros([100]))z2 = tf.matmul(l1,w2)+b2l2 = tf.nn.sigmoid(z2)

TensorFlow提供了tf.nn.batch_normalization,我用它定義了下面的第二層。這與上面第一層的代碼行為是一樣的。查閱開源代碼在這裡(github.com/tensorflow/t)。

# Layer 2 with BN, using Tensorflows built-in BN functionw2_BN = tf.Variable(w2_initial)z2_BN = tf.matmul(l1_BN,w2_BN)batch_mean2, batch_var2 = tf.nn.moments(z2_BN,[0])scale2 = tf.Variable(tf.ones([100]))beta2 = tf.Variable(tf.zeros([100]))BN2 = tf.nn.batch_normalization(z2_BN,batch_mean2,batch_var2,beta2,scale2,epsilon)l2_BN = tf.nn.sigmoid(BN2)

# Softmaxw3 = tf.Variable(w3_initial)b3 = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(l2,w3)+b3)w3_BN = tf.Variable(w3_initial)b3_BN = tf.Variable(tf.zeros([10]))y_BN = tf.nn.softmax(tf.matmul(l2_BN,w3_BN)+b3_BN)

# Loss, optimizer and predictions cross_entropy = -tf.reduce_sum(y_*tf.log(y)) cross_entropy_BN = -tf.reduce_sum(y_*tf.log(y_BN)) train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) train_step_BN = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy_BN) correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) correct_prediction_BN = tf.equal(tf.arg_max(y_BN,1),tf.arg_max(y_,1)) accuracy_BN = tf.reduce_mean(tf.cast(correct_prediction_BN,tf.float32))training the networkzs, BNs, acc, acc_BN = [], [], [], []sess = tf.InteractiveSession()sess.run(tf.global_variables_initializer())for i in tqdm.tqdm(range(40000)): batch = mnist.train.next_batch(60) train_step.run(feed_dict={x: batch[0], y_: batch[1]}) train_step_BN.run(feed_dict={x: batch[0], y_: batch[1]}) if i % 50 is 0: res = sess.run([accuracy,accuracy_BN,z2,BN2],feed_dict={x: mnist.test.images, y_: mnist.test.labels}) acc.append(res[0]) acc_BN.append(res[1]) zs.append(np.mean(res[2],axis=0)) # record the mean value of z2 over the entire test set BNs.append(np.mean(res[3],axis=0)) # record the mean value of BN2 over the entire test setzs, BNs, acc, acc_BN = np.array(zs), np.array(BNs), np.array(acc), np.array(acc_BN)

速度和精度的提升

如下所示,應用批標準化後,精度和訓練速度均有可觀的改善。論文BN2015中的圖2顯示,批標準化對於其他網路架構也同樣具有重要作用。

fig, ax = plt.subplots()ax.plot(range(0,len(acc)*50,50),acc, label=Without BN)ax.plot(range(0,len(acc)*50,50),acc_BN, label=With BN)ax.set_xlabel(Training steps)ax.set_ylabel(Accuracy)ax.set_ylim([0.8,1])ax.set_title(Batch Normalization Accuracy)ax.legend(loc=4)plt.show()

激活函數輸入的時間序列圖示

下面是網路第2層的前5個神經元的sigmoid激活函數輸入隨時間的分布情況。批標準化在消除輸入的方差/雜訊上具有顯著的效果。

fig, axes = plt.subplots(5, 2, figsize=(6,12))fig.tight_layout()for i, ax in enumerate(axes): ax[0].set_title("Without BN") ax[1].set_title("With BN") ax[0].plot(zs[:,i]) ax[1].plot(BNs[:,i])

模型預測

使用批標準化模型進行預測時,使用批量樣本自身的均值和方差會適得其反。想像一下單個樣本進入我們訓練的模型會發生什麼?激活函數的輸入將永遠為零(因為我們做的是均值為0的標準化),而且無論輸入是什麼,我們總得到相同的結果。

驗證如下:

predictions = []correct = 0for i in range(100): pred, corr = sess.run([tf.arg_max(y_BN,1), accuracy_BN], feed_dict={x: [mnist.test.images[i]], y_: [mnist.test.labels[i]]}) correct += corr predictions.append(pred[0])print("PREDICTIONS:", predictions)print("ACCURACY:", correct/100)

PREDICTIONS: [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]ACCURACY: 0.02

我們的模型總是輸出8,在MNIST的前100個樣本中8實際上只有2個,所以精度只有2%。

修改模型的測試期行為

為了修復這個問題,我們需要將批均值和批方差替換成全局均值和全局方差。詳見論文BN2015的3.1節。但是這會造成,上面的模型想正確的工作,就只能一次性的將測試集所有樣本進行預測,因為這樣才能算出理想的全局均值和全局方差。

為了使批標準化模型適用於測試,我們需要在測試前的每一步批標準化操作時,都對全局均值和全局方差進行估算,然後才能在做預測時使用這些值。和我們需要批標準化的原因一樣(激活輸入的均值和方差在訓練時會發生變化),估算全局均值和方差最好在其依賴的權重更新完成後,但是同時進行也不算特別糟,因為權重在訓練快結束時就收斂了。

現在,為了基於TensorFlow來實現修復,我們要寫一個batch_norm_wrapper函數,來封裝激活輸入。這個函數會將全局均值和方差作為tf.Variables來存儲,並在做標準化時決定採用批統計還是全局統計。為此,需要一個is_training標記。當is_training == True,我們就要在訓練期學習全局均值和方差。代碼骨架如下:

def batch_norm_wrapper(inputs, is_training): ... pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False) pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False) if is_training: mean, var = tf.nn.moments(inputs,[0]) ... # learn pop_mean and pop_var here ... return tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta, scale, epsilon) else: return tf.nn.batch_normalization(inputs, pop_mean, pop_var, beta, scale, epsilon)

注意變數節點聲明了 trainable = False,因為我們將要自行更新它們,而不是讓最優化器來更新。

在訓練期間,一個計算全局均值和方差的方法是指數平滑法,它很簡單,且避免了額外的工作,我們應用如下:

decay = 0.999 # use numbers closer to 1 if you have more datatrain_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))

最後,我們需要解決如何調用這些訓練期操作。為了完全可控,你可以把它們加入到一個graph collection(可以看看下面鏈接的TensorFlow源碼),但是簡單起見,我們將會在每次計算批均值和批方差時都調用它們。為此,當is_training為True時,我們把它們作為依賴加入了batch_norm_wrapper的返回值中。最終的batch_norm_wrapper函數如下:

# this is a simpler version of Tensorflows official version. See:# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/layers.py#L102def batch_norm_wrapper(inputs, is_training, decay = 0.999): scale = tf.Variable(tf.ones([inputs.get_shape()[-1]])) beta = tf.Variable(tf.zeros([inputs.get_shape()[-1]])) pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False) pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False) if is_training: batch_mean, batch_var = tf.nn.moments(inputs,[0]) train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay)) train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) with tf.control_dependencies([train_mean, train_var]): return tf.nn.batch_normalization(inputs,batch_mean, batch_var, beta, scale, epsilon) else: return tf.nn.batch_normalization(inputs,pop_mean, pop_var, beta, scale, epsilon)

實現正常測試

現在為了證明修復後的代碼可以正常測試,我們使用batch_norm_wrapper重新構建模型。注意,我們不僅要在訓練時做一次構建,在測試時還要重新做一次構建,所以我們寫了一個build_graph函數(實際的模型對象往往也是這麼封裝的):

def build_graph(is_training): # Placeholders x = tf.placeholder(tf.float32, shape=[None, 784]) y_ = tf.placeholder(tf.float32, shape=[None, 10]) # Layer 1 w1 = tf.Variable(w1_initial) z1 = tf.matmul(x,w1) bn1 = batch_norm_wrapper(z1, is_training) l1 = tf.nn.sigmoid(bn1) #Layer 2 w2 = tf.Variable(w2_initial) z2 = tf.matmul(l1,w2) bn2 = batch_norm_wrapper(z2, is_training) l2 = tf.nn.sigmoid(bn2) # Softmax w3 = tf.Variable(w3_initial) b3 = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(l2, w3)) # Loss, Optimizer and Predictions cross_entropy = -tf.reduce_sum(y_*tf.log(y)) train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) return (x, y_), train_step, accuracy, y, tf.train.Saver()

#Build training graph, train and save the trained modelsess.close()tf.reset_default_graph()(x, y_), train_step, accuracy, _, saver = build_graph(is_training=True)acc = []with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in tqdm.tqdm(range(10000)): batch = mnist.train.next_batch(60) train_step.run(feed_dict={x: batch[0], y_: batch[1]}) if i % 50 is 0: res = sess.run([accuracy],feed_dict={x: mnist.test.images, y_: mnist.test.labels}) acc.append(res[0]) saved_model = saver.save(sess, ./temp-bn-save)print("Final accuracy:", acc[-1])Final accuracy: 0.9721

現在應該一切正常了,我們重複上面的實驗:

tf.reset_default_graph()(x, y_), _, accuracy, y, saver = build_graph(is_training=False)predictions = []correct = 0with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, ./temp-bn-save) for i in range(100): pred, corr = sess.run([tf.arg_max(y,1), accuracy], feed_dict={x: [mnist.test.images[i]], y_: [mnist.test.labels[i]]}) correct += corr predictions.append(pred[0])print("PREDICTIONS:", predictions)print("ACCURACY:", correct/100)

PREDICTIONS: [7, 2, 1, 0, 4, 1, 4, 9, 6, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5, 4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2, 4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3, 7, 4, 6, 4, 3, 0, 7, 0, 2, 9, 1, 7, 3, 2, 9, 7, 7, 6, 2, 7, 8, 4, 7, 3, 6, 1, 3, 6, 9, 3, 1, 4, 1, 7, 6, 9]ACCURACY: 0.99

推薦閱讀:

如何用 TensorFlow 打造 Not Hotdog 的移動應用
手把手教你用TensorFlow實現看圖說話|教程+代碼
識別漢字圖像的數據集

TAG:TensorFlow |