使用 Tensorflow 構建生成式對抗網路(GAN)

近些年來,隨著生成模型的作用越來越大。我們可以使用生成模型做一些無中生有的事情,比如風格顏色填充、圖片高清化、圖片補全修復等等。本文要講的就是如何使用Tensorflow實現對抗生成網路。

1.作用

從細節上來看,生成模型可以做一些無中生有的事情,比如圖片高清化,智能填充(圖片被遮住一部分,修復完整),使用輪廓渲染栩栩如生的圖片等

2.發展限制

在最大似然估計及相關策略上,很多概率計算的模擬非常困難

將piecewise linear units(線性分段單元)用在生成模型上比較困難

環境安裝

安裝tensorflow

pip install tensorflow

或者使用更快的gpu版

pip install tensorflow-gpu

注意:首先你要知道你自己的顯卡有沒有gpu,不是n卡的同學可以直接放棄了,然後還需要安裝nvidia的CUDA Toolkit 和 cudnn。

安裝科學計算庫numpy,和作圖庫matplotlib

pip install numpypip install matplotlib

強烈推薦使用Anaconda作為python的集成環境,可以讓你在環境配置上面少走很多彎路。推薦使用wheel直接安裝。

生成式對抗網路(Generative Adversarial Net)

  • 對抗生成網路基本原理
  • 對抗生成網路基本流程
  • 對抗生成網路代碼實現

基本原理

假設我們有一個生成模型G(generator),他的目的是生成一張盡量真實的狗的圖片。於此同時我們有一個圖像判別的模型D(discriminator),他的目的是區分這張圖片中有沒有真實的狗。不斷的調整生成模型G和分類模型D,一直到分類模型不能將生成的圖片和真實狗的圖片區分出來為止,那麼我們就可以認為,生成模型生成了一張真實的狗的圖片。

我們的判別模型目標優化函數:

我們的生成模型目標優化函數:

基本流程

我們的GAN網路需要設計的結構如圖所示:

Python代碼

GAN 主要由兩個部分構成:generator 和 discriminator

  1. generator 主要作用是從訓練數據中產生相同分布的samples
  2. discriminator 則還是採用傳統的監督學習的方法

最終達到數學層面上的"納什均衡",全局最優

import matplotlib.pyplot as pltimport numpy as npimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

這一步我們引入了庫和mnist手寫數字數據。假如你還不了解mnist,可以點擊這裡

mnist image

mnist label

1.初始化訓練參數

num_steps = 1000 # 迭代次數batch_size = 128 # 批大小learning_rate = 0.0002 # 學習率image_dim = 784 # 28 * 28gen_hidden_dim = 256disc_hidden_dim = 256noise_dim = 100 # Glorot Initializationdef glorot_init(shape): return tf.random_normal(shape=shape, stddev=1. / tf.sqrt(shape[0] / 2.))

# 權重與偏移weights = { gen_hidden1: tf.Variable(glorot_init([noise_dim, gen_hidden_dim])), gen_out: tf.Variable(glorot_init([gen_hidden_dim, image_dim])), disc_hidden1: tf.Variable(glorot_init([image_dim, disc_hidden_dim])), disc_out: tf.Variable(glorot_init([disc_hidden_dim, 1])),}biases = { gen_hidden1: tf.Variable(tf.zeros([gen_hidden_dim])), gen_out: tf.Variable(tf.zeros([image_dim])), disc_hidden1: tf.Variable(tf.zeros([disc_hidden_dim])), disc_out: tf.Variable(tf.zeros([1])),}

設計權重向量(weight)和偏移(biases)

# Generator 生成網路def generator(x): hidden_layer = tf.matmul(x, weights[gen_hidden1]) hidden_layer = tf.add(hidden_layer, biases[gen_hidden1]) hidden_layer = tf.nn.relu(hidden_layer) out_layer = tf.matmul(hidden_layer, weights[gen_out]) out_layer = tf.add(out_layer, biases[gen_out]) out_layer = tf.nn.sigmoid(out_layer) return out_layer# Discriminator 辨別網路def discriminator(x): hidden_layer = tf.matmul(x, weights[disc_hidden1]) hidden_layer = tf.add(hidden_layer, biases[disc_hidden1]) hidden_layer = tf.nn.relu(hidden_layer) out_layer = tf.matmul(hidden_layer, weights[disc_out]) out_layer = tf.add(out_layer, biases[disc_out]) out_layer = tf.nn.sigmoid(out_layer) return out_layer# Network Inputs gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name=input_noise)disc_input = tf.placeholder(tf.float32, shape=[None, image_dim], name=disc_input)

設計生成模型(generator)和分類模型(discriminator),並且聲明輸入數據

# Build Generator Networkgen_sample = generator(gen_input)# Build 2 Discriminator Networks (one from noise input, one from generated samples)disc_real = discriminator(disc_input)disc_fake = discriminator(gen_sample)# Build Lossgen_loss = -tf.reduce_mean(tf.log(disc_fake))disc_loss = -tf.reduce_mean(tf.log(disc_real) + tf.log(1. - disc_fake))# 因為tensorflow梯度下降過程默認更新所有,所以我們手動設置參數gen_vars = [weights[gen_hidden1], weights[gen_out], biases[gen_hidden1], biases[gen_out]]disc_vars = [weights[disc_hidden1], weights[disc_out], biases[disc_hidden1], biases[disc_out]]# 創建優化器optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate)optimizer_disc = tf.train.AdamOptimizer(learning_rate=learning_rate)# 創建梯度下降train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)# Initialize the variablesinit = tf.global_variables_initializer()

建立判別網路,建立兩個生成網路,計算loss,梯度下降,初始化變數。

sess = tf.Session()sess.run(init)# Trainfor i in range(1, num_steps+1) batch_x, _ = mnist.train.next_batch(batch_size) # 隨機生成-1到1的浮點數 z = np.random.uniform(-1., 1., size=[batch_size, noise_dim]) # 訓練數據為樣本數據和隨機初始化數據(用於生成網路) feed_dict = {disc_input: batch_x, gen_input: z} _, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss], feed_dict=feed_dict) if i % 2000 == 0 or i == 1: print(Step %i: Generator Loss: %f, Discriminator Loss: %f % (i, gl, dl))

訓練模型

n = 6canvas = np.empty((28 * n, 28 * n))for i in range(n): z = np.random.uniform(-1., 1., size=[n, noise_dim]) g = sess.run(gen_sample, feed_dict={gen_input: z}) # 反轉顏色顯示 g = -1 * (g - 1) for j in range(n): # 在 matplotlib 中畫出結果 canvas[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = g[j].reshape([28, 28])plt.figure(figsize=(n, n))plt.imshow(canvas, origin="upper", cmap="gray")plt.show()

設置隨機輸入,反轉顏色,生成測試圖

上傳一張效果圖:


推薦閱讀:

感知機(PLA)
機器學習基石筆記4:機器學習可行性論證 上
如何測量這個世界的混亂-1-定義混亂
機器學習基本套路
對比了 18000 個 Python 項目,這 TOP45 值得學習!

TAG:機器學習 | 生成對抗網路GAN | AI技術 |