帶你進入GAN(三)WGAN-gp

帶你進入GAN(三)WGAN-gp

來自專欄機器學習與深度學習2 人贊了文章

WGAN

關於WGAN,請看這篇文章:zhuanlan.zhihu.com/p/25

WGAN-GP (improved wgan)

論文:arxiv.org/abs/1704.0002

在WGAN中,需要進行截斷,clip(w,-c,c)。在實驗中發現: 對於very deep WAGN,它不容易收斂。

大致原因如下:

  1. 實驗發現最後大多數的權重都在-0.01 和0.01上,這就意味了大部分權重只有兩個可能數,這太簡單了,作為一個深度神經網路來說,這實在是對它強大的擬合能力的浪費。
  2. 實驗發現容易導致梯度消失或梯度爆炸。我們知道,w1w1的梯度的一個因子是w2w2, w2w2的梯度的因子是w3w3. 若 w>0.1w>0.1,它被clip後,w變小了,導致梯度變小,多層傳遞,容易導致梯度消失;反之,若w<?0.1w<?0.1, 被clip後,w變大,則容易導致梯度爆炸。

於是,很快提出了更加優秀的解決方法:Gradient penalty (梯度懲罰)

論文:「Gulrajani I, Ahmed F, Arjovsky M, et al. Improved Training of Wasserstein GANs. 2017.」

思路

前面提到,Lipschitz限制是要求判別器的梯度不超過K,那我們何不直接設置一個額外的loss項來體現這一點呢?

比如 :

( || 
abla_x D(x) ||_p - K )^2

解釋

既然判別器希望儘可能拉大真假樣本的分數差距,那自然是希望梯度越大越好,變化幅度越大越好,所以判別器在充分訓練之後,其梯度norm其實就會是在K附近。所以期望梯度norm離K越近越好。

簡單的把k定為1,則D的優化目標改進為:

E_{x sim P_{data}}[D(x)] - E_{x sim P_G}[D(x)] - lambda E_{hat x sim P_{hat x}}[(||
abla_{hat x}D(hat x)||_2-1)^2]

細節說明

  1. 對於 P_{hat x} ,這裡並不是對整個樣本空間採樣,而只對 P_{data} , P_G 之間的空間採樣。

    因為對整個樣本空間採樣,所需要的樣本是指數級的,是實際無法做到的,所以,論文作者就非常機智地提出,其實沒必要在整個樣本空間上施加Lipschitz限制,只要抓住生成樣本、真實樣本以及夾在它們中間的區域就行了。
  2. 由於我們是對每個樣本獨立地施加梯度懲罰,所以判別器的模型架構中不能使用Batch Normalization。因為它會引入同個batch中不同樣本的相互依賴關係,如果需要的話,可以選擇其他normalization方法,如Layer Normalization、Weight NormalizationInstance Normalization,這些方法就不會引入樣本之間的依賴。論文推薦的是Layer normalization。

具體訓練步驟:

代碼實現

對dcn的batch norm進行改造,如果是wgan-gp,則去掉bn層。

def dis_dcn(self, x_image, training=False): c=self.img_dim[2] depths = [c]+[32, 64] activation_fn=self.leaky_relu # x_image.shape=[28,28,1] with tf.variable_scope(conv1): outputs = tf.layers.conv2d(x_image, depths[1], [5, 5], strides=(2, 2), padding=SAME) if not self.gp: outputs = tf.layers.batch_normalization(outputs, training=training) outputs = activation_fn(outputs ) # 14*14*32 with tf.variable_scope(conv2): outputs = tf.layers.conv2d(outputs, depths[2], [5, 5], strides=(2, 2), padding=SAME) if not self.gp: outputs = tf.layers.batch_normalization(outputs, training=training) outputs = activation_fn(outputs ) # 7*7*64 with tf.variable_scope(fc): # flatten reshape = self.flatten(outputs) # fc outputs = tf.layers.dense(reshape, 1, name=outputs) return outputs

loss定義

def get_loss(self, d_logits_r, d_logits_f): d_loss = -(tf.reduce_mean(d_logits_r) - tf.reduce_mean(d_logits_f)) g_loss = -tf.reduce_mean(d_logits_f) return d_loss, g_loss

