Giter Site home page Giter Site logo

cr-gjx / leakgan Goto Github PK

View Code? Open in Web Editor NEW
577.0 24.0 180.0 7.13 MB

The codes of paper "Long Text Generation via Adversarial Training with Leaked Information" on AAAI 2018. Text generation using GAN and Hierarchical Reinforcement Learning.

Home Page: https://arxiv.org/abs/1709.08624

Python 100.00%
natural-language-processing generative-adversarial-network hierarchical-reinforcement-learning reinforcement-learning text-generation

leakgan's Introduction

LeakGAN

The code of research paper Long Text Generation via Adversarial Training with Leaked Information.

This paper has been accepted at the Thirty-Second AAAI Conference on Artificial Intelligence (AAAI-18).

Requirements

  • Tensorflow r1.2.1
  • Python 2.7
  • CUDA 7.5+ (For GPU)

Introduction

Automatically generating coherent and semantically meaningful text has many applications in machine translation, dialogue systems, image captioning, etc. Recently, by combining with policy gradient, Generative Adversarial Nets (GAN) that use a discriminative model to guide the training of the generative model as a reinforcement learning policy has shown promising results in text generation. However, the scalar guiding signal is only available after the entire text has been generated and lacks intermediate information about text structure during the generative process. As such, it limits its success when the length of the generated text samples is long (more than 20 words). In this project, we propose a new framework, called LeakGAN, to address the problem for long text generation. We allow the discriminative net to leak its own high-level extracted features to the generative net to further help the guidance. The generator incorporates such informative signals into all generation steps through an additional Manager module, which takes the extracted features of current generated words and outputs a latent vector to guide the Worker module for next-word generation. Our extensive experiments on synthetic data and various real-world tasks with Turing test demonstrate that LeakGAN is highly effective in long text generation and also improves the performance in short text generation scenarios. More importantly, without any supervision, LeakGAN would be able to implicitly learn sentence structures only through the interaction between Manager and Worker.

As the illustration of LeakGAN. We specifically introduce a hierarchical generator G, which consists of a high-level MANAGER module and a low-level WORKER module. The MANAGER is a long short term memory network (LSTM) and serves as a mediator. In each step, it receives generator D’s high-level feature representation, e.g., the feature map of the CNN, and uses it to form the guiding goal for the WORKER module in that timestep. As the information from D is internally-maintained and in an adversarial game it is not supposed to provide G with such information. We thus call it a leakage of information from D.

Next, given the goal embedding produced by the MANAGER, the WORKER firstly encodes current generated words with another LSTM, then combines the output of the LSTM and the goal embedding to take a final action at current state. As such, the guiding signals from D are not only available to G at the end in terms of the scalar reward signals, but also available in terms of a goal embedding vector during the generation process to guide G how to get improved.

Reference

@article{guo2017long,
  title={Long Text Generation via Adversarial Training with Leaked Information},
  author={Guo, Jiaxian and Lu, Sidi and Cai, Han and Zhang, Weinan and Yu, Yong and Wang, Jun},
  journal={arXiv preprint arXiv:1709.08624},
  year={2017}
}

You can get the code and run the experiments in follow folders.

Folder

Synthetic Data: synthetic data experiment

Image COCO: a real text example for our model using dataset Image COCO (http://cocodataset.org/#download)

Note: this code is based on the previous work by LantaoYu. Many thanks to LantaoYu.

leakgan's People

Contributors

cr-gjx avatar wnzhang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

leakgan's Issues

An issue about your code

I am a little confused by your code synthetic data with sequence length 40.
The following code may replace the positive samples(real data) with LSTM generated ones.
generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file, 0)
Since the oracle LSTM does not reload parameters for pkl file, how could you treat the random initialized LSTM generated samples as positive samples?
Many thanks!

What is the latent space in this GAN model?

GAN models usually generate samples from latent space. It enables the GAN to generate distinct samples and generate samples based on some information. However, in your paper and your code, I cannot find a latent space input. In the "Generation Process" sub section of the paper, you claimed that the Manager and Worker both start from zero hidden state. And also, I found that in your code, the initial g(goal vector) is set to a trainable tensor, which means in evaluation or a single training batch it's fixed. Moreover, the initial input, say x_0, is fixed. These are common latent space inputs in RNN GANs, but none of them is set as latent vector input in your model, and this fact really confuses me.

Rewards don't change while training

When I run the code, I try to print the mean value of rewards. Strange is, the mean of rewards didn't change while training. Code snippet I used is here. I just add a print under line 285:

samples = leakgan.generate(sess,1.0,1)
rewards = get_reward(leakgan, discriminator,sess, samples, 4, dis_dropout_keep_prob)
print('rewards: ', np.mean(rewards))
feed = {leakgan.x: samples, leakgan.reward: rewards,leakgan.drop_out:1.0}
_,_,g_loss,w_loss = sess.run([leakgan.manager_updates,leakgan.worker_updates,leakgan.goal_loss,leakgan.worker_loss], feed_dict=feed)
print('total_batch: ', total_batch, "  ",g_loss,"  ", w_loss)

The output is here:

