深度學習一行一行敲cyclegan-tensorflow版(網路訓練)

對源碼進行逐句解析,盡量說的很細緻。

歡迎各位看官捧場!

源碼地址:CycleGAN-tensorflow

論文地址:[1703.10593] Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

這是在main.py文件

parser = argparse.ArgumentParser(description=)nparser.add_argument(--dataset_dir, dest=dataset_dir, default=horse2zebra, help=path of the dataset)n......nargs = parser.parse_args()n

  1. dest - 解析後的參數名稱
  2. default - 不指定參數時的默認值
  3. type - 命令行參數應該被轉換成的類型【argparse - Python 之旅 - 極客學院Wiki】

需要說明的是訓練中所涉及的一些參數:

  1. 學習率:0.0002
  2. beta10.5(雖然不太清楚這個參數的影響)
  3. L1_lambda:10(weight on L1 term in objective)
  4. use_lsgan:True(不太清楚這個是用來幹嘛的)
  5. ngf,ndf:64(of G and D filters in first conv layer)

if not os.path.exists(args.checkpoint_dir):n os.makedirs(args.checkpoint_dir)n.......n

  1. os.path.exists:判斷是否存在該文件
  2. os.makedirs:創造文件

model.py的train方法

self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) n.minimize(self.d_loss, var_list=self.d_vars)n

  1. :是繼續的意思【&在python中什麼意思?】
  2. var_list:調整var_list里的變數來減少d_loss

init_op = tf.global_variables_initializer()nself.sess.run(init_op)nself.writer = tf.summary.FileWriter("./logs", self.sess.graph)n

  1. 前兩句是初始化,程序中所有的變數。
  2. FileWriter見附錄

import timenstart_time = time.time()n

記錄當前時間,經常用到

if args.continue_train:n if self.load(args.checkpoint_dir):n print(" [*] Load SUCCESS")n else:n print(" [!] Load failed...")n

如果是接這上一次的訓練,則需要載入模型。load函數見附錄。

