Giter Site home page Giter Site logo

msgan's Introduction

Mode Seeking Generative Adversarial Networks for Diverse Image Synthesis

Pytorch implementation for our MSGAN (Miss-GAN). We propose a simple yet effective mode seeking regularization term that can be applied to arbitrary conditional generative adversarial networks in different tasks to alleviate the mode collapse issue and improve the diversity.

Contact: Qi Mao ([email protected]), Hsin-Ying Lee ([email protected]), and Hung-Yu Tseng ([email protected])

Paper

Mode Seeking Generative Adversarial Networks for Diverse Image Synthesis
Qi Mao*, Hsin-Ying Lee*, Hung-Yu Tseng*, Siwei Ma, and Ming-Hsuan Yang
IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019 (* equal contribution)
[arxiv]

Citing MSGAN

If you find MSGAN useful in your research, please consider citing:

@inproceedings{MSGAN,
  author = {Mao, Qi and Lee, Hsin-Ying and Tseng, Hung-Yu and Ma, Siwei and Yang, Ming-Hsuan},
  booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
  title = {Mode Seeking Generative Adversarial Networks for Diverse Image Synthesis},
  year = {2019}
}

Example Results

Usage

Prerequisites

Install

  • Clone this repo:
git clone https://github.com/HelenMao/MSGAN.git

Training Examples

Download datasets for each task into the dataset folder

mkdir datasets

Conditoned on Label

cd MSGAN/DCGAN-Mode-Seeking
python train.py --dataroot ./datasets/Cifar10

Conditioned on Image

  • Paired Data: facades and maps
  • Baseline: Pix2Pix

You can download the facades and maps datasets from the BicycleGAN [Github Project].
We employ the network architecture of the BicycleGAN and follow the training process of Pix2Pix.

cd MSGAN/Pix2Pix-Mode-Seeking
python train.py --dataroot ./datasets/facades
  • Unpaired Data: Yosemite (summer <-> winter) and Cat2Dog (cat <-> dog)
  • Baseline: DRIT

You can download the datasets from the DRIT [Github Project].
Specify --concat 0 for Cat2Dog to handle large shape variation translation

cd MSGAN/DRIT-Mode-Seeking
python train.py --dataroot ./datasets/cat2dog

Conditioned on Text

  • Dataset: CUB-200-2011
  • Baseline: StackGAN++

You can download the datasets from the StackGAN++ [Github Project].

cd MSGAN/StackGAN++-Mode-Seeking
python main.py --cfg cfg/birds_3stages.yml

Pre-trained Models

Download and save them into

./models/

Evaluation

For Pix2Pix, DRIT, and StackGAN++, please follow the instructions of corresponding github projects of the baseline frameworks for more evaluation details.

Testing Examples

DCGAN-Mode-Seeking

python test.py --dataroot ./datasets/Cifar10 --resume ./models/DCGAN-Mode-Seeking/00199.pth

Pix2Pix-Mode-Seeking

python test.py --dataroot ./datasets/facades --checkpoints_dir ./models/Pix2Pix-Mode-Seeking/facades --epoch 400
python test.py --dataroot ./datasets/maps --checkpoints_dir ./models/Pix2Pix-Mode-Seeking/maps --epoch 400

DRIT-Mode-Seeking

python test.py --dataroot ./datasets/yosemite --resume ./models/DRIT-Mode-Seeking/yosemite/01200.pth --concat 1
python test.py --dataroot ./datasets/cat2dog --resume ./models/DRIT-Mode-Seeking/cat2dog/01999.pth --concat 0

StackGAN++-Mode-Seeking

python main.py --cfg cfg/eval_birds.yml 

Reference

Quantitative Evaluation Metrics

msgan's People

Contributors

helenmao avatar hsinyinglee avatar hytseng0509 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

msgan's Issues

Questions about DRIT

Hello,Thanks for sharing the great work! I would like to ask some questions since I think DRIT is kind of hard for me.

Num1: What does fake_B_encoded、fake_B_random and fake_AA_encoded respectively mean? In other words, what is the mean of output_fakeA of generators? (model.py line 153)

Num2: I have trained my DRIT model using cat2dog dataset under default train options for 1000 epochs , I think maybe it's time to test. So I test my model by
python test.py --dataroot ./datasets/cat2dog --resume ./results/trial/00979.pth
and get 5 dogs images per folder. But the 5 dogs looks all same, the only difference is color. Does that mean mode collapse occur or I test my model in a wrong way? How can I get the results like your "Example Results"?

Num3: How can I calculated scores such as FID or LPIPS?

Thanks again!

where is the appendix of the paper?

I found the appendix in arXiv:1903.05628v6.
I do not understand the following sentence:For LPIPS, we randomly select 50 pairs of the 50 images of each context in the test set to compute LPIPS and average all the values for this trial.

What's the meaning of 50 pairs ? for one pair, the two picture is ?
There are three way to calculate LPIPS,use way 2 or way 3?

