TensorFlow與中文手寫漢字識別

Goal

本文目標是利用TensorFlow做一個簡單的圖像分類器,在比較大的數據集上,儘可能高效地做圖像相關處理,從Train,Validation到Inference,是一個比較基本的Example, 從一個基本的任務學習如果在TensorFlow下做高效地圖像讀取,基本的圖像處理,整個項目很簡單,但其中有一些trick,在實際項目當中有很大的好處, 比如絕對不要一次讀入所有的 的數據到內存(儘管在Mnist這類級別的例子上經常出現)…

最開始看到是這篇blog裡面的TensorFlow練習22: 手寫漢字識別, 但是這篇文章只用了140訓練與測試,試了下代碼 很快,但是當擴展到所有的時,發現32g的內存都不夠用,這才注意到原文中都是用numpy,會先把所有的數據放入到內存,但這個不必須的,無論在MXNet還是TensorFlow中都是不必 須的,MXNet使用的是DataIter,會在程序運行的過程中非同步讀取數據,TensorFlow也是這樣的,TensorFlow封裝了高級的api,用來做數據的讀取,比如TFRecord,還有就是從filenames中讀取, 來非同步讀取文件,然後做shuffle batch,再feed到模型的Graph中來做模型參數的更新。具體在tf如何做數據的讀取可以看看reading data in tensorflow

這裡我會拿到所有的數據集來做訓練與測試,算作是對斗大的熊貓上面那篇文章的一個擴展。

Batch Generate

數據集來自於中科院自動化研究所,感謝分享精神!!!具體下載:

wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zipwget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip

解壓後發現是一些gnt文件,然後用了斗大的熊貓裡面的代碼,將所有文件都轉化為對應label目錄下的所有png的圖片。(注意在HWDB1.1trn_gnt.zip解壓後是alz文件,需要再次解壓 我在mac沒有找到合適的工具,windows上有alz的解壓工具)。

import osimport numpy as npimport structfrom PIL import Imagedata_dir = ../datatrain_data_dir = os.path.join(data_dir, HWDB1.1trn_gnt)test_data_dir = os.path.join(data_dir, HWDB1.1tst_gnt)def read_from_gnt_dir(gnt_dir=train_data_dir): def one_file(f): header_size = 10 while True: header = np.fromfile(f, dtype=uint8, count=header_size) if not header.size: break sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24) tagcode = header[5] + (header[4]<<8) width = header[6] + (header[7]<<8) height = header[8] + (header[9]<<8) if header_size + width*height != sample_size: break image = np.fromfile(f, dtype=uint8, count=width*height).reshape((height, width)) yield image, tagcode for file_name in os.listdir(gnt_dir): if file_name.endswith(.gnt): file_path = os.path.join(gnt_dir, file_name) with open(file_path, rb) as f: for image, tagcode in one_file(f): yield image, tagcodechar_set = set()for _, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir): tagcode_unicode = struct.pack(>H, tagcode).decode(gb2312) char_set.add(tagcode_unicode)char_list = list(char_set)char_dict = dict(zip(sorted(char_list), range(len(char_list))))print len(char_dict)import picklef = open(char_dict, wb)pickle.dump(char_dict, f)f.close()train_counter = 0test_counter = 0for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir): tagcode_unicode = struct.pack(>H, tagcode).decode(gb2312) im = Image.fromarray(image) dir_name = ../data/train/ + %0.5d%char_dict[tagcode_unicode] if not os.path.exists(dir_name): os.mkdir(dir_name) im.convert(RGB).save(dir_name+/ + str(train_counter) + .png) train_counter += 1for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir): tagcode_unicode = struct.pack(>H, tagcode).decode(gb2312) im = Image.fromarray(image) dir_name = ../data/test/ + %0.5d%char_dict[tagcode_unicode] if not os.path.exists(dir_name): os.mkdir(dir_name) im.convert(RGB).save(dir_name+/ + str(test_counter) + .png) test_counter += 1

處理好的數據,放到了雲盤,大家可以直接在我的雲盤來下載處理好的數據集HWDB1. 這裡說明下,char_dict是漢字和對應的數字label的記錄。

得到數據集後,就要考慮如何讀取了,一次用numpy讀入內存在很多小數據集上是可以行的,但是在稍微大點的數據集上內存就成了瓶頸,但是不要害怕,TensorFlow有自己的方法:

def batch_data(file_labels,sess, batch_size=128): image_list = [file_label[0] for file_label in file_labels] label_list = [int(file_label[1]) for file_label in file_labels] print tag2 {0}.format(len(image_list)) images_tensor = tf.convert_to_tensor(image_list, dtype=tf.string) labels_tensor = tf.convert_to_tensor(label_list, dtype=tf.int64) input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor]) labels = input_queue[1] images_content = tf.read_file(input_queue[0]) # images = tf.image.decode_png(images_content, channels=1) images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32) # images = images / 256 images = pre_process(images) # print images.get_shape() # one hot labels = tf.one_hot(labels, 3755) image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,min_after_dequeue=10000) # print image_batch, image_batch.get_shape() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) return image_batch, label_batch, coord, threads

