Giter Site home page Giter Site logo

vita-group / autogan Goto Github PK

View Code? Open in Web Editor NEW
461.0 16.0 84.0 423 KB

[ICCV 2019] "AutoGAN: Neural Architecture Search for Generative Adversarial Networks" by Xinyu Gong, Shiyu Chang, Yifan Jiang and Zhangyang Wang

License: MIT License

Python 97.30% Shell 2.70%
autogan neural-architecture-search gan pytorch cifar10

autogan's Introduction

AutoGAN: Neural Architecture Search for Generative Adversarial Networks

Code used for AutoGAN: Neural Architecture Search for Generative Adversarial Networks.

Updates

  • Oct-02-2019: Search code is released.

Introduction

We've desinged a novel neural architecture search framework for generative adversarial networks (GANs), dubbed AutoGAN. Experiments validate the effectiveness of AutoGAN on the task of unconditional image generation. Specifically, our discovered architectures achieve highly competitive performance on unconditional image generation task of CIFAR-10, which obtains a record FID score of 12.42, a competitive Inception score of 8.55.

RNN controller:

ctrl

Search space:

ss

Discovered network architecture:

cifar_arch1

Performance

Unconditional image generation on CIFAR-10.

cifar10_res

Unconditional image generation on STL-10.

stl10_res

Set-up

environment requirements:

python >= 3.6

torch >= 1.1.0

pip install -r requirements.txt

prepare fid statistic file

mkdir fid_stat

Download the pre-calculated statistics to ./fid_stat.

How to search & train the derived architecture by yourself

sh exps/autogan_search.sh

When the search algorithm is done, you will get a vector denoting the discovered architecture, which can be viewed in the "*.log" file.

To train from scratch and get the performance of your discovered architecture, run the following command (you should replace the architecture vector following "--arch" with yours):

python train_derived.py \
-gen_bs 128 \
-dis_bs 64 \
--dataset cifar10 \
--bottom_width 4 \
--img_size 32 \
--max_iter 50000 \
--gen_model shared_gan \
--dis_model shared_gan \
--latent_dim 128 \
--gf_dim 256 \
--df_dim 128 \
--g_spectral_norm False \
--d_spectral_norm True \
--g_lr 0.0002 \
--d_lr 0.0002 \
--beta1 0.0 \
--beta2 0.9 \
--init_type xavier_uniform \
--n_critic 5 \
--val_freq 20 \
--arch 1 0 1 1 1 0 0 1 1 1 0 1 0 3 \
--exp_name derive

How to train & test the discovered architecture reported in the paper

train

sh exps/autogan_cifar10_a.sh

test

Run the following script:

