TextGAN 代碼閱讀筆記
目的是從0生成和 訓練輸入文本 近似的文本
源碼位置: guotong1988/TextGAN
D和G主要部分都是tensorflow.contrib.legacy_seq2seq.rnn_decoder
gen_model.py里loss的定義:
self.outputs = self.generate()self.loss = 1. - tf.reduce_mean(disc_model.discriminate_wv(self.outputs))
可見隨著梯度下降,會使 生成的文本 越來越接近 真實文本
disc_model.py里loss的定義:
predicted_classes_text = self.discriminate_text(self.input_data_text)self.loss_text = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=predicted_classes_text, labels=np.ones((self.args.batch_size, 1), dtype=np.float32)))generated_wv = gen_model.generate()predicted_classes_wv = self.discriminate_wv(generated_wv)self.loss_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=predicted_classes_wv, labels=np.zeros((self.args.batch_size, 1), dtype=np.float32)))self.loss = .5 * self.loss_gen + .5 * self.loss_text
可見隨著梯度下降,會使 D認為的 生成的文本 越來越接近 假數據,會使 D認為的 真實文本 越來越接近 真數據
由G和D的loss定義可見其對抗性
train.py里
G的feed的數據:
gen_model_latent_state = np.random.uniform(-1., 1., (args.batch_size, args.latent_size)).astype(float32)gen_feed = {disc_model.input_data_text: np.zeros_like(x), gen_model.input_data: x, gen_model.latent_state: gen_model_latent_state}
其中x是真實文本
D的feed的數據:
disc_feed = {disc_model.input_data_text: x, gen_model.input_data: x, gen_model.latent_state: gen_model_latent_state}
StandardModel的目的是regularize the embedding
輸入是標準的x和shift一詞的x,雖然target是shift一詞的x,也是生成它本身的目的
推薦閱讀:
※016【NLP】word2vec新手項目
※Learning to Skim Text 閱讀筆記
※Joint Extraction of Entities and RelationsBased on a Novel Tagging Scheme
※關於語音交互的不得不知的技術知識
※為何讀不懂你的那個TA
TAG:自然語言處理 |