Giter Site home page Giter Site logo

improved-wgan-pytorch's Introduction

Improved Training of Wasserstein GANs in Pytorch

This is a Pytorch implementation of gan_64x64.py from Improved Training of Wasserstein GANs.

To do:

  • Support parameters in cli *
  • Add requirements.txt *
  • Add Dockerfile if possible
  • Multiple GPUs *
  • Clean up code, remove unused code *

* not ready for conditional gan yet

Run

  • Example:

Fresh training

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --train_dir /path/to/train --validation_dir /path/to/validation/ --output_path /path/to/output/ --dim 64 --saving_step 300 --num_workers 8

Continued training:

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --train_dir /path/to/train --validation_dir /path/to/validation/ --output_path /path/to/output/ --dim 64 --saving_step 300 --num_workers 8 --restore_mode --start_iter 5000

Model

  • train.py: This model is mainly based on GoodGenerator and GoodDiscriminator of gan_64x64.py model from Improved Training of Wasserstein GANs. It has been trained on LSUN dataset for around 100k iters.
  • congan_train.py: ACGAN implementation, trained on 4 classes of LSUN dataset

Result

1. WGAN: trained on bedroom dataset (100k iters)

Sample 1 Sample 2

2. ACGAN: trained on 4 classes (100k iters)

  • dining_room: 1
  • bridge: 2
  • restaurant: 3
  • tower: 4
Sample 1 Sample 2

Testing

During the implementation of this model, we built a test module to compare the result between original model (Tensorflow) and our model (Pytorch) for every layer we implemented. It is available at compare-tensorflow-pytorch

TensorboardX

Results such as costs, generated images (every 200 iters) for tensorboard will be written to ./runs folder.

To display the results to tensorboard, run: tensorboard --logdir runs

Acknowledgements

improved-wgan-pytorch's People

Contributors

chithangduong avatar dependabot[bot] avatar elvisyjlin avatar jalola avatar nikasa1889 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

improved-wgan-pytorch's Issues

Generator loss function

During the training of the generator, the gradients are calculated w.r.t D(G(z)) instead of -D(G(Z)).

RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1]) and output[0] has a shape of torch.Size([])

I am trying to upload my own dataset (64x64). When running the code I get the following:
`RuntimeError Traceback (most recent call last)
in ()
259 lib.plot.tick()
260
--> 261 train()

3 frames
in train()
155 gen_cost = aD(fake_data)
156 gen_cost = gen_cost.mean()
--> 157 gen_cost.backward(mone)
158 gen_cost = -gen_cost
159

/usr/local/lib/python3.6/dist-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
148 products. Defaults to False.
149 """
--> 150 torch.autograd.backward(self, gradient, retain_graph, create_graph)
151
152 def register_hook(self, hook):

/usr/local/lib/python3.6/dist-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
91 grad_tensors = list(grad_tensors)
92
---> 93 grad_tensors = _make_grads(tensors, grad_tensors)
94 if retain_graph is None:
95 retain_graph = create_graph

/usr/local/lib/python3.6/dist-packages/torch/autograd/init.py in _make_grads(outputs, grads)
27 + str(grad.shape) + " and output["
28 + str(outputs.index(out)) + "] has a shape of "
---> 29 + str(out.shape) + ".")
30 new_grads.append(grad)
31 elif grad is None:

RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1]) and output[0] has a shape of torch.Size([]).`

Any advice how to fix this?

Questions about noise

Hello, thanks much for your code! I have two questions here about noise:
In code of calculation of gen_cost, why noise requires grad?

          noise.requires_grad_(True)

And in the later code when training G, why

            with torch.no_grad():
                noisev = noise  # totally freeze G, training D

can freeze G?

Increasing dimensionality

Hi, I was taking a look at the model and saw several hard-coded 64s but those both represent batch_size and model dimensionality, which ones should be changed in order to increase dimensionality of the GAN ?

Model Diverge!

Hi There,
I have tried to reproduce the results of WGAN-gp based on your implementation on ImageNet (64x64) (http://image-net.org/small/download.php), but I found this error:

RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input.

SO, I tried to avoid it by setting:
torch.backends.cudnn.enabled = False
Then It works with batch_size =32 (64 raised a memory issue).

However, for 2700 iteration, the model diverged!. Is it related to that change in code?

System info:
Windows 10
Pytorch 1.4
cuda 10.1, V10.1.243
cudnn 6.14

Missing components

It is impossible to run script gan_train.py because it has missing classes like DCGANDiscriminator and missing scripts like models.dcgan.

Also, I suggest to write config.yaml and use smaller datasets for base example (it is very painful to download 42Gb to verify if this source code works or not).

Conditional GAN: NUM_CLASS param for image completion task and whats the deal with random noise?

Hi, am just wondering what to use for the NUM_CLASSES variable if i am performing image completion on generic datasets where there is no significance for the class labels. Because the NUM_CLASSES param determines gen_rand_noise() function output. Also, if i am using a conditional gan on an input image what is the whole point of incorporating a random input noise? The input to the GAN should be another image right! I dont think that is highlighted in the code.

mone in generator

Hi,
In the gan training code, since you have already used gen_cost.backward(mone) for generator, why are you still doing gen_cost = -gen_cost

i.e, should it not be

  • just gen_cost.backward(mone)
    or
  • gen_cost.backward() and then gen_cost = gen_cost*-1
    Please correct me if I am wrong.

A question about loss_G.backward(Mone)

First, i have to thank for your sharing of this work, it does help me understanding the wgan-gp better, but here is a little question:

  • what does the Mone in the loss_G.backward() mean?

l've just got to know wgan-gp recently, so hope the author can help me with this question.

Backward

I noted that you only use one backward() when training D, and many authors use two, how do u think about it?

about the size of images

Hello, I want to ask if your picture supports the size of 32 * 32, and how about the training effect

Some question in model.py

i don't know the function of "class DepthToSpace(nn.Module):", can you give me a brief introduction, thank you !

lsun data issue

lmdb.Error: /datasets/lsun\dining_room_train_lmdb: ϵͳ�Ҳ���ָ����·����

Regarding conditioning in the generator

Dear jalola,

Thank you for your enlightening repository on conditional WGAN-GP. I am particularly interested in the conditional part to generate better images compared to the unconditional version. In other repositories, people concatenate label embedding to the noise vector that are at the same size. For example, say, the noise vector is 640 dimensional, we end up with [batch size, 1280] dimensional matrix after concatenating the embedding and noise vector. Then pass it through the network. However, I realized that you replace first c columns with the class labels where c is the number of classes. Is there any theoretical justification for that or references? Thanks in advance.

One and Mone

Sorry to bother you, I am new so I have too many problems... why we use gen_cost.backward(mone) rather than gen_cost.backward(one)? I think in WGAN it should be gen_cost.backward(one) - refers to the code below. I'm not sure cause I also learn from other codes but I cannot understand here...... Is the code in https://github.com/NUS-Tim/Pytorch-WGAN/tree/master/models right? I think in the papers, WGAN use real loss - fake loss but WGAN-GP use fake loss - real loss for D, but in the code above, the loss is the same, does it means that there is something wrong with the code?

tflibtorch package

Hi, I was going over the code in congan_train and came across line 11 (import tflibtorch as lib). Just wondering where this library is?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.