dataA = glob(./datasets/{}/*.*.format(self.dataset_dir + /trainA))ndataB = glob(./datasets/{}/*.*.format(self.dataset_dir + /trainB))nnp.random.shuffle(dataA)nnp.random.shuffle(dataB)n

得到所有需要訓練文件的路徑。glob真的是好用

  1. *代表0個或多個字元
  2. 該方法返回所有匹配的文件路徑列表【python中的一個好用的文件名操作模塊glob - CSDN博客】分析見附錄

batch_idxs = min(min(len(dataA), len(dataB)), args.train_size) // self.batch_sizen

得到究竟有多少個batch

lr = args.lr if epoch < args.epoch_step else args.lr*(args.epoch-epoch)/(args.epoch-args.epoch_step)n

  1. 變數甲 if 事件丙 else 變數乙:如果事件丙發生則返回變數甲,否則返回變數乙
  2. 當epoch大於設定的args.epoch_step時,學習率隨著epoch的增加而減少,在之前保持不變

batch_files = list(zip(dataA[idx * self.batch_size:(idx + 1) * self.batch_size],n dataB[idx * self.batch_size:(idx + 1) * self.batch_size]))nbatch_images = [load_train_data(batch_file, args.load_size, args.fine_size) for batch_file in batch_files]nbatch_images = np.array(batch_images).astype(np.float32)n

得到每一次訓練batch的數據

zip和list一起用:返回的是一個交叉的路徑list【Python zip() 函數 | 菜鳥教程】

需要注意的是:batch_images是將AB兩張圖片在深度上拼接在返回的。

# Update G network and record fake outputsnfake_A, fake_B, _, summary_str = self.sess.run(n [self.fake_A, self.fake_B, self.g_optim, self.g_sum],n feed_dict={self.real_data: batch_images, self.lr: lr})nself.writer.add_summary(summary_str, counter)n[fake_A, fake_B] = self.pool([fake_A, fake_B])n

  1. [self.fake_A, self.fake_B, self.g_optim, self.g_sum]:sess.run一下得到假A,假B,跑一下G的優化器,跑一下相關的loss
  2. {self.real_data: batch_images, self.lr: lr}:這是與tf.placehold對應的,在run的時候以字典的形式傳入需要的值
  3. 將各種loss加入到add_summary裡面[self.pool見附錄的Imagepool函數]

# Update D networkn_, summary_str = self.sess.run(n [self.d_optim, self.d_sum],n feed_dict={self.real_data: batch_images,n self.fake_A_sample: fake_A,n self.fake_B_sample: fake_B,n self.lr: lr})nself.writer.add_summary(summary_str, counter)n

功能和上面相似

print(("Epoch: [%2d] [%4d/%4d] time: %4.4f" % (n epoch, idx, batch_idxs, time.time() - start_time)))nnif np.mod(counter, args.print_freq) == 1:n self.sample_model(args.sample_dir, epoch, idx)nnif np.mod(counter, args.save_freq) == 2:n self.save(args.checkpoint_dir, counter)n

  1. 列印出相關訓練信息
  2. np.mod:取余(和np.remainder功能一致)【numpy.mod - NumPy v1.13 Manual】

用到的個函數解析見附錄


附錄:

The FileWriter class provides a mechanism to create an event file nin a given directory and add summaries and events to it. nThe class updates the file contents asynchronously. nThis allows a training program to call methods to add data to the file directlynfrom the training loop, without slowing down training.n

翻譯:FileWriter類提供了一種機制來在給定的目錄中創建一個事件文件並向其中添加摘要和事件。該類非同步更新文件內容。 這允許訓練程序調用直接從訓練循環向文件添加數據的方法,而不減慢訓練。

If you pass a Graph to the constructor it is added to the event file.n(This is equivalent to calling add_graph() later).nTensorBoard will pick the graph from the file and display it graphically nso you can interactively explore the graph you built.nYou will usually pass the graph from the session in which you launched it:n

翻譯:如果將一個Graph傳遞給構造函數,它將被添加到事件文件中。 (這相當於稍後調用add_graph())。TensorBoard將從文件中選擇圖形並以圖形方式顯示,以便您可以互動式地瀏覽您創建的圖形。您通常會從您啟動它的會話中傳遞圖表


def load(self, checkpoint_dir):n print(" [*] Reading checkpoint...")nn model_dir = "%s_%s" % (self.dataset_dir, self.image_size)n checkpoint_dir = os.path.join(checkpoint_dir, model_dir)nn ckpt = tf.train.get_checkpoint_state(checkpoint_dir)n if ckpt and ckpt.model_checkpoint_path:n ckpt_name = os.path.basename(ckpt.model_checkpoint_path)n self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))n return Truen else:n return Falsen

  1. %:格式化輸出【Python補充05 字元串格式化 (%操作符),而format是更簡單的格式化輸出python字元串格式化方法 format函數的使用】 self.dataset_dir =『horse2zebra』,self.image_size=『256』

2.self.saver.restore:從指定路徑中恢復出變數

import tensorflow as tfnimport osndataset_dir = horse2zebranimage_size = 256ncheckpoint_dir = ./checkpointnmodel_dir = "%s_%s" % (dataset_dir, image_size)ncheckpoint_dir = os.path.join(checkpoint_dir, model_dir)nckpt = tf.train.get_checkpoint_state(checkpoint_dir)nckpt_name = os.path.basename(ckpt.model_checkpoint_path)nprint(checkpoint_dir,1)nprint(ckpt,2)nprint(ckpt.model_checkpoint_path,3)nprint(os.path.basename(ckpt.model_checkpoint_path),4)nprint(os.path.join(checkpoint_dir, ckpt_name),5)n輸出:n./checkpointhorse2zebra_256 1nmodel_checkpoint_path: "./checkpointhorse2zebra_256cyclegan.model-56002"nall_model_checkpoint_paths: "./checkpointhorse2zebra_256cyclegan.model-52002"nall_model_checkpoint_paths: "./checkpointhorse2zebra_256cyclegan.model-53002"nall_model_checkpoint_paths: "./checkpointhorse2zebra_256cyclegan.model-54002"nall_model_checkpoint_paths: "./checkpointhorse2zebra_256cyclegan.model-55002"nall_model_checkpoint_paths: "./checkpointhorse2zebra_256cyclegan.model-56002"n 2n./checkpointhorse2zebra_256cyclegan.model-56002 3ncyclegan.model-56002 4n./checkpointhorse2zebra_256cyclegan.model-56002 5n

總結起來就是從存儲的checkpoint中,找到最後一次存儲的文件路徑

  1. tf.train.get_checkpoint_state:得到該文件下,所有存儲的文件對象
  2. ckpt.model_checkpoint_path:得到最新的文件路徑
  3. os.path.basename:得到文件的後綴名
  4. tf.train.Saver()

Saves and restores variables.n用來存儲和讀取變數的,而FileWriter是用來存儲訓練信息的n


dataA = glob(./datasets/{}/*.*.format(dataset_dir + /trainA))nprint(dataA)n輸出:n[./datasets/horse2zebra/trainAn02381460_1001.jpg, ./datasets/horse2zebra/trainAn02381460_1002.jpg, ........n


batch_files = list(zip(dataA[idx * batch_size:(idx + 1) * batch_size],n dataB[idx * batch_size:(idx + 1) * batch_size]))nprint(batch_files)n輸出:n[(./datasets/horse2zebra/trainAn02381460_1001.jpg, ./datasets/horse2zebra/trainBn02391049_10007.jpg), (./datasets/horse2zebra/trainAn02381460_1002.jpg, ./datasets/horse2zebra/trainBn02391049_10027.jpg)]n


下面是load_train_data函數:

def load_train_data(image_path, load_size=286, fine_size=256, is_testing=False):n img_A = imread(image_path[0])n img_B = imread(image_path[1])n#元組是沒有shape的,要拿到裡面的數據就是按照它的位置拿。每一條數據佔據一個位置n if not is_testing:n img_A = scipy.misc.imresize(img_A, [load_size, load_size])n img_B = scipy.misc.imresize(img_B, [load_size, load_size])n#將圖片規整到指定的大小n h1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))n w1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))n img_A = img_A[h1:h1+fine_size, w1:w1+fine_size]n img_B = img_B[h1:h1+fine_size, w1:w1+fine_size]n#1.np.random.uniform:從一個均勻分布[low,high)中隨機採樣n#2.np.ceil:向上取整nn if np.random.random() > 0.5:n img_A = np.fliplr(img_A)n img_B = np.fliplr(img_B)n#3.np.fliplr:翻轉n#fine_size要小於load_size,它的作用是從中截取一部分圖作為訓練n else:n img_A = scipy.misc.imresize(img_A, [fine_size, fine_size])n img_B = scipy.misc.imresize(img_B, [fine_size, fine_size])nn img_A = img_A/127.5 - 1.n img_B = img_B/127.5 - 1.nn img_AB = np.concatenate((img_A, img_B), axis=2)n # img_AB shape: (fine_size, fine_size, input_c_dim + output_c_dim),即在深度上合併n return img_ABn

原圖(fine_size為200,load_size256)

最終圖(fine_size為200,load_size256)


class ImagePool(object):n def __init__(self, maxsize=50):n self.maxsize = maxsizen self.num_img = 0n self.images = []nn def __call__(self, image):n if self.maxsize <= 0:n return imagen if self.num_img < self.maxsize:n self.images.append(image)n self.num_img += 1n return imagen if np.random.rand() > 0.5:n idx = int(np.random.rand()*self.maxsize)n tmp1 = copy.copy(self.images[idx])[0]n self.images[idx][0] = image[0]n idx = int(np.random.rand()*self.maxsize)n tmp2 = copy.copy(self.images[idx])[1]n self.images[idx][1] = image[1]n return [tmp1, tmp2]n else:n return imagen

?搞不明白這個函數想幹嘛,但這個仍然是非常重要的。因為D網路就是根據這個圖片工作的

答:這個主要是用來選擇圖片的,我們放入到D網路訓練的圖片不是G網路剛剛生成的,而是從過去生成的所有圖片中挑選的。現在好多交替訓練都有採用此方法的趨勢。

  1. self.maxsize:最大的緩存圖片的個數
  2. self.num_img:當前緩存圖片的個數
  3. self.images:用來存放緩存的圖片的

call函數:

  1. 當最大的緩存個數小於等於0時,直接返回當前的圖片
  2. 當目前已經緩存的圖片個數小於需要緩存的圖片時,直接返回當前圖片
  3. 當已經緩存的圖片達到要求時,則隨機的選擇是返回當前圖片還是已經產生的圖片【numpy.random.randn()與rand()的區別 - CSDN博客】【註:當返回過去緩存值時,會將過去的值替換成現在的值】

def sample_model(self, sample_dir, epoch, idx):n dataA = glob(./datasets/{}/*.*.format(self.dataset_dir + /testA))n dataB = glob(./datasets/{}/*.*.format(self.dataset_dir + /testB))n np.random.shuffle(dataA)n np.random.shuffle(dataB)n batch_files = list(zip(dataA[:self.batch_size], dataB[:self.batch_size]))n sample_images = [load_train_data(batch_file, is_testing=True) for batch_file in batch_files]n sample_images = np.array(sample_images).astype(np.float32)nn fake_A, fake_B = self.sess.run(n [self.fake_A, self.fake_B],n feed_dict={self.real_data: sample_images}n )n save_images(fake_A, [self.batch_size, 1],n ./{}/A_{:02d}_{:04d}.jpg.format(sample_dir, epoch, idx))n save_images(fake_B, [self.batch_size, 1],n ./{}/B_{:02d}_{:04d}.jpg.format(sample_dir, epoch, idx))n

這個的功能是:從新打亂數據,並保存一張訓練結果

def save_images(images, size, image_path):n return imsave(inverse_transform(images), size, image_path)nndef imsave(images, size, path):n return scipy.misc.imsave(path, merge(images, size))nndef merge(images, size):n h, w = images.shape[1], images.shape[2]n img = np.zeros((h * size[0], w * size[1], 3))n for idx, image in enumerate(images):n i = idx % size[1]n j = idx // size[1]n img[j*h:j*h+h, i*w:i*w+w, :] = imagenn return imgnndef inverse_transform(images):n return (images+1.)/2.n

merge(images, size)函數是:將一個batchsize裡面的圖片拼在一起顯示

inverse_transform函數是:恢復原來的圖片:

img_A = img_A / 127.5 - 1.nprint(img_A)ncv2.imshow(img_A,img_A)ncv2.waitKey(0)nimg_A = (img_A+1.)/2.nprint(img_A)ncv2.imshow(img_A,img_A)ncv2.waitKey(0)n輸出:n [[ 0.03529412 0.04313725 -0.00392157]n [-0.18431373 -0.17647059 -0.23921569]n [-0.25490196 -0.23921569 -0.3254902 ]nn [ 0.61568627 0.54509804 0.46666667]n [ 0.61568627 0.54509804 0.46666667]n [ 0.54509804 0.4745098 0.39607843]]n

img_A / 127.5 - 1.圖

(img_A+1.)/2.圖


def save(self, checkpoint_dir, step):n model_name = "cyclegan.model"n model_dir = "%s_%s" % (self.dataset_dir, self.image_size)n checkpoint_dir = os.path.join(checkpoint_dir, model_dir)nn if not os.path.exists(checkpoint_dir):n os.makedirs(checkpoint_dir)nn self.saver.save(self.sess,n os.path.join(checkpoint_dir, model_name),n global_step=step)n

self.saver.save:與restore的用法相似,global_step=step這個是指定當前的訓練步數

最後放一張效果圖:

原圖

訓練後

歡迎關注公眾號:huangxiaobai880

https://www.zhihu.com/video/924615852177887232
推薦閱讀:

【源碼眾讀】進度報告,這不是一個假活動
【Vlpp源碼閱讀】集合篇(一)
如何閱讀Tomcat源代碼?
Android Framework源碼當中哪些類有必要進行深入學習?

TAG:深度学习DeepLearning | 机器学习 | 源码阅读 |