(64, ?, 1720)
(?, ?, 1720)
2018-08-28 18:00:47.803815: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-08-28 18:00:47.955611: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1356] Found device 0 with properties: 
name: GeForce GTX 1080 major: 6 minor: 1 memoryClockRate(GHz): 1.8225
pciBusID: 0000:03:00.0
totalMemory: 7.92GiB freeMemory: 5.43GiB
2018-08-28 18:00:47.955644: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1435] Adding visible gpu devices: 0
2018-08-28 18:00:48.257207: I tensorflow/core/common_runtime/gpu/gpu_device.cc:923] Device interconnect StreamExecutor with strength 1 edge matrix:
2018-08-28 18:00:48.257245: I tensorflow/core/common_runtime/gpu/gpu_device.cc:929]      0 
2018-08-28 18:00:48.257253: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 0:   N 
2018-08-28 18:00:48.257465: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 4057 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080, pci bus id: 0000:03:00.0, compute capability: 6.1)
[[1395 2108 1587 ... 4713 4964  369]
 [3043 2382 2235 ... 1873   40 3757]
 [1811  411 4354 ...  670  492 3540]
 ...
 [4757 2083 4780 ... 2464 1251 1335]
 [ 571  679 2516 ... 3131 1198 2000]
 [1581  985  414 ... 3967 1530  983]]
