Giter Site home page Giter Site logo

p2-weighting's Introduction

P2 weighting (CVPR 2022)

This is the codebase for Perception Prioritized Training of Diffusion Models.

This repository is heavily based on openai/guided-diffusion.

P2 modifies the weighting scheme of the training objective function to improve sample quality. It encourages the diffusion model to focus on recovering signals from highly corrupted data, where the model learns global and perceptually rich concepts. Below figure shows the weighting schemes in terms of SNR.

snr_weight

Pre-trained models

All models are trained at 256x256 resolution.

Here are the models trained on FFHQ, CelebA-HQ, CUB, AFHQ-Dogs, Flowers, and MetFaces: drive

Requirements

We tested on PyTorch 1.7.1, single RTX8000 GPU.

Sampling from pre-trained models

First, set PYTHONPATH variable to point to the root of the repository. Do the same when training new models.

export PYTHONPATH=$PYTHONPATH:$(pwd)

Put model checkpoints into a folder models/.

Samples will be saved in samples/.

python scripts/image_sample.py --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_res_blocks 1 --num_head_channels 64 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 250 --model_path models/ffhq_p2.pt --sample_dir samples

To sample for 250 timesteps without DDIM, replace --timestep_respacing ddim25 to --timestep_respacing 250, and replace --use_ddim True with --use_ddim False.

Training your models

--p2_gamma and --p2_k are two hyperparameters of P2 weighting. We used --p2_gamma 0.5 --p2_k 1 and --p2_gamma 1 --p2_k 1 in the paper.

Logs and models will be saved in logs/. You should modify --data_dir.

We used lightweight version (93M parameter) of ADM (over 500M) as default model configuration. You may modify the model.

python scripts/image_train.py --data_dir data/DATASET_NAME --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 1 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --lr 2e-5 --batch_size 8 --rescale_learned_sigmas True --p2_gamma 1 --p2_k 1 --log_dir logs

p2-weighting's People

Contributors

jychoi118 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.