機器學習進階筆記之八 | TensorFlow與中文手寫漢字識別
引言
TensorFlow是Google基於DistBelief進行研發的第二代人工智慧學習系統,被廣泛用於語音識別或圖像識別等多項機器深度學習領域。其命名來源於本身的運行原理。Tensor(張量)意味著N維數組,Flow(流)意味著基於數據流圖的計算,TensorFlow代表著張量從圖象的一端流動到另一端計算過程,是將複雜的數據結構傳輸至人工智慧神經網中進行分析和處理的過程。
TensorFlow完全開源,任何人都可以使用。可在小到一部智能手機、大到數千台數據中心伺服器的各種設備上運行。
『機器學習進階筆記』系列將深入解析TensorFlow系統的技術實踐,從零開始,由淺入深,與大家一起走上機器學習的進階之路。
Goal
本文目標是利用TensorFlow做一個簡單的圖像分類器,在比較大的數據集上,儘可能高效地做圖像相關處理,從Train,Validation到Inference,是一個比較基本的Example, 從一個基本的任務學習如果在TensorFlow下做高效地圖像讀取,基本的圖像處理,整個項目很簡單,但其中有一些trick,在實際項目當中有很大的好處, 比如絕對不要一次讀入所有的 的數據到內存(儘管在Mnist這類級別的例子上經常出現)…
最開始看到是這篇blog裡面的TensorFlow練習22: 手寫漢字識別, 但是這篇文章只用了140訓練與測試,試了下代碼 很快,但是當擴展到所有的時,發現32g的內存都不夠用,這才注意到原文中都是用numpy,會先把所有的數據放入到內存,但這個不必須的,無論在MXNet還是TensorFlow中都是不必 須的,MXNet使用的是DataIter,會在程序運行的過程中非同步讀取數據,TensorFlow也是這樣的,TensorFlow封裝了高級的api,用來做數據的讀取,比如TFRecord,還有就是從filenames中讀取, 來非同步讀取文件,然後做shuffle batch,再feed到模型的Graph中來做模型參數的更新。具體在tf如何做數據的讀取可以看看reading data in tensorflow
這裡我會拿到所有的數據集來做訓練與測試,算作是對斗大的熊貓上面那篇文章的一個擴展。
Batch Generate
數據集來自於中科院自動化研究所,感謝分享精神!!!具體下載:
wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zipnwget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zipn
解壓後發現是一些gnt文件,然後用了斗大的熊貓裡面的代碼,將所有文件都轉化為對應label目錄下的所有png的圖片。(注意在HWDB1.1trn_gnt.zip解壓後是alz文件,需要再次解壓 我在mac沒有找到合適的工具,windows上有alz的解壓工具)。
import osnimport numpy as npnimport structnfrom PIL import Imagennndata_dir = ../datantrain_data_dir = os.path.join(data_dir, HWDB1.1trn_gnt)ntest_data_dir = os.path.join(data_dir, HWDB1.1tst_gnt)nnndef read_from_gnt_dir(gnt_dir=train_data_dir):n def one_file(f):n header_size = 10n while True:n header = np.fromfile(f, dtype=uint8, count=header_size)n if not header.size: breakn sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)n tagcode = header[5] + (header[4]<<8)n width = header[6] + (header[7]<<8)n height = header[8] + (header[9]<<8)n if header_size + width*height != sample_size:n breakn image = np.fromfile(f, dtype=uint8, count=width*height).reshape((height, width))n yield image, tagcoden for file_name in os.listdir(gnt_dir):n if file_name.endswith(.gnt):n file_path = os.path.join(gnt_dir, file_name)n with open(file_path, rb) as f:n for image, tagcode in one_file(f):n yield image, tagcodenchar_set = set()nfor _, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):n tagcode_unicode = struct.pack(>H, tagcode).decode(gb2312)n char_set.add(tagcode_unicode)nchar_list = list(char_set)nchar_dict = dict(zip(sorted(char_list), range(len(char_list))))nprint len(char_dict)nimport picklenf = open(char_dict, wb)npickle.dump(char_dict, f)nf.close()ntrain_counter = 0ntest_counter = 0nfor image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):n tagcode_unicode = struct.pack(>H, tagcode).decode(gb2312)n im = Image.fromarray(image)n dir_name = ../data/train/ + %0.5d%char_dict[tagcode_unicode]n if not os.path.exists(dir_name):n os.mkdir(dir_name)n im.convert(RGB).save(dir_name+/ + str(train_counter) + .png)n train_counter += 1nfor image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):n tagcode_unicode = struct.pack(>H, tagcode).decode(gb2312)n im = Image.fromarray(image)n dir_name = ../data/test/ + %0.5d%char_dict[tagcode_unicode]n if not os.path.exists(dir_name):n os.mkdir(dir_name)n im.convert(RGB).save(dir_name+/ + str(test_counter) + .png)n test_counter += 1n
處理好的數據,放到了雲盤,大家可以直接在我的雲盤來下載處理好的數據集HWDB1. 這裡說明下,char_dict是漢字和對應的數字label的記錄。
得到數據集後,就要考慮如何讀取了,一次用numpy讀入內存在很多小數據集上是可以行的,但是在稍微大點的數據集上內存就成了瓶頸,但是不要害怕,TensorFlow有自己的方法:
def batch_data(file_labels,sess, batch_size=128):n image_list = [file_label[0] for file_label in file_labels]n label_list = [int(file_label[1]) for file_label in file_labels]n print tag2 {0}.format(len(image_list))n images_tensor = tf.convert_to_tensor(image_list, dtype=tf.string)n labels_tensor = tf.convert_to_tensor(label_list, dtype=tf.int64)n input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor])nn labels = input_queue[1]n images_content = tf.read_file(input_queue[0])n # images = tf.image.decode_png(images_content, channels=1)n images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)n # images = images / 256n images = pre_process(images)n # print images.get_shape()n # one hotn labels = tf.one_hot(labels, 3755)n image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,min_after_dequeue=10000)n # print image_batch, image_batch.get_shape()nn coord = tf.train.Coordinator()n threads = tf.train.start_queue_runners(sess=sess, coord=coord)n return image_batch, label_batch, coord, threadsn
簡單介紹下,首先你需要得到所有的圖像的path和對應的label的列表,利用tf.convert_to_tensor轉換為對應的tensor, 利用tf.train.slice_input_producer將image_list ,label_list做一個slice處理,然後做圖像的讀取、預處理,以及label的one_hot表示,然後就是傳到tf.train.shuffle_batch產生一個個shuffle batch,這些就可以feed到你的 模型。 slice_input_producer和shuffle_batch這類操作內部都是基於queue,是一種非同步的處理方式,會在設備中開闢一段空間用作cache,不同的進程會分別一直往cache中塞數據 和取數據,保證內存或顯存的佔用以及每一個mini-batch不需要等待,直接可以從cache中獲取。
Data Augmentation
由於圖像場景不複雜,只是做了一些基本的處理,包括圖像翻轉,改變下亮度等等,這些在TensorFlow裡面有現成的api,所以盡量使用TensorFlow來做相關的處理:
def pre_process(images):n if FLAGS.random_flip_up_down:n images = tf.image.random_flip_up_down(images)n if FLAGS.random_flip_left_right:n images = tf.image.random_flip_left_right(images)n if FLAGS.random_brightness:n images = tf.image.random_brightness(images, max_delta=0.3)n if FLAGS.random_contrast:n images = tf.image.random_contrast(images, 0.8, 1.2)n new_size = tf.constant([FLAGS.image_size,FLAGS.image_size], dtype=tf.int32)n images = tf.image.resize_images(images, new_size)n return imagesn
Build Graph
這裡很簡單的構造了一個兩個卷積+一個全連接層的網路,沒有做什麼更深的設計,感覺意義不大,設計了一個dict,用來返回後面要用的所有op,還有就是為了方便再訓練中查看loss和accuracy, 沒有什麼特別的,很容易理解, labels 為None時 方便做inference。
def network(images, labels=None):n endpoints = {}n conv_1 = slim.conv2d(images, 32, [3,3],1, padding=SAME)n max_pool_1 = slim.max_pool2d(conv_1, [2,2],[2,2], padding=SAME)n conv_2 = slim.conv2d(max_pool_1, 64, [3,3],padding=SAME)n max_pool_2 = slim.max_pool2d(conv_2, [2,2],[2,2], padding=SAME)n flatten = slim.flatten(max_pool_2)n out = slim.fully_connected(flatten,3755, activation_fn=None)n global_step = tf.Variable(initial_value=0)n if labels is not None:n loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(out, labels))n train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss, global_step=global_step)n accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(labels, 1)), tf.float32))n tf.summary.scalar(loss, loss)n tf.summary.scalar(accuracy, accuracy)n merged_summary_op = tf.summary.merge_all()n output_score = tf.nn.softmax(out)n predict_val_top3, predict_index_top3 = tf.nn.top_k(output_score, k=3)nn endpoints[global_step] = global_stepn if labels is not None:n endpoints[labels] = labelsn endpoints[train_op] = train_opn endpoints[loss] = lossn endpoints[accuracy] = accuracyn endpoints[merged_summary_op] = merged_summary_opn endpoints[output_score] = output_scoren endpoints[predict_val_top3] = predict_val_top3n endpoints[predict_index_top3] = predict_index_top3n return endpointsn
Train
train函數包括從已有checkpoint中restore,得到step,快速恢復訓練過程,訓練主要是每一次得到mini-batch,更新參數,每隔eval_steps後做一次train batch的eval,每隔save_steps 後保存一次checkpoint。
def train():n sess = tf.Session()n file_labels = get_imagesfile(FLAGS.train_data_dir)n images, labels, coord, threads = batch_data(file_labels, sess)n endpoints = network(images, labels)n saver = tf.train.Saver()n sess.run(tf.global_variables_initializer())n train_writer = tf.train.SummaryWriter(./log + /train,sess.graph)n test_writer = tf.train.SummaryWriter(./log + /val)n start_step = 0n if FLAGS.restore:n ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)n if ckpt:n saver.restore(sess, ckpt)n print "restore from the checkpoint {0}".format(ckpt)n start_step += int(ckpt.split(-)[-1])n logger.info(:::Training Start:::)n try:n while not coord.should_stop():n # logger.info(step {0} start.format(i))n start_time = time.time()n _, loss_val, train_summary, step = sess.run([endpoints[train_op], endpoints[loss], endpoints[merged_summary_op], endpoints[global_step]])n train_writer.add_summary(train_summary, step)n end_time = time.time()n logger.info("the step {0} takes {1} loss {2}".format(step, end_time-start_time, loss_val))n if step > FLAGS.max_steps:n breakn # logger.info("the step {0} takes {1} loss {2}".format(i, end_time-start_time, loss_val))n if step % FLAGS.eval_steps == 1:n accuracy_val,test_summary, step = sess.run([endpoints[accuracy], endpoints[merged_summary_op], endpoints[global_step]])n test_writer.add_summary(test_summary, step)n logger.info(===============Eval a batch in Train data=======================)n logger.info( the step {0} accuracy {1}.format(step, accuracy_val))n logger.info(===============Eval a batch in Train data=======================)n if step % FLAGS.save_steps == 1:n logger.info(Save the ckpt of {0}.format(step))n saver.save(sess, os.path.join(FLAGS.checkpoint_dir, my-model), global_step=endpoints[global_step])n except tf.errors.OutOfRangeError:n # print "============train finished========="n logger.info(==================Train Finished================)n saver.save(sess, os.path.join(FLAGS.checkpoint_dir, my-model), global_step=endpoints[global_step])n finally:n coord.request_stop()n coord.join(threads)n sess.close()n
Graph
Loss and Accuracy
Validation
訓練完成之後,想對最終的模型在測試數據集上做一個評估,這裡我也曾經嘗試利用batch_data,將slice_input_producer中epoch設置為1,來做相關的工作,但是發現這裡無法和train 共用,會出現epoch無初始化值的問題(train中傳epoch為None),所以這裡自己寫了shuffle batch的邏輯,將測試集的images和labels通過feed_dict傳進到網路,得到模型的輸出, 然後做相關指標的計算:
def validation():n # it should be fixed by using placeholder with epoch num in train stagen sess = tf.Session()nn file_labels = get_imagesfile(FLAGS.test_data_dir)n test_size = len(file_labels)n print test_sizen val_batch_size = FLAGS.val_batch_sizen test_steps = test_size / val_batch_sizen print test_stepsn # images, labels, coord, threads= batch_data(file_labels, sess)n images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])n labels = tf.placeholder(dtype=tf.int32, shape=[None,3755])n # read batch images from file_labelsn # images_batch = np.zeros([128,64,64,1])n # labels_batch = np.zeros([128,3755])n # labels_batch[0][20] = 1n #n endpoints = network(images, labels)n saver = tf.train.Saver()n ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)n if ckpt:n saver.restore(sess, ckpt)n # logger.info("restore from the checkpoint {0}".format(ckpt))n # logger.info(Start validation)n final_predict_val = []n final_predict_index = []n groundtruth = []n for i in range(test_steps):n start = i* val_batch_sizen end = (i+1)*val_batch_sizen images_batch = []n labels_batch = []n labels_max_batch = []n logger.info(=======start validation on {0}/{1} batch=========.format(i, test_steps))n for j in range(start,end):n image_path = file_labels[j][0]n temp_image = Image.open(image_path).convert(L)n temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS)n temp_label = np.zeros([3755])n label = int(file_labels[j][1])n # print labeln temp_label[label] = 1n # print "====",np.asarray(temp_image).shapen labels_batch.append(temp_label)n # print "====",np.asarray(temp_image).shapen images_batch.append(np.asarray(temp_image)/255.0)n labels_max_batch.append(label)n # print images_batchn images_batch = np.array(images_batch).reshape([-1, 64, 64, 1])n labels_batch = np.array(labels_batch)n batch_predict_val, batch_predict_index = sess.run([endpoints[predict_val_top3],n endpoints[predict_index_top3]], feed_dict={images:images_batch, labels:labels_batch})n logger.info(=======validation on {0}/{1} batch end=========.format(i, test_steps))n final_predict_val += batch_predict_val.tolist()n final_predict_index += batch_predict_index.tolist()n groundtruth += labels_max_batchn sess.close()n return final_predict_val, final_predict_index, groundtruthn
在訓練20w個step之後,大概能達到在測試集上能夠達到:
相信如果在網路設計上多花點時間能夠在一定程度上提升accuracy和top 3 accuracy.有興趣的小夥伴們可以玩玩這個數據集。
Inference
def inference(image):n temp_image = Image.open(image).convert(L)n temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS)n sess = tf.Session()n logger.info(========start inference============)n images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])n endpoints = network(images)n saver = tf.train.Saver()n ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)n if ckpt:n saver.restore(sess, ckpt)n predict_val, predict_index = sess.run([endpoints[predict_val_top3],endpoints[predict_index_top3]], feed_dict={images:temp_image})n sess.close()n return final_predict_val, final_predict_indexn
運氣挺好,隨便找了張圖片就能準確識別出來
Summary
綜上,就是利用tensorflow做中文手寫識別的全部,從如何使用tensorflow內部的queue來有效讀入數據,到如何設計network, 到如何做train,validation,inference,珍格格流程比較清晰, 美中不足的是,原本打算是在訓練過程中,來對測試集做評估,但是在使用queue讀test_data_dir下的filenames,和train本身的好像有點問題,不過應該是可以解決的,我這裡就pass了。另外可能 還有一些可以改善的地方,比如感覺可以把batch data one hot的部分寫入到network,這樣,減緩在validation時內存會因為onehot的sparse開銷比較大。
感覺這個中文手寫漢字數據集價值很大,後面感覺會有好多可以玩的,比如
- 可以參考項亮大神的這篇文章端到端的OCR:驗證碼識別做定長的字元識別和不定長的字元識別,定長的基本原理是說,可以把最終輸出擴展為k個輸出, 每個值表示對應的字元label,這樣cnn模型在feature extract之後就可以自己去識別對應字元而無需人工切割;而LSTM+CTC來解決不定長的驗證碼,類似於將音頻解碼為漢字
- 最近GAN特別火,感覺可以考慮用這個數據來做某個字的生成,和text2img那個項目text-to-image
這部分的代碼都在我的github上tensorflow-101,有遇到相關功能,想參考代碼的可以去上面找找,沒準就能解決你們遇到的一些小問題.
——————
相關閱讀推薦:
機器學習進階筆記之七 | MXnet初體驗
機器學習進階筆記之六 | 深入理解Fast Neural Style
機器學習進階筆記之五 | 深入理解VGGResidual Network
機器學習進階筆記之四 | 深入理解GoogLeNet
機器學習進階筆記之三 | 深入理解Alexnet
機器學習進階筆記之二 | 深入理解Neural Style
機器學習進階筆記之一 | TensorFlow安裝與入門
本文由『UCloud內核與虛擬化研發團隊』提供。關於作者:
Burness(@段石石 ), UCloud平台研發中心深度學習研發工程師,tflearn Contributor & tensorflow Contributor,做過電商推薦、精準化營銷相關演算法工作,專註於分散式深度學習框架、計算機視覺演算法研究,平時喜歡玩玩演算法,研究研究開源的項目,偶爾也會去一些數據比賽打打醬油,生活中是個極客,對新技術、新技能痴迷。
你可以在Github上找到他:http://hacker.duanshishi.com/
「UCloud機構號」將獨家分享雲計算領域的技術洞見、行業資訊以及一切你想知道的相關訊息。
歡迎提問&求關注 o(*////▽////*)q~
以上。
推薦閱讀:
※基於tensorflow的最簡單的強化學習入門-part1.5: 基於上下文老虎機問題(Contextual Bandits)
※功率密度成深度學習設計難題,數據中心市場展現新機遇
※在科學的危機下踏浪前行
※Alpha Go 的影響