標籤:

cifar10 (1) 讀取數據

參考博客:github.com/kevin28520/M

tf.slice用法tensorflow學習(三):操作圖片的tf.slice()函數

TF Boys (TensorFlow Boys ) 養成記(二): TensorFlow 數據讀取

Tensorflow樣例代碼分析cifar10 - 追憶次風暴 - 博客園

tensorflow數據讀取機制:十圖詳解tensorflow數據讀取機制(附代碼)

import matplotlib.pyplot as pltimport tensorflow as tfimport numpy as npimport os def read_cifar10(data_dir,is_train,batch_size,shuffle): img_width = 32 img_height = 32 img_depth = 3 label_bytes = 1 image_bytes = img_width*img_height*img_depth with tf.name_scope(input): if is_train: filenames = [os.path.join(data_dir,data_batch_%d %ii) for ii in np.arange(1,6)] else: filenames = [os.path.join(data_dir,test_batch)] filename_queue = tf.train.string_input_producer(filenames)#string_input_producer會產生一個文件名隊列 reader = tf.FixedLengthRecordReader(label_bytes+image_bytes)#文件讀取器,從文件中讀取固定長度的位元組 key,value = reader.read(filename_queue)#從文件隊列中,讀取鍵值對,圖像,標籤 record_bytes = tf.decode_raw(value,tf.uint8)#解碼器,將一個字元串轉換為一個uint8的張量 label = tf.slice(record_bytes,[0],[label_bytes]) label = tf.cast(label,tf.int32)#轉換數據類型 image_raw = tf.slice(record_bytes,[label_bytes],[image_bytes]) image_raw = tf.reshape(image_raw,[img_depth,img_height,img_width]) image = tf.transpose(image_raw,(1,2,0))#convert from D/H/W to H/W/D image = tf.cast(image,tf.float32) image = tf.random_crop(image,[24,24,3])#隨機裁剪圖片 image = tf.image.random_flip_left_right(image)#隨機左右翻轉 image = tf.image.random_brightness(image,max_delta=63)#隨機改變亮度 image = tf.image.random_contrast(image,lower=0.2,upper=1.8)#隨機改變對比度 image = tf.image.per_image_standardization(image)#標準化 print(image) if shuffle: images,label_batch = tf.train.shuffle_batch( [image,label], batch_size = batch_size, num_threads = 16, capacity = 2000, min_after_dequeue = 1500) else: images,label_batch = tf.train.batch( [image,label], batch_size = batch_size, num_threads = 16, capacity = 2000) return images,tf.reshape(label_batch,[batch_size])# #one-hot,用於訓練# n_classes = 10# label_batch = tf.one_hot(label_batch,depth=n_classes)# return images,tf.reshape(label_batch,[batch_size,n_classes]) data_dir = C:eclipseeclipseworkspacecifar10數據集整理cifar-10-batches-pyBATCH_SIZE=10image_batch,label_batch = read_cifar10(data_dir, is_train=True, batch_size=BATCH_SIZE, shuffle=True)with tf.Session() as sess: i = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord)#使用start_queue_runners之後,才會開始填充隊列 try: while not coord.should_stop() and i<1: img, label = sess.run([image_batch, label_batch]) print(img.shape) print(label) # just test one batch for j in np.arange(BATCH_SIZE): print(label: %d %label[j]) plt.imshow(img[j,:,:,:]) plt.show() i+=1 except tf.errors.OutOfRangeError: print(done!) finally: coord.request_stop() coord.join(threads)

推薦閱讀:

【博客存檔】風格畫之最後一彈MRF-CNN
TensorFlow博客翻譯——DeepMind轉向TensorFlow
AllenNLP 基於 PyTorch 的 NLP 套裝I(安裝)
NVIDIA Jetson TX2 安裝 TensoFlow 1.6 教程
TensorFlow 安裝官方教程:Ubuntu 安裝,Mac OS X 安裝,Windows 安裝

TAG:TensorFlow |