way 1:python compute_dists.py -p0 imgs/ex_ref.png -p1 imgs/ex_p0.png --use_gpu
way 2:python compute_dists_dirs.py -d0 imgs/ex_dir0 -d1 imgs/ex_dir1 -o imgs/example_dists.txt --use_gpu
way 3:python compute_dists_pair.py -d imgs/ex_dir_pair -o imgs/example_dists_pair.txt --use_gpu

是不是没有上传几个预训练的embedding layer呀?

请问作者可以上传一下StackGAN++-Mode-Seeking这个目录下datasets.py文件中的
这三个预训练的embedding层吗?

    def load_embedding(self, data_dir, embedding_type):
        if embedding_type == 'cnn-rnn':
            embedding_filename = '/char-CNN-RNN-embeddings.pickle'
        elif embedding_type == 'cnn-gru':
            embedding_filename = '/char-CNN-GRU-embeddings.pickle'
        elif embedding_type == 'skip-thought':
            embedding_filename = '/skip-thought-embeddings.pickle'

        with open(data_dir + embedding_filename, 'rb') as f:
            embeddings = pickle.load(f, encoding="bytes")
            embeddings = np.array(embeddings)
            # embedding_shape = [embeddings.shape[-1]]
            print('embeddings: ', embeddings.shape)
        return embeddings

some trouble inDCGAN

Hi, thanks for such excellent workWhen I ran the DCGAN code you provided, the following error occurred:
usage: train.py [-h] --dataroot DATAROOT [--img_size IMG_SIZE] [--nz NZ] [--class_num CLASS_NUM] [--phase PHASE] [--batch_size BATCH_SIZE] [--nThreads NTHREADS] [--name NAME] [--display_dir DISPLAY_DIR] [--result_dir RESULT_DIR] [--display_freq DISPLAY_FREQ] [--img_save_freq IMG_SAVE_FREQ] [--model_save_freq MODEL_SAVE_FREQ] [--no_display_img] [--n_ep N_EP] [--resume RESUME] [--gpu GPU] train.py: error: the following arguments are required: --dataroot
then when I changed the ruquired = True of BaseOptions in options.py to ruquired = False, the error disappeared, but followed by the following error:
Traceback (most recent call last): File "D:/Python/MSGAN-master/DCGAN-Mode-Seeking/train.py", line 68, in <module> main() File "D:/Python/MSGAN-master/DCGAN-Mode-Seeking/train.py", line 16, in main os.makedirs(opts.dataroot, exist_ok=True) File "C:\ProgramData\Anaconda3\lib\os.py", line 205, in makedirs head, tail = path.split(name) File "C:\ProgramData\Anaconda3\lib\ntpath.py", line 204, in split p = os.fspath(p) TypeError: expected str, bytes or os.PathLike object, not NoneType
and I tried many ways without correcting the error. So what do you suggest?
Thank you!

Numpy as training input

New to python and deep learning. I am wondering how can we directly use numpy as training input? I think that is definitely possible, right?

NBD and JSD

Hi,thank you for your excellent work and providing the code links of evaluation,but I don't know how to use NDB in other datasets and find no JSD, can you provide the NBD and the JSD codes of the evaluation Facade and Map

Evaluation codes

Hi, thanks for such excellent work. I am wondering if you would like to make the evaluation codes that you use available publicly?

Mode Seeking Loss does not decrease

Hi,
I used your loss in my own code on my own dataset using a pix2pix model. When using only the gan loss and l1 loss, the model is trained well. However, when adding the ms loss to the generator, the ms loss does not decrease and the whole training crashes. I tried different ms loss lambda, but all failed. Have you encountered such issues during training? Are there any possible reasons for this problem? Thank you.

Replicating Pix2Pix experiment on maps dataset

Hi @HelenMao,

I tried replicating the experiment on maps dataset but the FID is much higher (200-300) than what is reported in the paper. I initially ran the default hyperparameter setting in the code. However, I found that the pre-trained model was trained for 400 epochs while the default setting was 200 epochs. I still get around 200 FID after running 400 epochs. Could you please let me know if there is any specific hyperparameter I need to set to reproduce the result? Thank you so much and happy holidays!

Gradient-based mode seeking loss

Hi, thanks for sharing the code of the paper.
Have you tried gradient-based mode seeking loss as below?

CodeCogsEqn

In conditional image generation task on the CIFAR-10, I got almost the same FID with gradient-based mode seeking loss as the one with original distance-based mode seeking loss.

How many images used for computing FID?

Hi,
I didn't find in your paper how many images you used for the computation of FID.
Since FID is sensitive to the number of samples, how many samples you used? E.g., for the map dataset, did you use every image in the validation set? And how many generated sample per input map image?

Thanks!

LPIPS

Hi @HelenMao,

For LPIPS, I am assuming you use Alexnet and version 0.0. For NDB, could you please share the hyperparameters (other than K) to the ImageBatchProvider, NDB, and evaluate function? Thanks!

