Giter Site home page Giter Site logo

noahvl / explaining-in-style-reproducibility-study Goto Github PK

View Code? Open in Web Editor NEW
34.0 2.0 7.0 163.93 MB

Re-implementation of the StylEx paper, training a GAN to explain a classifier in StyleSpace, paper by Lang et al. (2021).

License: Other

Python 2.27% Jupyter Notebook 97.73%
cnn computer-vision convolutional-neural-network deep-learning gan generative-adversarial-network google pytorch stylex

explaining-in-style-reproducibility-study's Introduction

[Re] Explaining in Style: Training a GAN to explain a classifier in StyleSpace

This repository is a re-implementation of Explaining in Style: Training a GAN to explain a classifier in StyleSpace by Lang et al. (2021).

GIF of user-study

DOI MIT License

Paper

[Re] Explaining in Style: Training a GAN to explain a classifier in StyleSpace
Noah Van der Vleuten, Tadija Radusinović, Rick Akkerman, Meilina Reksoprodjo

Paper: https://rescience.github.io/bibliography/Vleuten_2022.html

Presented at NeurIPS 2022: https://nips.cc/virtual/2022/poster/56097
Contains a poster and a presentation/slides about our reproducibility efforts.

If you use this for research, please cite our paper:

@article{Vleuten:2022,
  author = {van der Vleuten, Noah and Radusinović, Tadija and Akkerman, Rick and Reksoprodjo, Meilina},
  title = {{[Re] Explaining in Style: Training a GAN to explain a classifier in StyleSpace}},
  journal = {ReScience C},
  year = {2022},
  month = may,
  volume = {8},
  number = {2},
  pages = {{#42}},
  doi = {10.5281/zenodo.6574709},
  url = {https://zenodo.org/record/6574709/files/article.pdf},
  code_url = {https://github.com/NoahVl/Explaining-In-Style-Reproducibility-Study},
  code_doi = {10.5281/zenodo.6512392},
  code_swh = {swh:1:dir:04e11a55f476b115b40fd6af9d06ed70eb248535},
  data_url = {},
  data_doi = {},
  review_url = {https://openreview.net/forum?id=SYUxyazQh0Y},
  type = {Replication},
  language = {Python},
  domain = {ML Reproducibility Challenge 2021},
  keywords = {rescience c, machine learning, deep learning, python, pytorch, explainable ai, xai, gan, stylegan2, stylex}
}

Requirements

Running this notebook requires a CUDA-enabled graphics card. Installing the environment requires Conda.

Instructions

  1. Create the conda environment from the .yml file.
  2. Activate the environment.
  3. Open jupyter notebook.
  4. Open the stylex/all_results_notebook.ipynb notebook.
  5. Download the model files as described in the notebook.
  6. Select model_to_choose to pick the dataset/model on which to show results. Default is 'plant'.

Verifying results

The all_results_notebook.ipynb works with pre-calculated latent vectors to generate results and run the experiments. If you want to generate the latent embeddings yourself, make use of the run_attfind_combined.ipynb notebook (similarly, select the appropriate model_to_choose). Note that you will have to download the datasets if you want to run AttFind (you can make use of the notebooks in the data folder).

Warning: The AttFind procedure is quite slow and may take over an hour depending on your hardware.

How to train the models?

The StylEx framework consists of two parts, the "pretrained" classifier and the Encoder+GAN.

If you want to train a StylEx model on a new dataset we suggest you first train a new classifier and then provide it to the cli.py file to train the StylEx model on this dataset with the new classifier in evaluation mode. If you use a Resnet/Mobilenet model you should only have to change the classifier_name parameter in the cli.py file, or change it as a parameter using --classifier_name <mobilenet/resnet> when you call the cli.py file.

If you want to use a new classifier architecture you should add support for this in one of the stylex_train.py files.

Training one of the supported classifiers

Natively we support the MobileNet V2 and ResNet architecture. Of the two options, ResNet seemed to give much better results on small images upscaled to 224px than MobileNet. The MobileNet classifier training code has been included, however to reiterate, it is advised to train a ResNet classifier when using small images. We have also observed that unfreezing the layers iteratively by editing a Python file is not that preferred.

Therefore we have also created and included a notebook that was used to train the ResNet-18 CelebA gender classifier, this classifier was then used to be explained by the StyleGAN model trained on the FFHQ dataset as per directions of the original paper. In the notebook it is also possible to train a MobileNet classifier.

User study

The files of the user study, which has been discussed in the paper, have been included in this repository in the /all_user_studies folder.

Limitations

  1. We likely do not support multi-GPU training. This was present in the original lucidrains repository, however we stripped out some parts for ease of programming. We would highly appreciate someone with a multi-GPU setup adding the functionality back in!

For more information, please look at the Github issues page at both the open and closed issues.

License

MIT

Acknowledgements

Our repository is based on the StyleGAN2 training code in PyTorch of the amazing repository of Github user lucidrains, stylegan2-pytorch. To their training code we added the StylEx training code.

The original TensorFlow notebook of the authors, including the AttFind algorithm from the authors has been translated to PyTorch. It has also been used to run their pretrained age StylEx model to extract experimental results. Both notebooks have been included.

explaining-in-style-reproducibility-study's People

Contributors

meilinar avatar noahvl avatar rtadijar avatar tmabraham avatar yourmediumprogrammer avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

explaining-in-style-reproducibility-study's Issues

Trouble running with `multi_gpus=True`

Hello,

I'm fairly new to pytorch so am not sure if this is a bug in the code or something that I'm doing incorrectly.

I've got 8 GPUs available (world_size=8) and in order to make use of them I passed multi_gpus=True to cli.py. When I do that I get an error (shown below). I expect that the code should be working with multiple gpus since I've seen others commenting about training with multiple GPUs on other issues.

[UPDATE]: Some other info that might be useful: a) I'm training using stylex_train_new because my images are 256x256 and training would not converge using stylex_train; b) GPUs are A100s

Traceback (most recent call last):
  File "cli.py", line 263, in <module>
    main()
  File "cli.py", line 259, in main
    fire.Fire(train_from_folder)
  File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "cli.py", line 252, in train_from_folder
    mp.spawn(run_training,
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 
-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/cnvrg/stylex/cli.py", line 72, in run_training
    retry_call(model.train, tries=3, exceptions=NanException)
  File "/opt/conda/lib/python3.8/site-packages/retry/api.py", line 101, in retry_call
    return __retry_internal(partial(f, *args, **kwargs), exceptions, tries, delay, max_delay, backoff, jitter, logger)
  File "/opt/conda/lib/python3.8/site-packages/retry/api.py", line 33, in __retry_internal
    return f()
  File "/cnvrg/stylex/stylex_train_new.py", line 1478, in train
    rec_loss = self.rec_scaling * reconstruction_loss(image_batch, generated_images,
  File "/cnvrg/stylex/stylex_train_new.py", line 415, in reconstruction_loss
    loss = 0.1 * lpips_loss(encoder_batch_norm, generated_images_norm).mean() +\
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/lpips/lpips.py", line 118, in forward
    in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/lpips/lpips.py", line 154, in forward
    return (inp - self.shift) / self.scale
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

Optimization procedure ambiguities

Notes on factors of variation during training

Quotes are words of author.

Optimization procedure

  • Alternating training - Check influence on training time, image sharpness and stability

    We noticed this creates sharper looking images than just training an encoder-decoder network, however it's possible to use only an encoder, and the method would still work.

    Our results: Alternating seems to help on FFHQ according to Noah (but could also be due to better quality of the encoder or different weightings in losses. We will have to do two runs with just one value changed to make sure this actually helps). Inconclusive results for other datasets, most likely due to other optimization problems.

    TODO: Check influence on a simple dataset like MNIST. Perhaps compare on 64x64 on plants/faces once we've got other hyperparameters down

    Note: The functional mapping transforms noise to W, while the encoder transforms an image. Could there be trouble in training a GAN using W s coming from different networks, thus possibly having different distributions?

    Note2: I have a feeling the PL regularization may serve to connect these two distributions, as it keeps track of W statistics and make sure steps from the mean don't cause a lot of artifacts. Perhaps something to check.

  • Updating encoder-generator autoencoder using reconstruction and KL loss during discriminator training

    Waiting on author's response.

    Our results: Definitely makes MNIST train faster to legible images. Training on plants doesn't seem to suffer (need a good comparison with a baseline).

    Note: Usually, the generator is fixed while the discriminator is updated, so at face-value it seems questionable to nudge the generator at this time. However, it may be reasonable to do so as the rec/KL loss is a different one from the adversarial loss. Rec/KL loss should work to make the generator better, without making it better at fooling the discriminator.

  • Encoder learning rate - Check influence on stability and training time

    Waiting on author's response.

    Background: The functional mapping which takes Z to W in the original StyleGAN has a learning rate which is 0.01 times the base learning rate. Since our encoder serves to produce the W as well, it might be fruitful to make it's learning rate smaller as well.

    Our results: Training on plant dataset with the default learning rate led to unstable training (large, frequent reconstruction loss spikes), with no long-term decrease in reconstruction loss. Decreasing encoder learning rate by a factor of 0.05 produced a stabler training, with the reconstruction loss actually going down. This is with
    KL+Rec during discriminator training=True.

Hyperparams of interest

  • lr - How does the sweet spot change depending on image resolution and dataset? Any interactions with the abovementioned optimization procedure?

    We used a learning rate of 0.002

    Note: Their default learning rate is 10 times bigger than ours. Perhaps we can up the learning rate, assuming the encoder learning rate is lowered accordingly?

  • Loss scaling

    1. Reconstruction loss

      image (using LPIPS) and on the W-vector (using L1 loss), both with weight 0.1 + L1 on images using weight 1

      Results: Higher reconstruction loss scalings seem to make for faster training in our case. We generally scaled by 5 or 10.
      Note: 5/10 is of the order of the difference in learning rate. Perhaps we can afford a higher learning rate, resulting in exactly their scaling of the losses?

    2. KL loss

      Waiting on author's response.

      Results: No scaling (1) seems to work okay in our practice, with our default learning rate.
      (Noah: I wonder if a higher reconstruction loss would make it take more importance over the KL loss, which could in turn reduce the effects of classifier-specific training, which might affect our attribute find results down the line).

      Best scaling dependent on dataset?

      Waiting on author's response

  • ttur_mult - Flat scaling to discriminator loss (default 1.5)
    Note: Might be worth fiddling with. Google search gives value 1 might be good for FFHQ

  • and many more!

Architecture questions

  • Softmaxed logits:

    1. Should we softmax the logits that we append to w? In the paper, it is mentioned that the unchanged logits are used, however, from the notebook it becomes apparent that the softmax logits are used. Did the authors discover that one of the two versions works better?

    Waiting on author's response.

    1. From the notebook, it seems that the softmax logits are provided to the discriminator as well, however, this is not mentioned in the paper. Why is this done?

    Waiting on author's response.

    1. When transforming z to w using the mapping network (while doing alternating training), do the authors include the target class in z or do theyappend it later to w as logits? Are these logits randomized or what approach is taken to generate/append these logits in this noise sampling and transformation case?

    Waiting on author's response.

Still using centercrop for FFHQ

We have already decided with MNIST we would not use centercrop anymore as it might affect performance by cutting out essential info. I think this can also apply to FFHQ due to the crop possibly cutting out information at the edges such as chin or hair that can help with the gender classification.

Will test this.

KL divergence NaN issue

Hi,

During the GAN training I'm facing an issue. The KL divergence loss between real image logits and fake image logits turns out to be NaN after a a couple of steps. I'm printing the KL divergence after every step, and I found in the first step itself, the KL divergence is huge, and in the following step it becomes negative.

After Step 1:
G: 23949.90 | D: 30.51 | GP: 11.63 | Rec: 4.70 | KL: 2703973204759576721869805868380323840.00

Any idea why this is the case ? @NoahVl

IndexError: arrays used as indices must be of integer (or boolean) type

image

When I reproduce the source code, "img_ind" and "img_inx" will correspond to two pairs, but here I have trained the model with my own data, and the generated hdf5 is associated with this code, the error is reported as follows.
IndexError: arrays used as indices must be of integer (or boolean) type
I see that you commented before this code "If the next code block gives an "arrays used as indices must be of integer (or boolean) type" you might want to run with more images.
", I understand is in the generation of hdf5, use more num_images. so I tried to use the num_images to 100, but no still will appear the same error, please ask here is how to deal with it?

Classifier acc

Hello, I am trying to train a resnet classifier using your code, but the ACC that I am getting is about 0.58. Could you report the ACC that you got?
Furthermore, I thing that in the notebook "classifier_training_celeba" in the CelebA class there is a bug as you resize the image twice, first to img_size and after that to 224.

KL Div 0

The KL Divergence loss seems to be 0 for the MNIST one vs all classifier for all timesteps. It used to be a bit higher than 0 when we used 10 classes instead of 2. It might just be caused by the high accuracy of the classification model, however it seems strange that it is 0 for all timesteps.

Reloading Alex LPIPS model every 777 timesteps

The LPIPS models seem to get reloaded every 777 timesteps, as if all the globals are then reloaded for some reason. Might be fixed if we place these LPIPS loss model in the Trainer class instead of in the globals. Also strange that it seems to be loaded 4 times everytime, does this have to with the amount of workers?

Issue with reconstruction during training?

I was trying to train the StyleGAN and originally with sample_from_encoder=True I got the following results in the middle of training:

595-from_encoder

this at tick 595, with evaluation every 50 steps like you had as default.

I then turned off sample_from_encoder=False and it looks like it is doing some sort of generation:
image

this is at tick 229, with evaluation every 50 steps (top half real images, bottom half fake).

So clearly the generation is fine (apart from maybe some mode collapse, but I can probably solve that with some tuning of the StyleGAN2 parameters).

So does that mean there is something wrong with the encoder training? How can I resolve this issue?

stylex_train_new

Hello,
Thank you for this implementation. I have several quetions:

  • I am trying to train stylex on BDD dataset (driving scenes) and i was wondering what is the difference between stylex_train_new and stylex_train? I saw that in th new code you set a lower learning rate for the encoder which helped me stabilize the training and avoid big loss values.

  • I have another question regarding the encoder. you seem to have tested other encoder architechtures other than the discriminator architechture. Did you get better results using other architectures? In my training the encoder is recontructing almost the same image at each step. At first I thought it was a mode collapse but when I checked images generated without the encoder they were more diverse.

  • In early training ( iteration < 15k) I always have a mode collapse. I am using a total batch size of 32 (4 gpus) and a gradient accumulation=4 and img_size=64. Do you have any advice to prevent mode collapse?
    Generated images at iteration 3250 (not generated using encoder)
    65-

  • Is using a batch size of 8 and gradient accumulation of 4 equivalent to batch size of 32 and gradient accumulation = 1 in your implementation?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.