Giter Site home page Giter Site logo

disdiff's Introduction

DisDiff: Unsupervised Disentanglement of Diffusion Probabilistic Models

This is the official PyTorch implementation of NeurIPS 2023 paper: DisDiff: Unsupervised Disentanglement of Diffusion Probabilistic Models Arxiv | OpenReview).

DisDiff: Unsupervised Disentanglement of Diffusion Probabilistic Models
Tao Yang, Yuwang Wang, Yan Lu, Nanning Zheng
arXiv preprint arXiv:2102.10303
NeurIPS 2023

We connect disentangled representation learning to Diffusion Probabilistic Models (DPMs) to take advantage of the remarkable modeling ability of DPMs. We propose a new task, disentanglement of (DPMs): given a pre-trained DPM, without any annotations of the factors, the task is to automatically discover the inherent factors behind the observations and disentangle the gradient fields of DPM into subgradient fields as shown in Figure 1.

Figure 1: Illustration of disentanglement of DPMs. (a) is the diagram of Probabilistic Graphical Models (PGM). (b) is the diagram of image space. (c) is the demonstration of sampled images in (b). 

To tackle this task, we devise an unsupervised approach as shown in Figure 2, named DisDiff, achieving disentangled representation learning in the framework of DPMs.

Figure 2: Illustration of DisDiff. (a) Image x0 is first encoded to representations {z1, z2, . . . zN} by encoder,  (b) We first sample a factor c and decode the representation zc   to obtain the predicted x0. At the same time, we can obtain the predicted x0 of the original pre-trained DPM. We then calculate the disentangling loss based on them.

requirements

A suitable conda environment named disdiff can be created and activated with:

conda env create -f environment.yaml
conda activate disdiff

Usage

from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
cli = OmegaConf.from_dotlist(unknown)
configs = [OmegaConf.load(cfg) for cfg in opt.base]
config = OmegaConf.merge(*configs, cli)

model = instantiate_from_config(config.model)

Datasets

Download datasets for training DisDiff (Shapes3D as an example)

Manully set "/path/to/your/datasets/" in the following scripts in ldm/data/dis.py, e.g., For Shapes3D. We set the following script for setting dataset path.

class Shapes3DTrain(Shapes3D):
    def __init__(self, **kwargs):
        super().__init__(path='/path/to/your/datasets/',
                original_resolution=None,
                **kwargs)

Dataset path structure:

/path/to/your/datasets/
├──shapes3d/3dshapes.h5
├──mpi_toy/mpi3d_toy.npz
├──...

Trianing of DisDiff

Similar to LDM, the experiments are conducted by main.py script. The models are trained with this script with different configs. Since DisDiff is based on LDM model, we first need to train LDM in the following two stages.

Train VQVAE/VAE for LDM

main.py script conducts the training of LDM, we choose a config file for it to train our VQVAE/VAE. For example, we can choose configs/autoencoder/{dataset}_vq_4_16.yaml. A VQVAE/VAE can be trained with the following commend:

python main.py --base configs/autoencoder/{dataset}_vq_4_16.yaml -t --gpus 0,

dataset can be {cars3d, shapes3d, mpi3d, celeba}. As an example, we can choose dataset as shapes3d:

python main.py --base configs/latent-diffusion/shapes3d-vq-4-16.yaml -t --gpus 0,

Train LDM

main.py is the main script to train LDM with pretrained VQVAE/VAE. After the VQ-VAE is trained we set the checkpoint path of the pretrained model as follows: "ckpt_path: /my/vae_checkpoint/path" in each config file, e.g., configs/latent-diffusion/shapes3d-vq-4-16.yaml. We then can train a LDM for the pretrained VQVAE/VAE with the following commend:

python main.py --base configs/latent-diffusion/{dataset}_vq_4_16.yaml -t --gpus 0,

Trarin DisDiff

main.py is also the main script to train DisDiff for a pretrained LDM. After the LDM is trained, we set the checkpoint path of the pretrained models as follows: "ckpt_path: /my/ldm_checkpoint/path" in each config file, e.g., configs/latent-diffusion/shapes3d-vq-4-16-dis.yaml.

Note that we need to set the ckpt path for two pretrained models. The first is "ckpt_path" in "first_stage_config"; The second is "ckpt_path" in "unet_config". We then can train a LDM for the pretrained VQVAE/VAE with the following commend:

python main.py --base configs/latent-diffusion/{dataset}-vq-4-16-dis.yaml -t --gpus 0, -dt Z -dw 0.05 -l exp_vq_shapes3d -n {exp-name} -s 0

where {exp-name} is the experiment name of this setting. We give an name example "s0", which means random seed is 0.

Evaluation of DisDiff

run_para_metrics.py script conducts the evaluation of DisDiff in paralle. We only need to set the evaluation path by -l, the script will load the configerations and ckpts automaticlly. Note that we use -p to set the number of the processes.

python run_para_metrics.py -l exp_vq_shapes3d -p 10

Pretrained Models

Pretrained Autoencoder Download

Pretrained LDM Download

Pretrained DisDiff Download

Acknowlegement

Note that this project is built upon LDM and PDAE. The eveluation code is built upon disentanglement_lib.

Citation

@article{Tao2023DisDiff,
  title={DisDiff: Unsupervised Disentanglement of Diffusion Probabilistic Models},
  author={Yang, Tao and Wang, Yuwang and Lu, Yan and Zheng, Nanning},
  journal={NeurIPS},
  year={2023}
}

disdiff's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.