('epoch:', 0, '  ')
ERROR:tensorflow:Couldn't match files for checkpoint ./ckpts/leakgan_pre
None
Start pre-training discriminator...
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 9.732324)
('Groud-Truth:', 5.7501173)
('pre-train epoch ', 5, 'test_loss ', 9.205654)
('Groud-Truth:', 5.751481)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.669848)
('Groud-Truth:', 5.7524385)
('pre-train epoch ', 5, 'test_loss ', 8.304642)
('Groud-Truth:', 5.7588925)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.217238)
('Groud-Truth:', 5.7520146)
('pre-train epoch ', 5, 'test_loss ', 8.068525)
('Groud-Truth:', 5.7564244)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.088984)
('Groud-Truth:', 5.7408185)
('pre-train epoch ', 5, 'test_loss ', 8.123215)
('Groud-Truth:', 5.7433057)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.125778)
('Groud-Truth:', 5.7547755)
('pre-train epoch ', 5, 'test_loss ', 8.148893)
('Groud-Truth:', 5.7508097)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.230867)
('Groud-Truth:', 5.75031)
('pre-train epoch ', 5, 'test_loss ', 8.225234)
('Groud-Truth:', 5.7536063)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.287518)
('Groud-Truth:', 5.753323)
('pre-train epoch ', 5, 'test_loss ', 8.347004)
('Groud-Truth:', 5.7483764)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.36616)
('Groud-Truth:', 5.761182)
('pre-train epoch ', 5, 'test_loss ', 8.400379)
('Groud-Truth:', 5.7486343)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.394658)
('Groud-Truth:', 5.739507)
('pre-train epoch ', 5, 'test_loss ', 8.400157)
('Groud-Truth:', 5.749277)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.370322)
('Groud-Truth:', 5.7427726)
('pre-train epoch ', 5, 'test_loss ', 8.384528)
('Groud-Truth:', 5.7542734)
#########################################################################
Start Adversarial Training...
('rewards: ', 0.12695181503855285)
('total_batch: ', 0, '  ', -0.08002976, '  ', 2.572021)
('total_batch: ', 0, 'test_loss: ', 8.316683)
('Groud-Truth:', 5.7433925)
('rewards: ', 0.12695181503855285)
('total_batch: ', 1, '  ', -0.078070164, '  ', 2.4859025)
('rewards: ', 0.12695181503855285)
('total_batch: ', 2, '  ', -0.076734476, '  ', 2.4410596)
('rewards: ', 0.12695181503855285)
('total_batch: ', 3, '  ', -0.07695893, '  ', 2.448696)
('rewards: ', 0.12695181503855285)
('total_batch: ', 4, '  ', -0.075939886, '  ', 2.4610584)
('rewards: ', 0.12695181503855285)
('total_batch: ', 5, '  ', -0.07431904, '  ', 2.3714132)
('total_batch: ', 5, 'test_loss: ', 8.068953)
('Groud-Truth:', 5.7565346)
('rewards: ', 0.12695181503855285)
('total_batch: ', 6, '  ', -0.07445905, '  ', 2.353417)
('rewards: ', 0.12695181503855285)
('total_batch: ', 7, '  ', -0.0741186, '  ', 2.3683317)
('rewards: ', 0.12695181503855285)
('total_batch: ', 8, '  ', -0.07297791, '  ', 2.277438)
('rewards: ', 0.12695181503855285)
('total_batch: ', 9, '  ', -0.072162904, '  ', 2.1770558)
('rewards: ', 0.12695181503855285)
('total_batch: ', 10, '  ', -0.072940074, '  ', 2.2332137)
('total_batch: ', 10, 'test_loss: ', 7.715553)
('Groud-Truth:', 5.749097)
('rewards: ', 0.12695181503855285)
('total_batch: ', 11, '  ', -0.072791696, '  ', 2.2014272)
('rewards: ', 0.12695181503855285)
('total_batch: ', 12, '  ', -0.071583025, '  ', 2.1169)
('rewards: ', 0.12695181503855285)
('total_batch: ', 13, '  ', -0.07056489, '  ', 2.1018846)
('rewards: ', 0.12695181503855285)
('total_batch: ', 14, '  ', -0.06847201, '  ', 1.9985498)
('rewards: ', 0.12695181503855285)
('total_batch: ', 15, '  ', -0.06597161, '  ', 1.9304808)
('total_batch: ', 15, 'test_loss: ', 7.7085085)
('Groud-Truth:', 5.7584386)
('rewards: ', 0.12695181503855285)
('total_batch: ', 16, '  ', -0.066821866, '  ', 1.8718865)
('rewards: ', 0.12695181503855285)
('total_batch: ', 17, '  ', -0.0675878, '  ', 1.90795)
('rewards: ', 0.12695181503855285)
('total_batch: ', 18, '  ', -0.06887313, '  ', 2.014595)
('rewards: ', 0.12695181503855285)
('total_batch: ', 19, '  ', -0.06716676, '  ', 1.9232063)
('rewards: ', 0.12695181503855285)
('total_batch: ', 20, '  ', -0.06863812, '  ', 1.9305614)
('total_batch: ', 20, 'test_loss: ', 7.3769546)
('Groud-Truth:', 5.7537217)
('rewards: ', 0.12695181503855285)
('total_batch: ', 21, '  ', -0.067585774, '  ', 1.8494338)
('rewards: ', 0.12695181503855285)
('total_batch: ', 22, '  ', -0.06702501, '  ', 1.8630661)
('rewards: ', 0.12695181503855288)
('total_batch: ', 23, '  ', -0.06764726, '  ', 1.8600069)
('rewards: ', 0.12695181503855285)
('total_batch: ', 24, '  ', -0.06668289, '  ', 1.8348721)
('rewards: ', 0.12695181503855285)
('total_batch: ', 25, '  ', -0.06396266, '  ', 1.6727743)
('total_batch: ', 25, 'test_loss: ', 7.459133)
('Groud-Truth:', 5.7533226)
('rewards: ', 0.12695181503855285)
('total_batch: ', 26, '  ', -0.063976176, '  ', 1.785656)
('rewards: ', 0.12695181503855285)
('total_batch: ', 27, '  ', -0.06395789, '  ', 1.832536)
('rewards: ', 0.12695181503855285)
('total_batch: ', 28, '  ', -0.061984994, '  ', 1.6773611)
('rewards: ', 0.12695181503855285)
('total_batch: ', 29, '  ', -0.062187918, '  ', 1.7238611)
('rewards: ', 0.12695181503855285)
('total_batch: ', 30, '  ', -0.061822247, '  ', 1.6979691)
('total_batch: ', 30, 'test_loss: ', 7.2630606)
('Groud-Truth:', 5.752266)
('rewards: ', 0.12695181503855285)
('total_batch: ', 31, '  ', -0.062000763, '  ', 1.6922305)
('rewards: ', 0.12695181503855285)
('total_batch: ', 32, '  ', -0.060837973, '  ', 1.6040627)
('rewards: ', 0.12695181503855285)
('total_batch: ', 33, '  ', -0.06359503, '  ', 1.6852735)
('rewards: ', 0.12695181503855285)
('total_batch: ', 34, '  ', -0.06405149, '  ', 1.6879333)
('rewards: ', 0.12695181503855285)
('total_batch: ', 35, '  ', -0.06216966, '  ', 1.6790202)
('total_batch: ', 35, 'test_loss: ', 7.3644423)
('Groud-Truth:', 5.7475533)
('rewards: ', 0.12695181503855285)
('total_batch: ', 36, '  ', -0.062462743, '  ', 1.6497834)
('rewards: ', 0.12695181503855285)
('total_batch: ', 37, '  ', -0.0619183, '  ', 1.6184376)
('rewards: ', 0.12695181503855285)
('total_batch: ', 38, '  ', -0.060982812, '  ', 1.6155812)
('rewards: ', 0.12695181503855285)
('total_batch: ', 39, '  ', -0.06237963, '  ', 1.6360941)
('rewards: ', 0.12695181503855285)
('total_batch: ', 40, '  ', -0.06056885, '  ', 1.5517352)
('total_batch: ', 40, 'test_loss: ', 7.1818643)
('Groud-Truth:', 5.7501273)
('rewards: ', 0.12695181503855285)
('total_batch: ', 41, '  ', -0.059358098, '  ', 1.5096384)
('rewards: ', 0.12695181503855285)
('total_batch: ', 42, '  ', -0.06264005, '  ', 1.6303447)
('rewards: ', 0.12695181503855288)
('total_batch: ', 43, '  ', -0.057175227, '  ', 1.4948953)
('rewards: ', 0.12695181503855285)
('total_batch: ', 44, '  ', -0.05770105, '  ', 1.4612815)
('rewards: ', 0.12695181503855285)
('total_batch: ', 45, '  ', -0.05525694, '  ', 1.5415224)
('total_batch: ', 45, 'test_loss: ', 7.37787)
('Groud-Truth:', 5.765032)
('rewards: ', 0.12695181503855285)
('total_batch: ', 46, '  ', -0.05773246, '  ', 1.5037141)
('rewards: ', 0.12695181503855285)
('total_batch: ', 47, '  ', -0.05908689, '  ', 1.5346103)
('rewards: ', 0.12695181503855285)
('total_batch: ', 48, '  ', -0.06146999, '  ', 1.5455922)
('rewards: ', 0.12695181503855285)
('total_batch: ', 49, '  ', -0.059052933, '  ', 1.454091)
('rewards: ', 0.12695181503855285)
('total_batch: ', 50, '  ', -0.060394067, '  ', 1.424764)
('total_batch: ', 50, 'test_loss: ', 6.8873477)
('Groud-Truth:', 5.7444544)
('rewards: ', 0.12695181503855285)
('total_batch: ', 51, '  ', -0.062466055, '  ', 1.4687254)
('rewards: ', 0.12695181503855285)
('total_batch: ', 52, '  ', -0.060713746, '  ', 1.4554237)
('rewards: ', 0.12695181503855285)
('total_batch: ', 53, '  ', -0.06052899, '  ', 1.3958138)
('rewards: ', 0.12695181503855285)
('total_batch: ', 54, '  ', -0.0619392, '  ', 1.4287583)
('rewards: ', 0.12695181503855285)
('total_batch: ', 55, '  ', -0.0621605, '  ', 1.4351918)
('total_batch: ', 55, 'test_loss: ', 6.9499106)
('Groud-Truth:', 5.7546964)
('rewards: ', 0.12695181503855285)
('total_batch: ', 56, '  ', -0.062033392, '  ', 1.5352631)
('rewards: ', 0.12695181503855285)
('total_batch: ', 57, '  ', -0.059066363, '  ', 1.3914597)
('rewards: ', 0.12695181503855285)
('total_batch: ', 58, '  ', -0.061237067, '  ', 1.5043005)
('rewards: ', 0.12695181503855285)
('total_batch: ', 59, '  ', -0.06429994, '  ', 1.5481747)
('rewards: ', 0.12695181503855285)
('total_batch: ', 60, '  ', -0.06365656, '  ', 1.4762405)
('total_batch: ', 60, 'test_loss: ', 7.0070457)
('Groud-Truth:', 5.74922)
('rewards: ', 0.12695181503855285)
('total_batch: ', 61, '  ', -0.063503675, '  ', 1.5692905)
('rewards: ', 0.12695181503855285)
('total_batch: ', 62, '  ', -0.06233556, '  ', 1.520729)
('rewards: ', 0.12695181503855285)
('total_batch: ', 63, '  ', -0.062474646, '  ', 1.5496463)
('rewards: ', 0.12695181503855285)
('total_batch: ', 64, '  ', -0.061706282, '  ', 1.5579293)
('rewards: ', 0.12695181503855285)
('total_batch: ', 65, '  ', -0.05994014, '  ', 1.4769189)
('total_batch: ', 65, 'test_loss: ', 7.2139378)
('Groud-Truth:', 5.746189)
('rewards: ', 0.12695181503855285)
('total_batch: ', 66, '  ', -0.059200168, '  ', 1.496091)
('rewards: ', 0.12695181503855285)
('total_batch: ', 67, '  ', -0.061548878, '  ', 1.482251)
('rewards: ', 0.12695181503855285)
('total_batch: ', 68, '  ', -0.060286995, '  ', 1.3487338)
('rewards: ', 0.12695181503855285)
('total_batch: ', 69, '  ', -0.06037658, '  ', 1.450922)
('rewards: ', 0.12695181503855285)
('total_batch: ', 70, '  ', -0.060411155, '  ', 1.4245926)
('total_batch: ', 70, 'test_loss: ', 6.9322367)
('Groud-Truth:', 5.766845)
('rewards: ', 0.12695181503855285)
('total_batch: ', 71, '  ', -0.061025728, '  ', 1.4254792)
('rewards: ', 0.12695181503855285)
('total_batch: ', 72, '  ', -0.06147579, '  ', 1.5197109)
('rewards: ', 0.12695181503855288)
('total_batch: ', 73, '  ', -0.061548937, '  ', 1.4503106)
('rewards: ', 0.12695181503855285)
('total_batch: ', 74, '  ', -0.06145633, '  ', 1.4222437)
('rewards: ', 0.12695181503855285)
('total_batch: ', 75, '  ', -0.060437728, '  ', 1.4054346)
('total_batch: ', 75, 'test_loss: ', 6.9886317)
('Groud-Truth:', 5.7475686)
('rewards: ', 0.12695181503855285)
('total_batch: ', 76, '  ', -0.050933838, '  ', 1.2105244)
('rewards: ', 0.12695181503855285)
('total_batch: ', 77, '  ', -0.051188022, '  ', 1.2495371)
('rewards: ', 0.12695181503855285)
('total_batch: ', 78, '  ', -0.05532831, '  ', 1.2443336)
('rewards: ', 0.12695181503855285)
('total_batch: ', 79, '  ', -0.053034663, '  ', 1.177541)
('rewards: ', 0.12695181503855285)
('total_batch: ', 80, '  ', -0.05738743, '  ', 1.3191526)
('total_batch: ', 80, 'test_loss: ', 6.9316854)
('Groud-Truth:', 5.7573166)
('rewards: ', 0.12695181503855285)
('total_batch: ', 81, '  ', -0.057465132, '  ', 1.259404)
('rewards: ', 0.12695181503855285)
('total_batch: ', 82, '  ', -0.060087766, '  ', 1.3599504)
('rewards: ', 0.12695181503855285)
('total_batch: ', 83, '  ', -0.061464902, '  ', 1.3595842)
('rewards: ', 0.12695181503855285)
('total_batch: ', 84, '  ', -0.06157413, '  ', 1.3710985)
('rewards: ', 0.12695181503855285)
('total_batch: ', 85, '  ', -0.055594403, '  ', 1.2769436)
('total_batch: ', 85, 'test_loss: ', 6.83688)
('Groud-Truth:', 5.758091)
('rewards: ', 0.12695181503855285)
('total_batch: ', 86, '  ', -0.05882445, '  ', 1.3580688)
('rewards: ', 0.12695181503855285)
('total_batch: ', 87, '  ', -0.06124729, '  ', 1.4289982)
('rewards: ', 0.12695181503855285)
('total_batch: ', 88, '  ', -0.06239751, '  ', 1.3973529)
('rewards: ', 0.12695181503855285)
('total_batch: ', 89, '  ', -0.062367536, '  ', 1.3817501)
('rewards: ', 0.12695181503855285)
('total_batch: ', 90, '  ', -0.055310644, '  ', 1.3466204)
('total_batch: ', 90, 'test_loss: ', 7.1743336)
('Groud-Truth:', 5.7502723)
('rewards: ', 0.12695181503855285)
('total_batch: ', 91, '  ', -0.058700252, '  ', 1.3156176)
('rewards: ', 0.12695181503855285)
('total_batch: ', 92, '  ', -0.057320654, '  ', 1.2937081)
('rewards: ', 0.12695181503855285)
('total_batch: ', 93, '  ', -0.0583059, '  ', 1.3408545)
('rewards: ', 0.12695181503855285)
('total_batch: ', 94, '  ', -0.058293354, '  ', 1.3830068)
('rewards: ', 0.12695181503855285)
('total_batch: ', 95, '  ', -0.05834304, '  ', 1.2609603)
('total_batch: ', 95, 'test_loss: ', 7.2239175)
('Groud-Truth:', 5.7552676)
('rewards: ', 0.12695181503855285)
('total_batch: ', 96, '  ', -0.059853982, '  ', 1.2576574)
('rewards: ', 0.12695181503855285)
('total_batch: ', 97, '  ', -0.062316217, '  ', 1.2934114)
('rewards: ', 0.12695181503855285)
('total_batch: ', 98, '  ', -0.062229063, '  ', 1.2857419)
('rewards: ', 0.12695181503855285)
('total_batch: ', 99, '  ', -0.06375175, '  ', 1.3294318)
('rewards: ', 0.12695181503855285)
('total_batch: ', 100, '  ', -0.062345475, '  ', 1.366978)
('total_batch: ', 100, 'test_loss: ', 7.2802305)
('Groud-Truth:', 5.7551565)
('rewards: ', 0.12695181503855285)
('total_batch: ', 101, '  ', -0.06322038, '  ', 1.4014984)
('rewards: ', 0.12695181503855285)
('total_batch: ', 102, '  ', -0.06465603, '  ', 1.4416918)
('rewards: ', 0.12695181503855285)
('total_batch: ', 103, '  ', -0.06457795, '  ', 1.4671445)
('rewards: ', 0.12695181503855285)
('total_batch: ', 104, '  ', -0.06226616, '  ', 1.3307029)
('rewards: ', 0.12695181503855285)
('total_batch: ', 105, '  ', -0.060927063, '  ', 1.3127974)
('total_batch: ', 105, 'test_loss: ', 7.2767563)
('Groud-Truth:', 5.7502513)
('rewards: ', 0.12695181503855285)
('total_batch: ', 106, '  ', -0.063435964, '  ', 1.3460172)
('rewards: ', 0.12695181503855285)
('total_batch: ', 107, '  ', -0.06490766, '  ', 1.406265)
('rewards: ', 0.12695181503855285)
('total_batch: ', 108, '  ', -0.066277735, '  ', 1.4620061)
('rewards: ', 0.12695181503855285)
('total_batch: ', 109, '  ', -0.06379251, '  ', 1.4881936)
('rewards: ', 0.12695181503855285)
('total_batch: ', 110, '  ', -0.06467301, '  ', 1.475858)
('total_batch: ', 110, 'test_loss: ', 7.504943)
('Groud-Truth:', 5.739826)
('rewards: ', 0.12695181503855285)
('total_batch: ', 111, '  ', -0.06350755, '  ', 1.4781302)
('rewards: ', 0.12695181503855285)
('total_batch: ', 112, '  ', -0.065257765, '  ', 1.4885843)
('rewards: ', 0.12695181503855285)
('total_batch: ', 113, '  ', -0.066707864, '  ', 1.3226095)
('rewards: ', 0.12695181503855285)
('total_batch: ', 114, '  ', -0.06552046, '  ', 1.4152005)
('rewards: ', 0.12695181503855285)
('total_batch: ', 115, '  ', -0.06892277, '  ', 1.5064849)
('total_batch: ', 115, 'test_loss: ', 7.555312)
('Groud-Truth:', 5.7560487)