簡單介紹下,首先你需要得到所有的圖像的path和對應的label的列表,利用tf.convert_to_tensor轉換為對應的tensor, 利用tf.train.slice_input_producer將image_list ,label_list做一個slice處理,然後做圖像的讀取、預處理,以及label的one_hot表示,然後就是傳到tf.train.shuffle_batch產生一個個shuffle batch,這些就可以feed到你的 模型。 slice_input_producer和shuffle_batch這類操作內部都是基於queue,是一種非同步的處理方式,會在設備中開闢一段空間用作cache,不同的進程會分別一直往cache中塞數據 和取數據,保證內存或顯存的佔用以及每一個mini-batch不需要等待,直接可以從cache中獲取。

Data Augmentation

由於圖像場景不複雜,只是做了一些基本的處理,包括圖像翻轉,改變下亮度等等,這些在TensorFlow裡面有現成的api,所以盡量使用TensorFlow來做相關的處理:

def pre_process(images): if FLAGS.random_flip_up_down: images = tf.image.random_flip_up_down(images) if FLAGS.random_flip_left_right: images = tf.image.random_flip_left_right(images) if FLAGS.random_brightness: images = tf.image.random_brightness(images, max_delta=0.3) if FLAGS.random_contrast: images = tf.image.random_contrast(images, 0.8, 1.2) new_size = tf.constant([FLAGS.image_size,FLAGS.image_size], dtype=tf.int32) images = tf.image.resize_images(images, new_size) return images

Build Graph

這裡很簡單的構造了一個兩個卷積+一個全連接層的網路,沒有做什麼更深的設計,感覺意義不大,設計了一個dict,用來返回後面要用的所有op,還有就是為了方便再訓練中查看loss和accuracy, 沒有什麼特別的,很容易理解, labels 為None時 方便做inference。

def network(images, labels=None): endpoints = {} conv_1 = slim.conv2d(images, 32, [3,3],1, padding=SAME) max_pool_1 = slim.max_pool2d(conv_1, [2,2],[2,2], padding=SAME) conv_2 = slim.conv2d(max_pool_1, 64, [3,3],padding=SAME) max_pool_2 = slim.max_pool2d(conv_2, [2,2],[2,2], padding=SAME) flatten = slim.flatten(max_pool_2) out = slim.fully_connected(flatten,3755, activation_fn=None) global_step = tf.Variable(initial_value=0) if labels is not None: loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(out, labels)) train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss, global_step=global_step) accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(labels, 1)), tf.float32)) tf.summary.scalar(loss, loss) tf.summary.scalar(accuracy, accuracy) merged_summary_op = tf.summary.merge_all() output_score = tf.nn.softmax(out) predict_val_top3, predict_index_top3 = tf.nn.top_k(output_score, k=3) endpoints[global_step] = global_step if labels is not None: endpoints[labels] = labels endpoints[train_op] = train_op endpoints[loss] = loss endpoints[accuracy] = accuracy endpoints[merged_summary_op] = merged_summary_op endpoints[output_score] = output_score endpoints[predict_val_top3] = predict_val_top3 endpoints[predict_index_top3] = predict_index_top3 return endpoints

Train

train函數包括從已有checkpoint中restore,得到step,快速恢復訓練過程,訓練主要是每一次得到mini-batch,更新參數,每隔eval_steps後做一次train batch的eval,每隔save_steps 後保存一次checkpoint。

def train(): sess = tf.Session() file_labels = get_imagesfile(FLAGS.train_data_dir) images, labels, coord, threads = batch_data(file_labels, sess) endpoints = network(images, labels) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) train_writer = tf.train.SummaryWriter(./log + /train,sess.graph) test_writer = tf.train.SummaryWriter(./log + /val) start_step = 0 if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print "restore from the checkpoint {0}".format(ckpt) start_step += int(ckpt.split(-)[-1]) logger.info(:::Training Start:::) try: while not coord.should_stop(): # logger.info(step {0} start.format(i)) start_time = time.time() _, loss_val, train_summary, step = sess.run([endpoints[train_op], endpoints[loss], endpoints[merged_summary_op], endpoints[global_step]]) train_writer.add_summary(train_summary, step) end_time = time.time() logger.info("the step {0} takes {1} loss {2}".format(step, end_time-start_time, loss_val)) if step > FLAGS.max_steps: break # logger.info("the step {0} takes {1} loss {2}".format(i, end_time-start_time, loss_val)) if step % FLAGS.eval_steps == 1: accuracy_val,test_summary, step = sess.run([endpoints[accuracy], endpoints[merged_summary_op], endpoints[global_step]]) test_writer.add_summary(test_summary, step) logger.info(===============Eval a batch in Train data=======================) logger.info( the step {0} accuracy {1}.format(step, accuracy_val)) logger.info(===============Eval a batch in Train data=======================) if step % FLAGS.save_steps == 1: logger.info(Save the ckpt of {0}.format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, my-model), global_step=endpoints[global_step]) except tf.errors.OutOfRangeError: # print "============train finished=========" logger.info(==================Train Finished================) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, my-model), global_step=endpoints[global_step]) finally: coord.request_stop() coord.join(threads) sess.close()

