貓狗大戰(1)處理數據
來自專欄 《人工智慧》
參考博客:https://github.com/kevin28520/My-TensorFlow-tutorials/blob/master/01%20cats%20vs%20dogs/input_data.py
tensorflow 實戰 貓狗大戰(一)訓練自己的數據 - CSDN博客
import tensorflow as tfimport numpy as npimport osfile_dir = J:\移動硬碟\數據集\train\def get_files(file_dir): cats = []#12500 label_cats = [] dogs = []#12500 label_dogs = [] for file in os.listdir(file_dir): name = file.split(.)#一個列表 if name[0] == cat: cats.append(file_dir + file)#file:圖片的名字 label_cats.append(0) else: dogs.append(file_dir + file) label_dogs.append(1) print(There are %d cats
There are %d dogs % (len(cats),len(dogs))) image_list = np.concatenate((cats,dogs),axis=0)#把貓狗列表拼接在一起,25000,前12500貓,後12500狗 label_list = np.hstack((label_cats,label_dogs)) temp = np.array([image_list,label_list]) temp = temp.transpose()#[J:\移動硬碟\數據集\train\cat.249.jpg 0],打上了標籤 np.random.shuffle(temp)#隨機打亂 image_list = list(temp[:,0])#list將數組變成列表 label_list = list(temp[:,1]) label_list = [int(i) for i in label_list]#將字元串變為整數 return image_list,label_listdef get_batch(image,label,image_w,image_h,batch_size,capacity): image = tf.cast(image,tf.string)#將列錶轉化成tf能夠識別的格式 label = tf.cast(label,tf.int32) input_queue = tf.train.slice_input_producer([image,label]) image_contents = tf.read_file(input_queue[0]) label = input_queue[1] image = tf.image.decode_jpeg(image_contents,channels=3)# image = tf.image.resize_image_with_crop_or_pad(image,image_w,image_h) image = tf.image.resize_images(image,[image_h,image_w],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) image = tf.cast(image,tf.uint8)#顯示正常圖片# image = tf.cast(image,tf.float32)#用於訓練# image = tf.image.per_image_standardization(image)#標準化數據 image_batch,label_batch = tf.train.batch([image,label], batch_size=batch_size, num_threads=64, #線程 capacity=capacity) return image_batch,label_batch#顯示效果,可視化圖片import matplotlib.pyplot as pltbatch_size = 10capacity = 256image_w = 200image_h = 200image_list,label_list = get_files(file_dir)image_batch,label_batch = get_batch(image_list,label_list,image_w,image_h,batch_size,capacity)with tf.Session() as sess: i = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: while not coord.should_stop() and i < 10 : #9個batch_size img,label = sess.run([image_batch,label_batch]) for j in np.arange(batch_size): print(label: %d % label[j]) plt.imshow(img[j,:,:,:]) plt.show() i += 1 except tf.errors.OutOfRangeError: print(done!) finally: coord.request_stop() coord.join(threads)
推薦閱讀:
※tensorflow的共享變數,tf.Variable(),tf.get_variable(),tf.Variable_scope(),tf.name_scope()聯繫與區別
※學習筆記TF058:人臉識別
※Tensorflow op放置策略
※怎樣使用tensorflow導入已經下載好的mnist數據集?
TAG:TensorFlow |