TensorFlow 教程 #04 - 保存 & 恢復
本篇主要介紹如何保存和恢復神經網路變數以及Early-Stopping優化策略。
其中有大段之前教程的文字及代碼,如果看過的朋友可以快速翻到下文Saver相關的部分。
01 - 簡單線性模型 | 02 - 卷積神經網路 | 03 - PrettyTensor
by Magnus Erik Hvass Pedersen / GitHub / Videos on YouTube
中文翻譯 thrillerist/Github
如有轉載,請附上本文鏈接。
_______________________________________________________________________________________
簡介
這篇教程展示了如何保存以及恢復神經網路中的變數。在優化的過程中,當驗證集上分類準確率提高時,保存神經網路的變數。如果經過1000次迭代還不能提升性能時,就終止優化。然後我們重新載入在驗證集上表現最好的變數。
這種策略稱為Early-Stopping。它用來避免神經網路的過擬合。(過擬合)會在神經網路訓練時間太長時出現,此時神經網路開始學習訓練集中的雜訊,將導致它誤分類新的圖像。
這篇教程主要是用神經網路來識別MNIST數據集中的手寫數字,過擬合在這裡並不是什麼大問題。但本教程展示了Early Stopping的思想。
本文基於上一篇教程,你需要了解基本的TensorFlow和附加包Pretty Tensor。其中大量代碼和文字與之前教程相似,如果你已經看過就可以快速地瀏覽本文。
流程圖
下面的圖表直接顯示了之後實現的卷積神經網路中數據的傳遞。網路有兩個卷積層和兩個全連接層,最後一層是用來給輸入圖像分類的。關於網路和卷積的更多細節描述見教程#02 。
from IPython.display import ImagenImage(images/02_network_flowchart.png)n
導入
%matplotlib inlinenimport matplotlib.pyplot as pltnimport tensorflow as tfnimport numpy as npnfrom sklearn.metrics import confusion_matrixnimport timenfrom datetime import timedeltanimport mathnimport osnn# Use PrettyTensor to simplify Neural Network construction.nimport prettytensor as ptn
使用Python3.5.2(Anaconda)開發,TensorFlow版本是:
tf.__version__n
0.12.0-rc0
PrettyTensor 版本:
pt.__version__n
0.7.1
載入數據
MNIST數據集大約12MB,如果沒在給定路徑中找到就會自動下載。
from tensorflow.examples.tutorials.mnist import input_datandata = input_data.read_data_sets(data/MNIST/, one_hot=True)n
Extracting data/MNIST/train-images-idx3-ubyte.gz
Extracting data/MNIST/train-labels-idx1-ubyte.gzExtracting data/MNIST/t10k-images-idx3-ubyte.gzExtracting data/MNIST/t10k-labels-idx1-ubyte.gz
現在已經載入了MNIST數據集,它由70,000張圖像和對應的標籤(比如圖像的類別)組成。數據集分成三份互相獨立的子集。我們在教程中只用訓練集和測試集。
print("Size of:")nprint("- Training-set:tt{}".format(len(data.train.labels)))nprint("- Test-set:tt{}".format(len(data.test.labels)))nprint("- Validation-set:t{}".format(len(data.validation.labels)))n
Size of:
- Training-set:tt55000- Test-set:tt10000- Validation-set:t5000
類型標籤使用One-Hot編碼,這意外每個標籤是長為10的向量,除了一個元素之外,其他的都為零。這個元素的索引就是類別的數字,即相應圖片中畫的數字。我們也需要測試數據集類別數字的整型值,用下面的方法來計算。
data.test.cls = np.argmax(data.test.labels, axis=1)ndata.validation.cls = np.argmax(data.validation.labels, axis=1)n
數據維度
在下面的源碼中,有很多地方用到了數據維度。它們只在一個地方定義,因此我們可以在代碼中使用這些數字而不是直接寫數字。
# We know that MNIST images are 28 pixels in each dimension.nimg_size = 28nn# Images are stored in one-dimensional arrays of this length.nimg_size_flat = img_size * img_sizenn# Tuple with height and width of images used to reshape arrays.nimg_shape = (img_size, img_size)nn# Number of colour channels for the images: 1 channel for gray-scale.nnum_channels = 1nn# Number of classes, one class for each of 10 digits.nnum_classes = 10n
用來繪製圖片的幫助函數
這個函數用來在3x3的柵格中畫9張圖像,然後在每張圖像下面寫出真實類別和預測類別。
def plot_images(images, cls_true, cls_pred=None):n assert len(images) == len(cls_true) == 9n n # Create figure with 3x3 sub-plots.n fig, axes = plt.subplots(3, 3)n fig.subplots_adjust(hspace=0.3, wspace=0.3)nn for i, ax in enumerate(axes.flat):n # Plot image.n ax.imshow(images[i].reshape(img_shape), cmap=binary)nn # Show true and predicted classes.n if cls_pred is None:n xlabel = "True: {0}".format(cls_true[i])n else:n xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])nn # Show the classes as the label on the x-axis.n ax.set_xlabel(xlabel)n n # Remove ticks from the plot.n ax.set_xticks([])n ax.set_yticks([])n n # Ensure the plot is shown correctly with multiple plotsn # in a single Notebook cell.n plt.show()n
繪製幾張圖像來看看數據是否正確
# Get the first images from the test-set.nimages = data.test.images[0:9]nn# Get the true classes for those images.ncls_true = data.test.cls[0:9]nn# Plot the images and labels using our helper-function above.nplot_images(images=images, cls_true=cls_true)n
TensorFlow圖
TensorFlow的全部目的就是使用一個稱之為計算圖(computational graph)的東西,它會比直接在Python中進行相同計算量要高效得多。TensorFlow比Numpy更高效,因為TensorFlow了解整個需要運行的計算圖,然而Numpy只知道某個時間點上唯一的數學運算。
TensorFlow也能夠自動地計算需要優化的變數的梯度,使得模型有更好的表現。這是由於圖是簡單數學表達式的結合,因此整個圖的梯度可以用鏈式法則推導出來。
TensorFlow還能利用多核CPU和GPU,Google也為TensorFlow製造了稱為TPUs(Tensor Processing Units)的特殊晶元,它比GPU更快。
A TensorFlow graph consists of the following parts which will be detailed below:
- 佔位符變數(Placeholder)用來改變圖的輸入。
- 模型變數(Model)將會被優化,使得模型表現得更好。
- 模型本質上就是一些數學函數,它根據Placeholder和模型的輸入變數來計算一些輸出。
- 一個cost度量用來指導變數的優化。
- 一個優化策略會更新模型的變數。
另外,TensorFlow圖也包含了一些調試狀態,比如用TensorBoard列印log數據,本教程不涉及這些。
佔位符 (Placeholder)變數
Placeholder是作為圖的輸入,我們每次運行圖的時候都可能改變它們。將這個過程稱為feeding placeholder變數,後面將會描述這個。
首先我們為輸入圖像定義placeholder變數。這讓我們可以改變輸入到TensorFlow圖中的圖像。這也是一個張量(tensor),代表一個多維向量或矩陣。數據類型設置為float32,形狀設為[None, img_size_flat],None代表tensor可能保存著任意數量的圖像,每張圖象是一個長度為img_size_flat的向量。
x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name=x)n
卷積層希望x被編碼為4維張量,因此我們需要將它的形狀轉換至[num_images, img_height, img_width, num_channels]。注意img_height == img_width == img_size,如果第一維的大小設為-1, num_images的大小也會被自動推導出來。轉換運算如下:
x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])n
接下來我們為輸入變數x中的圖像所對應的真實標籤定義placeholder變數。變數的形狀是[None, num_classes],這代表著它保存了任意數量的標籤,每個標籤是長度為num_classes的向量,本例中長度為10。
y_true = tf.placeholder(tf.float32, shape=[None, 10], name=y_true)n
我們也可以為class-number提供一個placeholder,但這裡用argmax來計算它。這裡只是TensorFlow中的一些操作,沒有執行什麼運算。
y_true_cls = tf.argmax(y_true, dimension=1)n
神經網路
這一節用PrettyTensor實現卷積神經網路,這要比直接在TensorFlow中實現來得簡單,詳見教程 #03。
基本思想就是用一個Pretty Tensor object封裝輸入張量x_image,它有一個添加新卷積層的幫助函數,以此來創建整個神經網路。Pretty Tensor負責變數分配等等。
x_pretty = pt.wrap(x_image)n
現在我們已經將輸入圖像裝到一個PrettyTensor的object中,再用幾行代碼就可以添加卷積層和全連接層。
注意,在with代碼塊中,pt.defaults_scope(activation_fn=tf.nn.relu) 把 activation_fn=tf.nn.relu當作每個的層參數,因此這些層都用到了 Rectified Linear Units (ReLU) 。defaults_scope使我們能更方便地修改所有層的參數。
with pt.defaults_scope(activation_fn=tf.nn.relu):n y_pred, loss = x_pretty.n conv2d(kernel=5, depth=16, name=layer_conv1).n max_pool(kernel=2, stride=2).n conv2d(kernel=5, depth=36, name=layer_conv2).n max_pool(kernel=2, stride=2).n flatten().n fully_connected(size=128, name=layer_fc1).n softmax_classifier(num_classes=num_classes, labels=y_true)n
獲取權重
下面,我們要繪製神經網路的權重。當使用Pretty Tensor來創建網路時,層的所有變數都是由Pretty Tensoe間接創建的。因此我們要從TensorFlow中獲取變數。
我們用layer_conv1 和 layer_conv2代表兩個卷積層。這也叫變數作用域(不要與上面描述的defaults_scope混淆了)。PrettyTensor會自動給它為每個層創建的變數命名,因此我們可以通過層的作用域名稱和變數名來取得某一層的權重。
函數實現有點笨拙,因為我們不得不用TensorFlow函數get_variable(),它是設計給其他用途的,創建新的變數或重用現有變數。創建下面的幫助函數很簡單。
def get_weights_variable(layer_name):n # Retrieve an existing variable named weights in the scopen # with the given layer_name.n # This is awkward because the TensorFlow function wasn # really intended for another purpose.nn with tf.variable_scope(layer_name, reuse=True):n variable = tf.get_variable(weights)nn return variablen
藉助這個幫助函數我們可以獲取變數。這些是TensorFlow的objects。你需要類似的操作來獲取變數的內容: contents = session.run(weights_conv1) ,下面會提到這個。
weights_conv1 = get_weights_variable(layer_name=layer_conv1)nweights_conv2 = get_weights_variable(layer_name=layer_conv2)n
優化方法
PrettyTensor給我們提供了預測類型標籤(y_pred)以及一個需要最小化的損失度量,用來提升神經網路分類圖片的能力。
PrettyTensor的文檔並沒有說明它的損失度量是用cross-entropy還是其他的。但現在我們用AdamOptimizer來最小化損失。
優化過程並不是在這裡執行。實際上,還沒計算任何東西,我們只是往TensorFlow圖中添加了優化器,以便後續操作。
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)n
性能度量
我們需要另外一些性能度量,來向用戶展示這個過程。
首先我們從神經網路輸出的y_pred中計算出預測的類別,它是一個包含10個元素的向量。類別數字是最大元素的索引。
y_pred_cls = tf.argmax(y_pred, dimension=1)n
然後創建一個布爾向量,用來告訴我們每張圖片的真實類別是否與預測類別相同。
correct_prediction = tf.equal(y_pred_cls, y_true_cls)n
上面的計算先將布爾值向量類型轉換成浮點型向量,這樣子False就變成0,True變成1,然後計算這些值的平均數,以此來計算分類的準確度。
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))n
Saver
為了保存神經網路的變數,我們創建一個稱為Saver-object的對象,它用來保存及恢復TensorFlow圖的所有變數。在這裡並未保存什麼東西,(保存操作)在後面的optimize()函數中完成。
saver = tf.train.Saver()n
由於(保存操作)常間隔著寫在(代碼)中,因此保存的文件通常稱為checkpoints。
這是用來保存或恢複數據的文件夾。
save_dir = checkpoints/n
如果文件夾不存在則創建。
if not os.path.exists(save_dir):n os.makedirs(save_dir)n
這是保存checkpoint文件的路徑。
save_path = os.path.join(save_dir, best_validation)n
運行TensorFlow
創建TensorFlow會話(session)
一旦創建了TensorFlow圖,我們需要創建一個TensorFlow會話,用來運行圖。
session = tf.Session()n
初始化變數
變數weights和biases在優化之前需要先進行初始化。我們寫一個簡單的封裝函數,後面會再次調用。
def init_variables():n session.run(tf.global_variables_initializer())n
運行函數來初始化變數。
init_variables()n
用來優化迭代的幫助函數
在訓練集中有50,000張圖。用這些圖像計算模型的梯度會花很多時間。因此我們利用隨機梯度下降的方法,它在優化器的每次迭代里只用到了一小部分的圖像。
如果內存耗盡導致電腦死機或變得很慢,你應該試著減少這些數量,但同時可能還需要更優化的迭代。
train_batch_size = 64n
每迭代100次下面的優化函數,會計算一次驗證集上的分類準確率。如果過了1000次迭代驗證準確率還是沒有提升,就停止優化。我們需要一些變數來跟蹤這個過程。
# Best validation accuracy seen so far.nbest_validation_accuracy = 0.0nn# Iteration-number for last improvement to validation accuracy.nlast_improvement = 0nn# Stop optimization if no improvement found in this many iterations.nrequire_improvement = 1000n
函數用來執行一定數量的優化迭代,以此來逐漸改善網路層的變數。在每次迭代中,會從訓練集中選擇新的一批數據,然後TensorFlow在這些訓練樣本上執行優化。每100次迭代會列印出(信息),同時計算驗證準確率,如果效果有提升的話會將它保存至文件。
# Counter for total number of iterations performed so far.ntotal_iterations = 0nndef optimize(num_iterations):n # Ensure we update the global variables rather than local copies.n global total_iterationsn global best_validation_accuracyn global last_improvementnn # Start-time used for printing time-usage below.n start_time = time.time()nn for i in range(num_iterations):nn # Increase the total number of iterations performed.n # It is easier to update it in each iteration becausen # we need this number several times in the following.n total_iterations += 1nn # Get a batch of training examples.n # x_batch now holds a batch of images andn # y_true_batch are the true labels for those images.n x_batch, y_true_batch = data.train.next_batch(train_batch_size)nn # Put the batch into a dict with the proper namesn # for placeholder variables in the TensorFlow graph.n feed_dict_train = {x: x_batch,n y_true: y_true_batch}nn # Run the optimizer using this batch of training data.n # TensorFlow assigns the variables in feed_dict_trainn # to the placeholder variables and then runs the optimizer.n session.run(optimizer, feed_dict=feed_dict_train)nn # Print status every 100 iterations and after last iteration.n if (total_iterations % 100 == 0) or (i == (num_iterations - 1)):nn # Calculate the accuracy on the training-batch.n acc_train = session.run(accuracy, feed_dict=feed_dict_train)nn # Calculate the accuracy on the validation-set.n # The function returns 2 values but we only need the first.n acc_validation, _ = validation_accuracy()nn # If validation accuracy is an improvement over best-known.n if acc_validation > best_validation_accuracy:n # Update the best-known validation accuracy.n best_validation_accuracy = acc_validationn n # Set the iteration for the last improvement to current.n last_improvement = total_iterationsnn # Save all variables of the TensorFlow graph to file.n saver.save(sess=session, save_path=save_path)nn # A string to be printed below, shows improvement found.n improved_str = *n else:n # An empty string to be printed below.n # Shows that no improvement was found.n improved_str = n n # Status-message for printing.n msg = "Iter: {0:>6}, Train-Batch Accuracy: {1:>6.1%}, Validation Acc: {2:>6.1%} {3}"nn # Print it.n print(msg.format(i + 1, acc_train, acc_validation, improved_str))nn # If no improvement found in the required number of iterations.n if total_iterations - last_improvement > require_improvement:n print("No improvement found in a while, stopping optimization.")nn # Break out from the for-loop.n breaknn # Ending time.n end_time = time.time()nn # Difference between start and end-times.n time_dif = end_time - start_timenn # Print the time-usage.n print("Time usage: " + str(timedelta(seconds=int(round(time_dif)))))n
用來繪製錯誤樣本的幫助函數
函數用來繪製測試集中被誤分類的樣本。
def plot_example_errors(cls_pred, correct):n # This function is called from print_test_accuracy() below.nn # cls_pred is an array of the predicted class-number forn # all images in the test-set.nn # correct is a boolean array whether the predicted classn # is equal to the true class for each image in the test-set.nn # Negate the boolean array.n incorrect = (correct == False)n n # Get the images from the test-set that have beenn # incorrectly classified.n images = data.test.images[incorrect]n n # Get the predicted classes for those images.n cls_pred = cls_pred[incorrect]nn # Get the true classes for those images.n cls_true = data.test.cls[incorrect]n n # Plot the first 9 images.n plot_images(images=images[0:9],n cls_true=cls_true[0:9],n cls_pred=cls_pred[0:9])n
繪製混淆(confusion)矩陣的幫助函數
def plot_confusion_matrix(cls_pred):n # This is called from print_test_accuracy() below.nn # cls_pred is an array of the predicted class-number forn # all images in the test-set.nn # Get the true classifications for the test-set.n cls_true = data.test.clsn n # Get the confusion matrix using sklearn.n cm = confusion_matrix(y_true=cls_true,n y_pred=cls_pred)nn # Print the confusion matrix as text.n print(cm)nn # Plot the confusion matrix as an image.n plt.matshow(cm)nn # Make various adjustments to the plot.n plt.colorbar()n tick_marks = np.arange(num_classes)n plt.xticks(tick_marks, range(num_classes))n plt.yticks(tick_marks, range(num_classes))n plt.xlabel(Predicted)n plt.ylabel(True)nn # Ensure the plot is shown correctly with multiple plotsn # in a single Notebook cell.n plt.show()n
計算分類的幫助函數
這個函數用來計算圖像的預測類別,同時返回一個代表每張圖像分類是否正確的布爾數組。
由於計算可能會耗費太多內存,就分批處理。如果你的電腦死機了,試著降低batch-size。
# Split the data-set in batches of this size to limit RAM usage.nbatch_size = 256nndef predict_cls(images, labels, cls_true):n # Number of images.n num_images = len(images)nn # Allocate an array for the predicted classes whichn # will be calculated in batches and filled into this array.n cls_pred = np.zeros(shape=num_images, dtype=np.int)nn # Now calculate the predicted classes for the batches.n # We will just iterate through all the batches.n # There might be a more clever and Pythonic way of doing this.nn # The starting index for the next batch is denoted i.n i = 0nn while i < num_images:n # The ending index for the next batch is denoted j.n j = min(i + batch_size, num_images)nn # Create a feed-dict with the images and labelsn # between index i and j.n feed_dict = {x: images[i:j, :],n y_true: labels[i:j, :]}nn # Calculate the predicted class using TensorFlow.n cls_pred[i:j] = session.run(y_pred_cls, feed_dict=feed_dict)nn # Set the start-index for the next batch to then # end-index of the current batch.n i = jnn # Create a boolean array whether each image is correctly classified.n correct = (cls_true == cls_pred)nn return correct, cls_predn
計算測試集上的預測類別。
def predict_cls_test():n return predict_cls(images = data.test.images,n labels = data.test.labels,n cls_true = data.test.cls)n
計算驗證集上的預測類別。
def predict_cls_validation():n return predict_cls(images = data.validation.images,n labels = data.validation.labels,n cls_true = data.validation.cls)n
分類準確率的幫助函數
這個函數計算了給定布爾數組的分類準確率,布爾數組表示每張圖像是否被正確分類。比如, cls_accuracy([True, True, False, False, False]) = 2/5 = 0.4。
def cls_accuracy(correct):n # Calculate the number of correctly classified images.n # When summing a boolean array, False means 0 and True means 1.n correct_sum = correct.sum()nn # Classification accuracy is the number of correctly classifiedn # images divided by the total number of images in the test-set.n acc = float(correct_sum) / len(correct)nn return acc, correct_sumn
計算驗證集上的分類準確率。
def validation_accuracy():n # Get the array of booleans whether the classifications are correctn # for the validation-set.n # The function returns two values but we only need the first.n correct, _ = predict_cls_validation()n n # Calculate the classification accuracy and return it.n return cls_accuracy(correct)n
展示性能的幫助函數
函數用來列印測試集上的分類準確率。
為測試集上的所有圖片計算分類會花費一段時間,因此我們直接從這個函數里調用上面的函數,這樣就不用每個函數都重新計算分類。
def print_test_accuracy(show_example_errors=False,n show_confusion_matrix=False):nn # For all the images in the test-set,n # calculate the predicted classes and whether they are correct.n correct, cls_pred = predict_cls_test()nn # Classification accuracy and the number of correct classifications.n acc, num_correct = cls_accuracy(correct)n n # Number of images being classified.n num_images = len(correct)nn # Print the accuracy.n msg = "Accuracy on Test-Set: {0:.1%} ({1} / {2})"n print(msg.format(acc, num_correct, num_images))nn # Plot some examples of mis-classifications, if desired.n if show_example_errors:n print("Example errors:")n plot_example_errors(cls_pred=cls_pred, correct=correct)nn # Plot the confusion matrix, if desired.n if show_confusion_matrix:n print("Confusion Matrix:")n plot_confusion_matrix(cls_pred=cls_pred)n
繪製卷積權重的幫助函數
def plot_conv_weights(weights, input_channel=0):n # Assume weights are TensorFlow ops for 4-dim variablesn # e.g. weights_conv1 or weights_conv2.nn # Retrieve the values of the weight-variables from TensorFlow.n # A feed-dict is not necessary because nothing is calculated.n w = session.run(weights)nn # Print mean and standard deviation.n print("Mean: {0:.5f}, Stdev: {1:.5f}".format(w.mean(), w.std()))n n # Get the lowest and highest values for the weights.n # This is used to correct the colour intensity acrossn # the images so they can be compared with each other.n w_min = np.min(w)n w_max = np.max(w)nn # Number of filters used in the conv. layer.n num_filters = w.shape[3]nn # Number of grids to plot.n # Rounded-up, square-root of the number of filters.n num_grids = math.ceil(math.sqrt(num_filters))n n # Create figure with a grid of sub-plots.n fig, axes = plt.subplots(num_grids, num_grids)nn # Plot all the filter-weights.n for i, ax in enumerate(axes.flat):n # Only plot the valid filter-weights.n if i<num_filters:n # Get the weights for the ith filter of the input channel.n # The format of this 4-dim tensor is determined by then # TensorFlow API. See Tutorial #02 for more details.n img = w[:, :, input_channel, i]nn # Plot image.n ax.imshow(img, vmin=w_min, vmax=w_max,n interpolation=nearest, cmap=seismic)n n # Remove ticks from the plot.n ax.set_xticks([])n ax.set_yticks([])n n # Ensure the plot is shown correctly with multiple plotsn # in a single Notebook cell.n plt.show()n
優化之前的性能
測試集上的準確度很低,這是由於模型只做了初始化,並沒做任何優化,所以它只是對圖像做隨機分類。
print_test_accuracy()n
Accuracy on Test-Set: 8.5% (849 / 10000)
卷積權重是隨機的,但也很難把它與下面優化過的權重區分開來。這裡也展示了平均值和標準差,因此我們可以看看是否有差別。
plot_conv_weights(weights=weights_conv1)n
Mean: 0.00880, Stdev: 0.28635
10,000次優化迭代後的性能
現在我們進行了10,000次優化迭代,並且,當經過1000次迭代驗證集上的性能卻沒有提升時就停止優化。
星號 * 代表驗證集上的分類準確度有提升。
optimize(num_iterations=10000)n
Iter: 100, Train-Batch Accuracy: 84.4%, Validation Acc: 85.2% *
Iter: 200, Train-Batch Accuracy: 92.2%, Validation Acc: 91.5% *Iter: 300, Train-Batch Accuracy: 95.3%, Validation Acc: 93.7% *Iter: 400, Train-Batch Accuracy: 92.2%, Validation Acc: 94.3% *Iter: 500, Train-Batch Accuracy: 98.4%, Validation Acc: 94.7% *Iter: 600, Train-Batch Accuracy: 93.8%, Validation Acc: 94.7%
Iter: 700, Train-Batch Accuracy: 98.4%, Validation Acc: 95.6% *Iter: 800, Train-Batch Accuracy: 100.0%, Validation Acc: 96.3% *Iter: 900, Train-Batch Accuracy: 98.4%, Validation Acc: 96.4% *Iter: 1000, Train-Batch Accuracy: 100.0%, Validation Acc: 96.9% *Iter: 1100, Train-Batch Accuracy: 96.9%, Validation Acc: 97.0% *Iter: 1200, Train-Batch Accuracy: 93.8%, Validation Acc: 97.0% *Iter: 1300, Train-Batch Accuracy: 92.2%, Validation Acc: 97.2% *Iter: 1400, Train-Batch Accuracy: 100.0%, Validation Acc: 97.3% *Iter: 1500, Train-Batch Accuracy: 96.9%, Validation Acc: 97.4% *Iter: 1600, Train-Batch Accuracy: 100.0%, Validation Acc: 97.7% *Iter: 1700, Train-Batch Accuracy: 100.0%, Validation Acc: 97.8% *Iter: 1800, Train-Batch Accuracy: 98.4%, Validation Acc: 97.7% Iter: 1900, Train-Batch Accuracy: 98.4%, Validation Acc: 98.1% *Iter: 2000, Train-Batch Accuracy: 95.3%, Validation Acc: 98.0% Iter: 2100, Train-Batch Accuracy: 98.4%, Validation Acc: 97.9% Iter: 2200, Train-Batch Accuracy: 100.0%, Validation Acc: 98.0% Iter: 2300, Train-Batch Accuracy: 96.9%, Validation Acc: 98.1% Iter: 2400, Train-Batch Accuracy: 93.8%, Validation Acc: 98.1% Iter: 2500, Train-Batch Accuracy: 98.4%, Validation Acc: 98.2% *Iter: 2600, Train-Batch Accuracy: 98.4%, Validation Acc: 98.0% Iter: 2700, Train-Batch Accuracy: 98.4%, Validation Acc: 98.0% Iter: 2800, Train-Batch Accuracy: 96.9%, Validation Acc: 98.1% Iter: 2900, Train-Batch Accuracy: 96.9%, Validation Acc: 98.2% Iter: 3000, Train-Batch Accuracy: 98.4%, Validation Acc: 98.2% Iter: 3100, Train-Batch Accuracy: 100.0%, Validation Acc: 98.1% Iter: 3200, Train-Batch Accuracy: 100.0%, Validation Acc: 98.3% *Iter: 3300, Train-Batch Accuracy: 98.4%, Validation Acc: 98.4% *Iter: 3400, Train-Batch Accuracy: 95.3%, Validation Acc: 98.0% Iter: 3500, Train-Batch Accuracy: 98.4%, Validation Acc: 98.3% Iter: 3600, Train-Batch Accuracy: 100.0%, Validation Acc: 98.5% *Iter: 3700, Train-Batch Accuracy: 98.4%, Validation Acc: 98.3% Iter: 3800, Train-Batch Accuracy: 96.9%, Validation Acc: 98.1% Iter: 3900, Train-Batch Accuracy: 96.9%, Validation Acc: 98.5% Iter: 4000, Train-Batch Accuracy: 100.0%, Validation Acc: 98.4% Iter: 4100, Train-Batch Accuracy: 100.0%, Validation Acc: 98.5% Iter: 4200, Train-Batch Accuracy: 100.0%, Validation Acc: 98.3% Iter: 4300, Train-Batch Accuracy: 100.0%, Validation Acc: 98.6% *Iter: 4400, Train-Batch Accuracy: 96.9%, Validation Acc: 98.4% Iter: 4500, Train-Batch Accuracy: 98.4%, Validation Acc: 98.5% Iter: 4600, Train-Batch Accuracy: 98.4%, Validation Acc: 98.5% Iter: 4700, Train-Batch Accuracy: 98.4%, Validation Acc: 98.4% Iter: 4800, Train-Batch Accuracy: 100.0%, Validation Acc: 98.8% *Iter: 4900, Train-Batch Accuracy: 100.0%, Validation Acc: 98.8% Iter: 5000, Train-Batch Accuracy: 98.4%, Validation Acc: 98.6% Iter: 5100, Train-Batch Accuracy: 98.4%, Validation Acc: 98.6% Iter: 5200, Train-Batch Accuracy: 100.0%, Validation Acc: 98.6% Iter: 5300, Train-Batch Accuracy: 96.9%, Validation Acc: 98.5% Iter: 5400, Train-Batch Accuracy: 98.4%, Validation Acc: 98.7% Iter: 5500, Train-Batch Accuracy: 98.4%, Validation Acc: 98.6% Iter: 5600, Train-Batch Accuracy: 100.0%, Validation Acc: 98.4% Iter: 5700, Train-Batch Accuracy: 100.0%, Validation Acc: 98.6% Iter: 5800, Train-Batch Accuracy: 100.0%, Validation Acc: 98.7% No improvement found in a while, stopping optimization.Time usage: 0:00:28
print_test_accuracy(show_example_errors=True,n show_confusion_matrix=True)n
Accuracy on Test-Set: 98.4% (9842 / 10000)
Example errors:Confusion Matrix:[[ 974 0 0 0 0 1 2 0 2 1] [ 0 1127 2 2 0 0 1 0 3 0] [ 4 4 1012 4 1 0 0 3 4 0] [ 0 0 1 1005 0 2 0 0 2 0] [ 1 0 1 0 961 0 2 0 3 14] [ 2 0 1 6 0 880 1 0 1 1] [ 4 2 0 1 3 4 942 0 2 0] [ 1 1 8 6 1 0 0 994 1 16] [ 6 0 1 4 1 1 1 2 952 6] [ 3 3 0 3 2 2 0 0 1 995]]
現在卷積權重是經過優化的。將這些與上面的隨機權重進行對比。它們看起來基本相同。實際上,一開始我以為程序有bug,因為優化前後的權重看起來差不多。
但保存圖像,並排著比較它們(你可以右鍵保存)。你會發現兩者細微的不同。
平均值和標準差也有一點變化,因此優化過的權重肯定是不一樣的。
plot_conv_weights(weights=weights_conv1)n
Mean: 0.02895, Stdev: 0.29949
再次初始化變數
再一次用隨機值來初始化所有神經網路變數。
init_variables()n
這意味著神經網路又是完全隨機地對圖片進行分類,由於只是隨機的猜測所以分類準確率很低。
print_test_accuracy()n
Accuracy on Test-Set: 13.4% (1341 / 10000)
卷積權重看起來應該與上面的不同。
plot_conv_weights(weights=weights_conv1)n
Mean: -0.01086, Stdev: 0.28023
恢復最好的變數
重新載入在優化過程中保存到文件的所有變數。
saver.restore(sess=session, save_path=save_path)n
使用之前保存的那些變數,分類準確率又提高了。
注意,準確率與之前相比可能會有細微的上升或下降,這是由於文件里的變數是用來最大化驗證集上的分類準確率,但在保存文件之後,又進行了1000次的優化迭代,因此這是兩組有輕微不同的變數的結果。有時這會導致測試集上更好或更差的表現。
print_test_accuracy(show_example_errors=True,n show_confusion_matrix=True)n
Accuracy on Test-Set: 98.3% (9826 / 10000)
Example errors:Confusion Matrix:[[ 973 0 0 0 0 0 2 0 3 2] [ 0 1124 2 2 0 0 3 0 4 0] [ 2 1 1027 0 0 0 0 1 1 0] [ 0 0 1 1005 0 2 0 0 2 0] [ 0 0 3 0 968 0 1 0 3 7] [ 2 0 1 9 0 871 3 0 3 3] [ 4 2 1 0 3 3 939 0 6 0] [ 1 3 19 11 2 0 0 972 2 18] [ 6 0 3 5 1 0 1 2 951 5] [ 3 3 0 1 4 1 0 0 1 996]]
卷積權重也與之前顯示的圖幾乎相同,同樣,由於多做了1000次優化迭代,二者並非完全一樣。
plot_conv_weights(weights=weights_conv1)n
Mean: 0.02792, Stdev: 0.29822
關閉TensorFlow會話
現在我們已經用TensorFlow完成了任務,關閉session,釋放資源。
# This has been commented out in case you want to modify and experimentn# with the Notebook without having to restart it.n# session.close()n
總結
這篇教程描述了在TensorFlow中如何保存並恢復神經網路的變數。它有許多用處。比如,當你用神經網路來識別圖像的時候,只需要訓練網路一次,然後可以在其他電腦上完成開發工作。
checkpoint的另一個用處是,如果你有一個非常大的神經網路和數據集,就可能會在中間保存一些checkpoints來避免電腦死機,這樣,你就可以在最近的checkpoint開始優化而不是重頭開始。
本教程也展示了如何用驗證集來進行所謂的Early Stopping,如果沒有降低驗證錯誤優化就會終止。這在神經網路出現過擬合以及開始學習訓練集中的雜訊時很有用;不過這在本教程的神經網路和MNIST數據集中並不是什麼大問題。
還有一個有趣的現象,最優化時卷積權重(或者叫濾波)的變化很小,即使網路的性能從隨機猜測提高到近乎完美的分類。奇怪的是隨機的權重好像已經足夠好了。你認為為什麼會有這種現象?
練習
下面使一些可能會讓你提升TensorFlow技能的一些建議練習。為了學習如何更合適地使用TensorFlow,實踐經驗是很重要的。
在你對這個Notebook進行修改之前,可能需要先備份一下。
- 在經過1000次迭代而性能沒有提升時,優化就終止了。這樣夠嗎?你能想出一個更好地進行Early Stopping的方法么?試著實現它。
- 如果checkpoint文件已經存在了,載入它而不是做優化。
- 每100次優化迭代保存一次checkpoint。通過saver.latest_checkpoint()取回最新的(保存點)。為什麼保存多個checkpoints而不是只保存最近的一個?
- 試著改變神經網路,比如添加其他層。當你從不同的網路中重新載入變數會出現什麼問題?
- 用plot_conv_weights()函數在優化前後畫出第二個卷積層的權重。它們幾乎相同的么?
- 你認為優化過的卷積權重為什麼與隨機初始化的(權重)幾乎相同?
- 不看源碼,自己重寫程序。
- 向朋友解釋程序如何工作。
推薦閱讀:
※TensorFlow-dev-summit:那些TensorFlow上好玩的和黑科技
※在Docker中部署使用Tensorflow && Docker基本用法介紹
※YJango的TensorFlow整體把握
※深度神經網路學習筆記
TAG:深度学习DeepLearning | TensorFlow | 卷积神经网络CNN |