Custom dataset

If we want to run the model on a custom dataset, how should we do this?

The code could not run

/home/lishaomei/anaconda2/bin/python "/home/lishaomei/LeakGAN-master/Synthetic Data/Main.py"
20
WARNING:tensorflow:From /home/lishaomei/LeakGAN-master/Synthetic Data/Discriminator.py:92: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See tf.nn.softmax_cross_entropy_with_logits_v2.

(64, ?, 1720)
(?, ?, 1720)
2018-12-27 16:49:16.358266: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
2018-12-27 16:49:16.378827: I tensorflow/core/common_runtime/process_util.cc:69] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
2018-12-27 16:49:33.167244: W tensorflow/core/framework/op_kernel.cc:1273] OP_REQUIRES failed at mkl_concat_op.cc:814 : Aborted: Operation received an exception:Status: 5, message: could not create a concat primitive descriptor, in file tensorflow/core/kernels/mkl_concat_op.cc:811
Traceback (most recent call last):
File "/home/lishaomei/LeakGAN-master/Synthetic Data/Main.py", line 324, in
main()
File "/home/lishaomei/LeakGAN-master/Synthetic Data/Main.py", line 189, in main
g = sess.run(leakgan.gen_x,feed_dict={leakgan.drop_out:0.8,leakgan.train:1})
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1152, in _run
feed_dict_tensor, options, run_metadata)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1328, in _do_run
run_metadata)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1348, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.AbortedError: Operation received an exception:Status: 5, message: could not create a concat primitive descriptor, in file tensorflow/core/kernels/mkl_concat_op.cc:811
[[node while/feature/FeatureExtractor/concat (defined at /home/lishaomei/LeakGAN-master/Synthetic Data/Discriminator.py:146) = _MklConcatV2[N=12, T=DT_FLOAT, Tidx=DT_INT32, _kernel="MklOp", _device="/job:localhost/replica:0/task:0/device:CPU:0"](while/feature/FeatureExtractor/conv-maxpool-1/pool-1, while/feature/FeatureExtractor/conv-maxpool-2/pool-2, while/feature/FeatureExtractor/conv-maxpool-3/pool-3, while/feature/FeatureExtractor/conv-maxpool-4/pool-4, while/feature/FeatureExtractor/conv-maxpool-5/pool-5, while/feature/FeatureExtractor/conv-maxpool-6/pool-6, while/feature/FeatureExtractor/conv-maxpool-7/pool-7, while/feature/FeatureExtractor/conv-maxpool-8/pool-8, while/feature/FeatureExtractor/conv-maxpool-9/pool-9, while/feature/FeatureExtractor/conv-maxpool-10/pool-10, while/feature/FeatureExtractor/conv-maxpool-15/pool-15, while/feature/FeatureExtractor/conv-maxpool-20/pool-20, while/feature/FeatureExtractor/concat/axis, while/feature/FeatureExtractor/conv-maxpool-1/pool-1:2, while/feature/FeatureExtractor/conv-maxpool-2/pool-2:2, while/feature/FeatureExtractor/conv-maxpool-3/pool-3:2, while/feature/FeatureExtractor/conv-maxpool-4/pool-4:2, while/feature/FeatureExtractor/conv-maxpool-5/pool-5:2, while/feature/FeatureExtractor/conv-maxpool-6/pool-6:2, while/feature/FeatureExtractor/conv-maxpool-7/pool-7:2, while/feature/FeatureExtractor/conv-maxpool-8/pool-8:2, while/feature/FeatureExtractor/conv-maxpool-9/pool-9:2, while/feature/FeatureExtractor/conv-maxpool-10/pool-10:2, while/feature/FeatureExtractor/conv-maxpool-15/pool-15:2, while/feature/FeatureExtractor/conv-maxpool-20/pool-20:2, DMT/_51)]]