wgan的train_op

def get_train_op(self,d_loss,g_loss): """ 對d網路的w進行clip """ all_vars = tf.trainable_variables() g_vars = [v for v in all_vars if gen in v.name] d_vars = [v for v in all_vars if dis in v.name] for v in all_vars: print(v) opt_d = tf.train.RMSPropOptimizer(self.learning_rate) opt_g = tf.train.RMSPropOptimizer(self.learning_rate) # train_op_d = opt_d.minimize(d_loss, var_list=d_vars) train_op_g = opt_g.minimize(g_loss, var_list=g_vars) with tf.control_dependencies(train_op_d): # clip net ds weight cc = 0.01 clip_ops = [tf.assign(d, tf.clip_by_value(d, -cc, cc)) for d in d_vars] clip_d_w = tf.group(*clip_ops) return clip_d_w,train_op_g

wgan-gp的train_op

def get_train_op_gp(self,d_loss,g_loss): """ WGAN-GP """ all_vars = tf.trainable_variables() g_vars = [v for v in all_vars if gen in v.name] d_vars = [v for v in all_vars if dis in v.name] for v in all_vars: print(v) opt_d = tf.train.AdamOptimizer(self.learning_rate,beta1=0.5, beta2=0.9) opt_g = tf.train.AdamOptimizer(self.learning_rate,beta1=0.5, beta2=0.9) # 生成生成圖片與真實圖片之間的插值 alpha = tf.random_uniform(shape=[self.batch_size, 1,1,1], minval=0., maxval=1. ) # interpolates = self.x + (alpha * (self.G_img - self.x)) interpolates = alpha * self.x + (1-alpha)*self.G_img # 計算插值的梯度 # tf.gradients return a list, [0]表示取裡面的元素,是個tensor grads = tf.gradients(self.discriminator(interpolates,reuse=True), [interpolates])[0] # 計算插值梯度的懲罰,類L2正則 slopes = tf.sqrt(tf.reduce_sum(tf.square(tf.reshape(grads,[self.batch_size,-1])), 1)) lambda1 = 10 gradient_penalty = lambda1 * tf.reduce_mean((slopes - 1.) ** 2) # d_loss += gradient_penalty # train_op_d = opt_d.minimize(d_loss, var_list=d_vars) train_op_g = opt_g.minimize(g_loss, var_list=g_vars) return train_op_d, train_op_g

build_model

def build_model(self): # y 表示label # self.label = tf.placeholder(tf.float32, shape=[None, self.y_dim], name=label) # x_img表示真實圖片 self.x = tf.placeholder(tf.float32, shape=[None, self.img_dim[0],self.img_dim[1], self.img_dim[2]], name=real_img) # z 表示隨機雜訊 self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name=noise) # 由生成器生成圖像 G self.G_img = self.generator(self.z,reuse=False,if_train=self.if_train ) #shape=[batch_size,28,28,1] # 真實圖像送入判別器 d_logits_r = self.discriminator(self.x, reuse=False,if_train=self.if_train) # 生成圖像送入辨別器 d_logits_f = self.discriminator(self.G_img, reuse=True,if_train=self.if_train) # loss self.d_loss,self.g_loss = self.get_loss(d_logits_r,d_logits_f) # if self.gp: self.train_op_d,self.train_op_g = self.get_train_op_gp(self.d_loss,self.g_loss) else: self.train_op_d, self.train_op_g = self.get_train_op_(self.d_loss, self.g_loss) self.saver = tf.train.Saver() # if self.if_train: init = tf.global_variables_initializer() self.sess.run(init) else: ckpt = tf.train.get_checkpoint_state(self.model_path) self.saver.restore(self.sess, ckpt.model_checkpoint_path) print(model load done)

推薦閱讀:

機器學習必須熟悉的演算法之word2vector(二)
結合google facets進行機器學習數據可視化
如何從一個「行業小白」走進「機器學習、CV、深度學習等」這個行業,並取得一定的成績?
KDD 2017 參會報告
無需編程,即刻擁有你的機器學習模型

TAG:神經網路 | 機器學習 | 深度學習DeepLearning |