Reciprocal of mode seeking loss

You seem to use the reciprocal of your mode seeking loss at line 131 instead of minimizing the inverse of the function to maximize it. Nowhere in your paper you state doing so. Is there a particular reason for this because optimizing the reciprocal is not the same as optimizing the inverse of a function.

NDB & JSD Reproduction Problems

Hello, I'm trying to replicate the NDB & JSD results on CDCGAN reported in the MSGAN paper but have been unable to do so (I'm getting JSD results 1 order of magnitude greater). May I know how did you perform the evaluation? I checked out the code in gans-n-gmms/utils/ndb.py but noticed that their metric is for greyscale data, whereas the datasets in the MSGAN paper (e.g. CFIAR10, CUB200) are RGB.

  1. Did you flatten the images (B, CHW) to calculate NDB/JSD, or did you first convert the RGB results & real data into greyscale (so (B, H*W)) and then calculate NDB/JSD?

  2. Also, to confirm, were the NDB, JSD, FID scores calculated from comparing the generated results with the test set (or train set) of the benchmark datasets?

  3. When calculating the class-wise NDB/JSD, how many generated samples and real samples do you use respectively for comparison?

  4. When calculating the FID & LPIPS of the entire dataset, how many generated samples and real samples do you use respectively for comparison?

Thanks a lot!

Learning rate decay for cat2dog dataset

Hi @HelenMao,

python test.py --dataroot ./datasets/cat2dog --resume ./models/DRIT-Mode-Seeking/cat2dog/01999.pth --concat 0

Unlike summer2winter, I see that the model is ran for 2k epochs for cat2dog dataset. I am wondering if the number of epochs before we decay the learning rate is 600 or 1,000. I could not find the details of the exact hyperparameter settings.
In the code, I saw in the comment that n_ep = 400 * d_iter. Does that also mean d_iter = 2000/400 = 5 in this case?
Thank you for your help!

a question about DCGAN

Thank you very much for your excellent work!
Author Hello, I would like to ask, when you test is based on what indicators to select the load training saved model? What are the indicators that let you know that the generator is perfect at this time and can be used for testing to generate image? Because in training, only the number of iterations and the learning rate are printed in real time.
Another question is, can the program you provide run on more GPUs?
Thank you very much for your help!

compute <mode seeking loss> is different from paper!

lz = torch.mean(torch.abs(self.fake_image2 - self.fake_image1)) / torch.mean(

torch.abs(self.z_random2 - self.z_random))

self.loss_lz = 1 / (lz + eps)

Hi, thank you sharing the code of paper , and I have a question about your mode seeking loss lz in code.

In your paper ,the format is :

image
but in your code DCGAN, lz it is reciprocal with paper, why ?

Do you use L1 Loss in pixel2pixel model?

In your paper, there are two kinds of losses, original GAN loss and your proposed mode seeking loss.

I want to ask whether you use L1 loss in training pixep2pixel model.

It seems that L1 loss will fight with your mode seeking loss since L1 loss try to make every output same but mode seeking loss aims at push every output diffferent according the noise distance.

Your job is really nice.

help needed for reproducing FID

Hi, thanks for making your work publicly.

I would like to build my idea upon your nice work but running your DCGAN train/test code as it is did not give me similar FID scores while the images (gen_00199.jpg) look fine.
image

airplane 67.26209697186295
automobile 56.615038948216466
bird 67.71981481266164
cat 64.21582516792154
deer 50.58379969002419
dog 74.1233251819657
frog 55.22868149492308
horse 60.88559157869355
ship 52.45243167542941
truck 52.44226222359282
fid: 60.15288677452914 +_ 7.530598089037249

Testing with your pretrained weights results in better but still not there.
airplane 55.98297725079635
automobile 48.51736857887687
bird 62.36735954146019
cat 54.30139778488274
deer 50.4328713386401
dog 68.39370845428152
frog 47.34936598559659
horse 54.073287595819295
ship 47.548950778636936
truck 48.33605156842515
fid: 53.73033388774157 +_ 6.6423165837502305

Could you find something wrong in my procedure below?

  • Train
    python train.py --dataroot ~/data/cifar10
  • Test
    python test.py --dataroot ~/data/cifar10 --resume results/trial/00199.pth --num 5000 --name trial/images_test
  • Prepare training images
    python cifar10_to_imagefolder.py --dir_dataset ~/data/cifar10/ --dir_dest ~/data/cifar10/images/train
  • Prepare per-class stats of training images
    python precalc_stats_perclass.py --root_dataset ~/data/cifar10/images/train/ --dir_dest ~/data/cifar10/fid_stats
  • Compute per-class FID
    python fid_perclass.py --dir_stats ~/data/cifar10/fid_stats --root_samples ~/MSGAN/DCGAN-Mode-Seeking/results/trial/images_test/

Am I missing something?

For your convenience to track down,
my per-class evaluation code is here,
and preparing training images is here (cifar10_to_imagefolder.py above)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.