Caused by op u'while/feature/FeatureExtractor/concat', defined at:
File "/home/lishaomei/LeakGAN-master/Synthetic Data/Main.py", line 324, in
main()
File "/home/lishaomei/LeakGAN-master/Synthetic Data/Main.py", line 177, in main
learning_rate=LEARNING_RATE)
File "/home/lishaomei/LeakGAN-master/Synthetic Data/LeakGANModel.py", line 144, in init
gen_o, gen_x,goal,tf.zeros([self.batch_size,self.goal_out_size]),self.goal_init,step_size,gen_real_goal_array,gen_o_worker_array),parallel_iterations=1)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3291, in while_loop
return_same_structure)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3004, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2939, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/home/lishaomei/LeakGAN-master/Synthetic Data/LeakGANModel.py", line 104, in _g_recurrence
feature = self.FeatureExtractor_unit(cur_sen,self.drop_out)
File "/home/lishaomei/LeakGAN-master/Synthetic Data/Discriminator.py", line 146, in unit
h_pool = tf.concat(pooled_outputs, 3)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 1124, in concat
return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1033, in concat_v2
"ConcatV2", values=values, axis=axis, name=name)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3274, in create_op
op_def=op_def)
File "/home/lishaomei/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1770, in init
self._traceback = tf_stack.extract_stack()