python test.py \
--dataset cifar10 \
--img_size 32 \
--bottom_width 4 \
--gen_model autogan_cifar10_a \
--latent_dim 128 \
--gf_dim 256 \
--g_spectral_norm False \
--load_path /path/to/*.pth \
--exp_name test_autogan_cifar10_a

Pre-trained models are provided (Google Drive).

Citation

If you find this work is useful to your research, please cite our paper:

@InProceedings{Gong_2019_ICCV,
author = {Gong, Xinyu and Chang, Shiyu and Jiang, Yifan and Wang, Zhangyang},
title = {AutoGAN: Neural Architecture Search for Generative Adversarial Networks},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2019}
}

Acknowledgement

  1. Inception Score code from OpenAI's Improved GAN (official).
  2. FID code and CIFAR-10 statistics file from https://github.com/bioinf-jku/TTUR (official).

autogan's People

Contributors

gongxinyuu avatar yifanjiang19 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

autogan's Issues

Training error

[Epoch 14/15] [Batch 300/782] [D loss: 1.949833] [G loss: 0.263498]
[Epoch 14/15] [Batch 400/782] [D loss: 1.945486] [G loss: -0.067332]
[Epoch 14/15] [Batch 500/782] [D loss: 1.483213] [G loss: 0.168653]
[Epoch 14/15] [Batch 600/782] [D loss: 2.166453] [G loss: 0.740634]
[Epoch 14/15] [Batch 700/782] [D loss: 1.869209] [G loss: 0.684715]
=> train controller...
arch: tensor([0, 0, 2, 0], device='cuda:1')
calculate Inception score...
get Inception score of 1.725852608680725
Traceback (most recent call last):█████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:34<00:00, 1.57it/s]
File "search.py", line 199, in
main()
File "search.py", line 170, in main
train_controller(args, controller, ctrl_optimizer, gen_net, prev_hiddens, prev_archs, writer_dict)
File "/home/ht/projects/AutoGAN/functions.py", line 213, in train_controller
cur_batch_rewards = cur_batch_rewards.unsqueeze(-1) + args.entropy_coeff * entropies # bs * 1
RuntimeError: expected device cuda:0 but got device cuda:1

FID score comparison

Hi, thanks for providing the code and the discovered architecture looks very interesting to experiment and sheds some insights on how generators could be well designed. In particular, I'm very intrigued by how AutoGAN exploits using upsample-convs rather than transposed convs.

To clarify with you, is the model definition (cifar10_a) at [1] the one that reproduces the best AutoGAN results reported?

Furthermore, could I check if your FID score reported uses 50k real and 50k fake images to compute, since the official FID statistics uses 50k examples (all training data) and your config at [2] uses 50k images by default as well? It seems the other works compared uses different number of images to compute (e.g. SNGAN uses 10k real, 5k fake, refer to pg 16, appendix B.1 of paper [3]). As I see the hyperparameters used compared to SNGAN is quite similar, have you tried computing the AutoGAN's FID score using 10k-real, 5k-fake images for a comparison?

For example, I have run the official SNGAN code [4] and tried using 50k real and 50k fake images to compute FID (using official code and stats from TTUR repo) as well, and it seems that the FID score I got is around 15 (compared to 21.7 for 10k-5k FID score) for unconditional SNGAN.

Thanks for the help! 👍 👍

[1] https://github.com/TAMU-VITA/AutoGAN/blob/master/models/autogan_cifar10_a.py
[2] https://github.com/TAMU-VITA/AutoGAN/blob/master/cfg.py#L143
[3] https://arxiv.org/pdf/1802.05957.pdf
[4] https://github.com/pfnet-research/sngan_projection

pip install is giving errors

This is in python v3.6.9 conda environment in Windows:

ERROR: Could not find a version that satisfies the requirement torch==1.1.0 (from -r requirements.txt (line 9)) (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)
ERROR: No matching distribution found for torch==1.1.0 (from -r requirements.txt (line 9))

Can you please fix the version issue?

Trouble with training discovered GANs

I am using Google Colab to train a GAN (after executing search) or the GAN searched in the research paper according to the instructions mentioned in the 'README.md' but the following error keeps happening.

Screenshot (5)

Traceback (most recent call last):
File "train.py", line 167, in
main()
File "train.py", line 138, in main
inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict)
File "/content/AutoGAN/functions.py", line 279, in validate
img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)
File "/usr/local/lib/python3.6/dist-packages/torchvision/utils.py", line 75, in make_grid
norm_range(t, range)
File "/usr/local/lib/python3.6/dist-packages/torchvision/utils.py", line 71, in norm_range
norm_ip(t, float(t.min()), float(t.max()))
File "/usr/local/lib/python3.6/dist-packages/torchvision/utils.py", line 64, in norm_ip
img.clamp_(min=min, max=max)
RuntimeError: Output 0 of UnbindBackward is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

pre-cal stat

Hi, may I ask what the pre-calculated statistics do? thanks.

Mistakes in training

Dear Professor
At the beginning of the network training EPOCH0, my server prompts OSError: Not a gzipped file (b'<! ')
Did a Google search and found that maybe the training dataset was corrupted? Or is there some reason I don't understand?
Could you please let me know? Thank you very much
image

Does the program need special configuration before it runs?

Hi, thanks for providing the source code. I have trained the optimal network architecture reported in this paper for several times (by using the command 'sh exps/autogan_cifar10_a.sh'), but it is surprising that the FID value is always beyond 13. So I wonder whether the program need special configuration before it runs (I just use the raw configuration now).

Trouble Getting Pre-Trained Models to Work

Hi, thanks for supplying this repo.

I've been trying to use your .py files and saved weights to load your pre-trained model. But there's a few problems.

If I understand corretly, I am supposed to download/use these weights 'exps/autogan_cifar10_a.sh', and use the associated .py file.

I used the following info
python test.py
--dataset cifar10
--img_size 32
--bottom_width 4
--gen_model autogan_cifar10_a
--latent_dim 128
--gf_dim 256
--g_spectral_norm False
--load_path /path/to/*.pth
--exp_name test_autogan_cifar10_a

...and put it manually into the autogan_cifar10_a.py file. But I get the following error when I try load the weights into the class.

`---> 52 G.load_state_dict(torch.load('weights/autogan_cifar10_a.pth', map_location='cpu'))
53 cnn.load_state_dict(torch.load('weights/cifar-10_cnn Resnet.pth', map_location='cpu'))
54

/anaconda3/envs/exp1/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
775 if len(error_msgs) > 0:
776 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 777 self.class.name, "\n\t".join(error_msgs)))
778 return _IncompatibleKeys(missing_keys, unexpected_keys)
779

RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "cell1.n1.weight", "cell1.n1.bias", "cell1.n1.running_mean", "cell1.n1.running_var", "cell1.n2.weight", "cell1.n2.bias", "cell1.n2.running_mean", "cell1.n2.running_var".
`

Thank you in advance for any help, perhaps you could provide a working example where you load the pre-trained model and generate some samples in a jupyter notebook? Happy new year :-)

how to generate fid_stats_cifar10_train.npz

Hello together,
I'm training a new model with other datasets. The other datasets are ready and can be trained normally, but the inception score and FID score are not correct . because I use the downloaded fid_stats_cifar10_train.npz.
My question: What is stored in fid_stats_cifar10_train.npz?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.