Giter Site home page Giter Site logo

ddpm-ip's Introduction

PWC PWC PWC

DDPM-IP

This is the codebase for the paper Input Perturbation Reduces Exposure Bias in Diffusion Models.

This repository is heavily based on openai/guided-diffusion, with training modification of input perturbation.

Coming soon (results on higher resolution dataset)

  • FFHQ 128x128
  • FFHQ 256x256

Simple to implement Input Perturbation in diffusion models

Our proposed Input Perturbation is an extremely simple plug-in method for general diffusion models. The implementation of Input Perturbation is just two lines of code.

For instance, based on guided-diffusion, the only code modifications are in the script guided_diffusion/gaussian_diffusion.py, in line 765-766:

new_noise = noise + gamma * th.randn_like(noise)  # gamma=0.15 for CIFAR10, gamma=0.1 for other datasets
x_t = self.q_sample(x_start, t, noise=new_noise)

NOTE THAT: change the parameter GPUS_PER_NODE = 4 in the script dist_util.py according to your GPU cluster configuration.

Installation

the installation is the same with guided-diffusion

git clone https://github.com/forever208/DDPM-IP.git
cd DDPM-IP
pip install -e .

Download ADM-IP pre-trained models

We have released checkpoints for the main models in the paper.

Here are the download links for each model checkpoint:

Sampling from pre-trained ADM-IP models

To unconditionally sample from these models, you can use the image_sample.py scripts. Sampling from DDPM-IP has no difference with sampling from openai/guided-diffusion since DDPM-IP does not change the sampling process.

For example, we sample 50k images from CIFAR10 by:

mpirun python scripts/image_sample.py \
--image_size 32 --timestep_respacing 100 \
--model_path PATH_TO_CHECKPOINT \
--num_channels 128 --num_head_channels 32 --num_res_blocks 3 --attention_resolutions 16,8 \
--resblock_updown True --use_new_attention_order True --learn_sigma True --dropout 0.3 \
--diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True --batch_size 256 --num_samples 50000

Results

This table summarizes our input perturbation results based on ADM baselines. Input perturbation shows tremendous training acceleration and much better FID results.

FID computation details:

  • All FIDs are computed using 50K generated samples (unconditional sampling).
  • For CIFAR10 and ImageNet 32x32, we use the whole training data as the reference batch,
  • For LSUN tower 64x64 and CelebA 64x64, we randomly pick up 50k samples from the training set, forming the reference batch

This table summarizes our input perturbation results based on DDIM baselines.

Prepare datasets

Please refer to README.md for the data preparation.

Training ADM-IP

Training diffusion models is described in this repository.

Training ADM-IP only requires one more argument --input perturbation 0.1 (set --input perturbation 0.0 for the baseline).

NOTE THAT: if you have problem with slurm multi-node training, try the following setting. Let's say training by 16 GPUs on 2 nodes:

#SBATCH --nodes=2
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=6
#SBATCH --gres=gpu:8 # 8 gpus for each node

instead of specifying mpiexec -n 16, you run by mpirun python script/image_train.py. (more discussion can be found here)

We share the complete arguments of training ADM-IP in the four datasets:

CIFAR10

mpiexec -n 2  python scripts/image_train.py --input_pertub 0.15 \
--data_dir PATH_TO_DATASET \
--image_size 32 --use_fp16 True --num_channels 128 --num_head_channels 32 --num_res_blocks 3 \
--attention_resolutions 16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.3 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 64

ImageNet 32x32 (you can also choose dropout=0.1)

mpiexec -n 4  python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 32 --use_fp16 True --num_channels 128 --num_head_channels 32 --num_res_blocks 3 \
--attention_resolutions 16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.3 --diffusion_steps 1000 --noise_schedule cosine \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 128

LSUN tower 64x64

mpiexec -n 16  python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 64 --use_fp16 True --num_channels 192 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 16

CelebA 64x64

mpiexec -n 16  python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 64 --use_fp16 True --num_channels 192 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 16

Citation

If you find our work useful, please feel free to cite by

@article{ning2023input,
  title={Input Perturbation Reduces Exposure Bias in Diffusion Models},
  author={Ning, Mang and Sangineto, Enver and Porrello, Angelo and Calderara, Simone and Cucchiara, Rita},
  journal={arXiv preprint arXiv:2301.11706},
  year={2023}
}

ddpm-ip's People

Contributors

forever208 avatar unixpickle avatar leedoyup avatar prafullasd avatar erinbeesley avatar liujianzhi avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.