TextGAN 代碼閱讀筆記

目的是從0生成和 訓練輸入文本 近似的文本

源碼位置: guotong1988/TextGAN

D和G主要部分都是tensorflow.contrib.legacy_seq2seq.rnn_decoder

gen_model.py里loss的定義:

self.outputs = self.generate()self.loss = 1. - tf.reduce_mean(disc_model.discriminate_wv(self.outputs))

可見隨著梯度下降,會使 生成的文本 越來越接近 真實文本

disc_model.py里loss的定義:

predicted_classes_text = self.discriminate_text(self.input_data_text)self.loss_text = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=predicted_classes_text, labels=np.ones((self.args.batch_size, 1), dtype=np.float32)))generated_wv = gen_model.generate()predicted_classes_wv = self.discriminate_wv(generated_wv)self.loss_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=predicted_classes_wv, labels=np.zeros((self.args.batch_size, 1), dtype=np.float32)))self.loss = .5 * self.loss_gen + .5 * self.loss_text

可見隨著梯度下降,會使 D認為的 生成的文本 越來越接近 假數據,會使 D認為的 真實文本 越來越接近 真數據

由G和D的loss定義可見其對抗性

train.py里

G的feed的數據:

gen_model_latent_state = np.random.uniform(-1., 1., (args.batch_size, args.latent_size)).astype(float32)gen_feed = {disc_model.input_data_text: np.zeros_like(x), gen_model.input_data: x, gen_model.latent_state: gen_model_latent_state}

其中x是真實文本

D的feed的數據:

disc_feed = {disc_model.input_data_text: x, gen_model.input_data: x, gen_model.latent_state: gen_model_latent_state}

StandardModel的目的是regularize the embedding

輸入是標準的x和shift一詞的x,雖然target是shift一詞的x,也是生成它本身的目的

推薦閱讀:

016【NLP】word2vec新手項目
Learning to Skim Text 閱讀筆記
Joint Extraction of Entities and RelationsBased on a Novel Tagging Scheme
關於語音交互的不得不知的技術知識
為何讀不懂你的那個TA

TAG:自然語言處理 |