Giter Site home page Giter Site logo

discoverydiff's Introduction

Discovery and Expansion of New Domains within Diffusion Models

This is the official Pytorch implementation for the paper Discovery and Expansion of New Domains within Diffusion Models.

Ye Zhu, Yu Wu, Duo Xu, Zhiwei Deng, Yan Yan, Olga Russakovsky

Paper

Updates:

  • (06/2024) This codebase is still under construction and will be further structured and improved to facilitate its use.

1. Project Overview and Take Away

In this work, we seek to explore the domain generalization ability within diffusion generative models under the few-shot scenario and introduce a novel paradigm, which we refer to as latent discovery based method to achieve domain expansion. Unlike tuning-based methods that seek to change the generative mapping trajectories, we propose to find extra qualified latent encodings given the prior information from a small set of target OOD data samples.

The design philosophy of this tuning-free paradigm has been adopted in several of my previous projects for versatile DM-based applications in BoundaryDiffusion, NeurIPS 2023 and COW, ICLR 2024. However, the previous two projects still mainly focus on working within the In-Domain of pre-trained diffusion models, in this work, we seek to step out of the original training space of DMs.

2. Setup

2.1 Environment

The environment setup follow my previous diffusion based projects, an example for setting up the experimental environment can be found below.

conda create --name diff python=3.9
conda activate diff
# install the pytorch and cudatoolkit based on your own machine.
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
git clone https://github.com/L-YeZhu/DiscoveryDiff.git
cd DiscoveryDiff
pip install -r requirements.txt

2.2 Datasets

The dataset paths can be modified in ./configs/paths_configs.py file.

  • Natural images:

I have tested some natural image domains including those from AFHQ-Dog, CelebA-HQ, and LSUN as demonstrations. These datasets are commonly used image datasets in computer vision, and all of them are open-access. Feel free to download and play with them from their corresponding sources.

  • Astrophysical images:

In addition to the natural images, I have tested several astrophysical datasets as unseen target domains in this project, which includes Galaxy Zoo and the radiation field of the molecular clouds (from this paper). Those datasets feature larger domain gaps with respect to the training domains given a pre-trained DDPM, also yeild better performance under our proposed method.

For the tested astrophysical datasets, Galaxy Zoo is an open-access dataset that contains real images of galaxies and annotations. The radiation field data has been used in Dr Duo Xu's previously published paper Predicting the Radiation Field of Molecular Clouds Using Denoising Diffusion Probabilistic Models in the Astrophysical Journal. In case you are interested in doing further research with the astrophysical dataset, we recommend you contact me or/and Dr. Duo Xu for more details.

Note that the original data of the radiation field are not images but physical quantities, we process and visualize them as images in this project for illustration purposes for researchers outside the astrophysical field. If you are interested in how we can interpret those radiation data in RGB, we have a small section in the appendices of the paper for brief clarification.

2.3 Base diffusion models

I used several different pre-trained diffusion generative models as base models for experiments as demonstrations, different base models along with their training datasets may yield diverse effects. The paths to the pre-trained models can be modified also in ./configs/paths_configs.py file.

Generic and unconditional DMs

You can find those model checkpoints from existing open sources depending on the resolution and DDPM variants you want to experiment with. The choices of the base model should stay consistent with your experimental datasets in terms of resolutions. For easy usage, I include the links to my own model collections below (I didn't train those base models, credits to previous researchers): iDDPM trained on AFHQ-Dog in 256 resolutions, DDPM trained on CelebA in 256 resolutions, DDPM trained on LSUN-church in 256 resolutions, DDPM trained on LSUN-Bedroom in 256 resolutions.

Large T2I models

I also tested two large T2I models i.e., StableDiffusion_v2.1 and GPT4, with the following same text prompts for the astro tasks above. We observe that the existing T2I models are not (yet) appropriate for generating such specific astrophysical data given they are trained to optimize the modeling of overall distributions.

1. Generate a realistic image of galaxy.
2. Generate a realistic image about the radiation field of molecular clouds in astronomy.

3. Analytical Experiments on Representation Ability

Our work is based on the key observation: a DM trained even on a single domain small dataset already has sufficient representation ability to accurately reconstruct arbitrary unseen images from the inverted latent encoding following a relatively deterministic denoising trajectory. And then followed by the question of how we can leverage its powerful representation ability to achieve novel, creative, and useful applications.

3.1 Unseen reconstruction

Given a pre-trained diffusion domain on a single domain dataset (e.g., dog faces on AFHQ-Dog-256), we aim to show that the pre-trained model can reconstruct an arbitraty image with deterministic inversion and denoising processes [1].

To test this, you can use the following:

python main.py --config {DATASET}.yml --unseen_reconstruct --exp ./runs/ --n_inv_step 80 --t_0 900 --eta 0.0

3.2 Inverted unseen priors

For the inversion part, the function --inversion converts the raw OOD data to the t_0 latent space.

python main.py --config {DATASET}.yml --inversion --exp ./runs/ --n_inv_step 80 --t_0 900 --eta 0.0

3.3 Baseline of vanilla fine-tuning

We mainly compare with the vanilla fine-tuning method as the baseline in this project. In fact, we found that vanilla tuning with only image supervisions is extremely hard for pre-trained DDPMs, and impose almost no impact on the synthesized output even for close domain transitions. You can test the vanilla fine-tuning using the command below:

python main.py --config {DATASET}.yml --finetune --exp ./runs/ 

4. Latent Sampling Methods

This is the key technical challenge for this project, due to the mode interference issue that we have described in the paper. It is possible to perform latent sampling with various techniques, below I list one of them that use the latent directions and some geometric optimizations.

python main.py --config {DATASET}.yml --unseen_sample --exp ./runs/ --n_inv_step 80 --t_0 900 --eta 0.0

5. Evaluations and Applications

The evaluation should generally follow the corresponding protocols of the end tasks. In other words, if your objective is to generate new natural images, then it is possible to evaluate the results based on some commonly used scores like FID. However, if your end task is more specific, as in our showcase with the astrophysical data simulation, you may want to follow whatever evaluations the end application areas are doing.

6. Citation

If you find this work useful and interesting, please consider citing the paper.

@article{zhu2024discovery,
  title = {Discovery and Expansion of New Domains within Diffusion Models},
  author = {Zhu, Ye and Wu, Yu and Xu, Duo and Deng, Zhiwei and Russakovsky, Olga and Yan, Yan},
  journal = {arXiv preprint arXiv:2310.09213},
  year = {2024},
}

References

[1] Song, Jiaming, Chenlin Meng, and Stefano Ermon. “Denoising diffusion implicit models.” In ICLR 2021.

discoverydiff's People

Contributors

l-yezhu avatar

Stargazers

mortal avatar  avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.