AbortedError (see above for traceback): Operation received an exception:Status: 5, message: could not create a concat primitive descriptor, in file tensorflow/core/kernels/mkl_concat_op.cc:811
[[node while/feature/FeatureExtractor/concat (defined at /home/lishaomei/LeakGAN-master/Synthetic Data/Discriminator.py:146) = _MklConcatV2[N=12, T=DT_FLOAT, Tidx=DT_INT32, _kernel="MklOp", _device="/job:localhost/replica:0/task:0/device:CPU:0"](while/feature/FeatureExtractor/conv-maxpool-1/pool-1, while/feature/FeatureExtractor/conv-maxpool-2/pool-2, while/feature/FeatureExtractor/conv-maxpool-3/pool-3, while/feature/FeatureExtractor/conv-maxpool-4/pool-4, while/feature/FeatureExtractor/conv-maxpool-5/pool-5, while/feature/FeatureExtractor/conv-maxpool-6/pool-6, while/feature/FeatureExtractor/conv-maxpool-7/pool-7, while/feature/FeatureExtractor/conv-maxpool-8/pool-8, while/feature/FeatureExtractor/conv-maxpool-9/pool-9, while/feature/FeatureExtractor/conv-maxpool-10/pool-10, while/feature/FeatureExtractor/conv-maxpool-15/pool-15, while/feature/FeatureExtractor/conv-maxpool-20/pool-20, while/feature/FeatureExtractor/concat/axis, while/feature/FeatureExtractor/conv-maxpool-1/pool-1:2, while/feature/FeatureExtractor/conv-maxpool-2/pool-2:2, while/feature/FeatureExtractor/conv-maxpool-3/pool-3:2, while/feature/FeatureExtractor/conv-maxpool-4/pool-4:2, while/feature/FeatureExtractor/conv-maxpool-5/pool-5:2, while/feature/FeatureExtractor/conv-maxpool-6/pool-6:2, while/feature/FeatureExtractor/conv-maxpool-7/pool-7:2, while/feature/FeatureExtractor/conv-maxpool-8/pool-8:2, while/feature/FeatureExtractor/conv-maxpool-9/pool-9:2, while/feature/FeatureExtractor/conv-maxpool-10/pool-10:2, while/feature/FeatureExtractor/conv-maxpool-15/pool-15:2, while/feature/FeatureExtractor/conv-maxpool-20/pool-20:2, DMT/_51)]]