Graph

Loss and Accuracy

Validation

訓練完成之後,想對最終的模型在測試數據集上做一個評估,這裡我也曾經嘗試利用batch_data,將slice_input_producer中epoch設置為1,來做相關的工作,但是發現這裡無法和train 共用,會出現epoch無初始化值的問題(train中傳epoch為None),所以這裡自己寫了shuffle batch的邏輯,將測試集的images和labels通過feed_dict傳進到網路,得到模型的輸出, 然後做相關指標的計算:

def validation(): # it should be fixed by using placeholder with epoch num in train stage sess = tf.Session() file_labels = get_imagesfile(FLAGS.test_data_dir) test_size = len(file_labels) print test_size val_batch_size = FLAGS.val_batch_size test_steps = test_size / val_batch_size print test_steps # images, labels, coord, threads= batch_data(file_labels, sess) images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1]) labels = tf.placeholder(dtype=tf.int32, shape=[None,3755]) # read batch images from file_labels # images_batch = np.zeros([128,64,64,1]) # labels_batch = np.zeros([128,3755]) # labels_batch[0][20] = 1 # endpoints = network(images, labels) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) # logger.info("restore from the checkpoint {0}".format(ckpt)) # logger.info(Start validation) final_predict_val = [] final_predict_index = [] groundtruth = [] for i in range(test_steps): start = i* val_batch_size end = (i+1)*val_batch_size images_batch = [] labels_batch = [] labels_max_batch = [] logger.info(=======start validation on {0}/{1} batch=========.format(i, test_steps)) for j in range(start,end): image_path = file_labels[j][0] temp_image = Image.open(image_path).convert(L) temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS) temp_label = np.zeros([3755]) label = int(file_labels[j][1]) # print label temp_label[label] = 1 # print "====",np.asarray(temp_image).shape labels_batch.append(temp_label) # print "====",np.asarray(temp_image).shape images_batch.append(np.asarray(temp_image)/255.0) labels_max_batch.append(label) # print images_batch images_batch = np.array(images_batch).reshape([-1, 64, 64, 1]) labels_batch = np.array(labels_batch) batch_predict_val, batch_predict_index = sess.run([endpoints[predict_val_top3], endpoints[predict_index_top3]], feed_dict={images:images_batch, labels:labels_batch}) logger.info(=======validation on {0}/{1} batch end=========.format(i, test_steps)) final_predict_val += batch_predict_val.tolist() final_predict_index += batch_predict_index.tolist() groundtruth += labels_max_batch sess.close() return final_predict_val, final_predict_index, groundtruth

在訓練20w個step之後,大概能達到在測試集上能夠達到:

相信如果在網路設計上多花點時間能夠在一定程度上提升accuracy和top 3 accuracy.有興趣的小夥伴們可以玩玩這個數據集。

Inference

def inference(image): temp_image = Image.open(image).convert(L) temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS) sess = tf.Session() logger.info(========start inference============) images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1]) endpoints = network(images) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) predict_val, predict_index = sess.run([endpoints[predict_val_top3],endpoints[predict_index_top3]], feed_dict={images:temp_image}) sess.close() return final_predict_val, final_predict_index

運氣挺好,隨便找了張圖片就能準確識別出來

Summary

綜上,就是利用tensorflow做中文手寫識別的全部,從如何使用tensorflow內部的queue來有效讀入數據,到如何設計network, 到如何做train,validation,inference,珍格格流程比較清晰, 美中不足的是,原本打算是在訓練過程中,來對測試集做評估,但是在使用queue讀test_data_dir下的filenames,和train本身的好像有點問題,不過應該是可以解決的,我這裡就pass了。另外可能 還有一些可以改善的地方,比如感覺可以把batch data one hot的部分寫入到network,這樣,減緩在validation時內存會因為onehot的sparse開銷比較大。

感覺這個中文手寫漢字數據集價值很大,後面感覺會有好多可以玩的,比如

  • 可以參考項亮大神的這篇文章端到端的OCR:驗證碼識別做定長的字元識別和不定長的字元識別,定長的基本原理是說,可以把最終輸出擴展為k個輸出, 每個值表示對應的字元label,這樣cnn模型在feature extract之後就可以自己去識別對應字元而無需人工切割;而LSTM+CTC來解決不定長的驗證碼,類似於將音頻解碼為漢字
  • 最近GAN特別火,感覺可以考慮用這個數據來做某個字的生成,和text2img那個項目text-to-image

這部分的代碼都在我的github上tensorflow-101,有遇到相關功能,想參考代碼的可以去上面找找,沒準就能解決你們遇到的一些小問題.

Update in 2017.02.13

感謝@soloice的PR,使得代碼更簡潔, 並且修改了網路的結構,使得模型準確率上升很高, 最後top1和top3的結果:


推薦閱讀:

深入淺出Tensorflow(四):卷積神經網路
利用TensorFlow搞定知乎驗證碼之《讓你找中文倒轉漢字》
深入淺出Tensorflow(五):循環神經網路簡介
cs20si:tensorflow for research 學習筆記2

TAG:TensorFlow | 深度学习DeepLearning | 计算机视觉 |