Process finished with exit code 1

leakgan----generator

What does sub_goal mean under the _g_recurrence function in the leakganmodel file?

Get “ValueError: No gradients provided for any variable” when running Image_COCO.Main

Traceback (most recent call last):
File "/Users/zhouye/Documents/dtwave/LeakGAN/run.py", line 14, in
main()
File "/Users/zhouye/Documents/dtwave/LeakGAN/Image_COCO/Main.py", line 161, in main
batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,goal_size=GOAL_SIZE,step_size=4,D_model=discriminator)
File "/Users/zhouye/Documents/dtwave/LeakGAN/Image_COCO/LeakGANModel.py", line 284, in init
self.pretrain_manager_updates = pretrain_manager_opt.apply_gradients(pretrain_manager_grads_vars)
File "/Users/zhouye/miniconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 593, in apply_gradients
([str(v) for _, v, _ in converted_grads_and_vars],))
ValueError: No gradients provided for any variable: ["<tf.Variable 'Manager/Manager/Variable:0' shape=(1880, 128) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_1:0' shape=(128, 128) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_2:0' shape=(128,) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_3:0' shape=(1880, 128) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_4:0' shape=(128, 128) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_5:0' shape=(128,) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_6:0' shape=(1880, 128) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_7:0' shape=(128, 128) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_8:0' shape=(128,) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_9:0' shape=(1880, 128) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_10:0' shape=(128, 128) dtype=float32_ref>", "<tf.Variable 'Manager/Manager/Variable_11:0' shape=(128,) dtype=float32_ref>", "<tf.Variable 'Manager/Manager_1/Variable:0' shape=(128, 1880) dtype=float32_ref>", "<tf.Variable 'Manager/Manager_1/Variable_1:0' shape=(1880,) dtype=float32_ref>", "<tf.Variable 'Manager/goal_init:0' shape=(64, 1880) dtype=float32_ref>"].

python3.6+tensorflow1.13

many thanks.

Modulo by Zero error when attempting to train on custom dataset

I've been trying to get your LeakGAN to work for some time now but I just don't understand what format your realtrain_cotra file is or what kind of preencoding you've done to it. I have a text dataset, but replacing realtrain_Cotra doesn't work. I thought it needed to be onehot encoded, but that didn't work. I thought maybe it just needed to be label encoded but that doesn't work. I converted the label encoded integer array into a simple string to try and closely match the kind of data you have in your realtrain_cotra but then it gives me a divide by zero error.

epoch: 0
Traceback (most recent call last):
File "Main.py", line 285, in
main()
File "Main.py", line 177, in main
gen_data_loader.create_batches(positive_file)
File "/content/LeakGAN/Image COCO/dataloader.py", line 22, in create_batches
self.sequence_batch = np.split(np.array(self.token_stream), self.num_batch, 0)
File "/usr/local/lib/python2.7/dist-packages/numpy/lib/shape_base.py", line 847, in split
if N % sections:
ZeroDivisionError: integer division or modulo by zero

what in the world do you do to get this system working on custom data?

Where is code for the mentioned "Long Text Generation: EMNLP2017 WMT News" experiment?

Hello @CR-Gjx @wnzhang. Firstly, Really great work with the LeakGAN framework. The paper is very aptly organised and is indeed very insightful. I have been reading the paper in order to understand the concepts of the framework. I am interested in the Long Text Generation experiment among all the others. In this repository, unfortunately, I couldn't find the code for this experiment (for the other two experiments, it's there). Could you point me to where it is? Thank you!

Where is the real data inputed?

Where is the real data inputed? Since generate_samples will replace no matter what in real_data.txt. And the orginal data will never be used. If the target lstm is ramdomly initialized, this file will be replaced too with generrated data.

GPU Usage

How can we know for sure that the model is benefiting from the GPU. I saw on the terminal that RTX 2060 device is recognized but the code is running for 5 days and still looping over the TOTAL_BATCH which is 800. Currently executing the 422th loop. When I used the nvidia-smi it showed that python3 is using 81Mib of GPU memory.

Improve LeakGAN by Changing Policy Gradient Structure

Hey @CR-Gjx Thanks for providing this open source code. Very helpful to study and I love the idea of hierarchical reinforcement learning.

In the recent AlphaGo Zero paper and Thinking Fast and Slow Paper, they both show that replacing classic policy gradient with MCTS guided gradients reaches far better results and is more stable.

In these papers, the rf problems they address are self-play, but I believe that their techniques could be applied to LeakGAN and it would improve its performance substantially.

Currently, LeakGAN takes a sequence and calculates a scalar reward. This scalar reward is then used in REINFORCE to improve the worker. We would leave the manager's objective the same.

Instead of having just one correct target, the papers suggest having multiple correct targets (distribution). There is a reward for each target within the distribution. Thus when you do cross entropy, you are doing it over n targets each with their respective rewards.

To generate individual rewards for each target, MCTS is used to improve upon the original policy. They use action = Q(s,a) + U(s,a) to create the decision tree. They then use the number of visits to calculate the reward (rather than a value function because value function leads to overfitting).

The fundamental difference here is that we are optimizing a distribution rather than a single target. This distribution naturally has way more information for the Generator to benefit from. I think this would help immensely help with mode collapse which is partially remedied by occasionally training with MLE. Thoughts on this?

In /Image CoCo ,training stop based on what?

I tried to train leakGan in /Image CoCo,while I saw both worker_loss and manager_loss are unstable.
total_batch: 440 -0.0632318 0.985172
...
total_batch: 450 -4.78807e-05 3.20244
total_batch: 451 -0.0882713 0.516656
...
total_batch: 455 -0.0578834 1.14354
I'm wondering how to know when to stop training? During the training ,dose g_loss or w_loss meaningful?

And why g_loss<0?
thank you very much.

A few more instructions in the README

Would you mind presenting some introductory instructions in your readme about how to load new data into synthetic data for training? I would like to replicate your results then introduce new data. Thank you for this resource.

I can't get the same results as the paper.

I ran the codes which In the path(\LeakGAN-master\Image COCO\Main.py)(LeakGAN-master\Image COCO\eval_bleu.py). but I can't get the same results as the paper. I just changed the python's version(3.6) and tensorflow-gpu 's version(1.8.0).
this is the result i got :
BLEU2: 0.902
BLEU3: 0.7837
BLEU4 :0.6206
BLEU5: 0.4208
Please tell me how I can modify it.thanks a lot!

Why CNN instead of LSTM?

There are some recent text generation GAN papers in which the discriminator is a LSTM rather than a CNN. Why was CNN used in this paper?

why there is some 'tail tokens' after the period?

I have trained the code in /Image CoCo ,during the training , I found there always be some 'tail tokens' after the period and some paddings, even the 451th epoch still has this kind of problem.
Like this sentence: 'A large bottle of wine sitting on a table . sauerkraut ' or
'A bathroom with a tub , sink and mirror .(pad)(pad)(pad)across '
I'm wondering why does this happen?
please help, thank you very much.

Number of batches

How many batches are we supposed to go through?
I'm started adversarial training and I'm batch 405 for target length 20. But on your github page, the maximum number of batches is 91. I cloned everything you had and ran it accordingly. Am I doing something wrong or is this expected?

Thanks

I think the Bootstrapped Rescaled Activation in the code is not useful

def rescale( reward, rollout_num=1.0): reward = np.array(reward) x, y = reward.shape ret = np.zeros((x, y)) for i in range(x): l = reward[i] rescalar = {} for s in l: rescalar[s] = s idxx = 1 min_s = 1.0 max_s = 0.0 for s in rescalar: rescalar[s] = redistribution(idxx, len(l), min_s) idxx += 1 for j in range(y): ret[i, j] = rescalar[reward[i, j]] return ret

I read the code, but I didn't see any sorting behavior, more like scaling by insertion order. If this code is collating, "for s in rescalar," it might be more reasonable to read it in some sort order. Please advise, thank you very much.

a little confused to bleu

I am a little confused to bleu calculation nltk.translate.bleu_score.sentence_bleu(reference, h, weight) in which reference is save/realtest_coco.txt and h is save/generator_sample.txt,
is that somewhat demands on realtest_coco.txt, size and contents?

Also , this sentence in the paper

In each step, it receives generator D’s high-level feature representation, e.g., the feature map of the CNN, and uses it to form the guiding goal for the W ORKER module in that timestep.

i am not sure whether it's "generator" or "discriminator".

btw, it's really a very nice job, so many thanks to you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.