Giter Site home page Giter Site logo

ddib's Introduction

Dual Diffusion Implicit Bridges (ICLR 2023)

Dual Diffusion Implicit Bridges for Image-to-Image Translation
Xuan Su, Jiaming Song, Chenlin Meng, Stefano Ermon
ICLR '23 | GitHub | arXiv | Colab | Project Page

Overview

Common image-to-image translation methods rely on joint training over data from both source and target domains. The training process requires concurrent access to both datasets, which hinders data separation and privacy protection; and existing models cannot be easily adapted to translation of new domain pairs. We present Dual Diffusion Implicit Bridges (DDIBs), an image translation method based on diffusion models, that circumvents training on domain pairs. Image translation with DDIBs relies on two diffusion models trained independently on each domain, and is a two-step process: DDIBs first obtain latent encodings for source images with the source diffusion model, and then decode such encodings using the target model to construct target images. Both steps are defined via ordinary differential equations (ODEs), thus the process is cycle consistent only up to discretization errors of the ODE solvers. Theoretically, we interpret DDIBs as concatenation of source to latent, and latent to target Schrödinger Bridges, a form of entropy-regularized optimal transport, to explain the efficacy of the method. Experimentally, we apply DDIBs on synthetic and high-resolution image datasets, to demonstrate their utility in a wide variety of translation tasks and their inherent optimal transport properties.

Installation

Installation follows the same procedures as in the above repositories.

We first install the current repository, and then install other libraries like numpy, matplotlib etc. My successful installation contains the following version numbers, with Python 3.9:

pip install -e .
pip install numpy==1.24.0 matplotlib==3.6.2 scikit-image==0.19.3 scikit-learn==1.2.0 gdown==4.6.0
conda install -c conda-forge mpi4py openmpi

Synthetic Models

We release pretrained checkpoints for the 2D synthetic models in the paper.

Installation

Downloading via script: In your repository, run python download.py --exp synthetic to download the pretrained synthetic models. The script will create a directory models/synthetic and automatically download the checkpoints to the directory.

Downloading manually: As an alternative, you can also download the checkpoint manually. Here is the download link for the model checkpoints: Synthetic Models

Indexes. We use indexes 0-5 to refer to the 6 synthetic types, in: [Moons, Checkerboards, Concentric Rings, Concentric Squares, Parallel Rings, Parallel Squares].

How are the datasets generated? The key file to look at is: guided_diffusion/synthetic_datasets.py. We implement the data generation and sampling processes for various 2D modalities.

After running the download script, we can run the cycle consistency, synthetic translation and sampling experiments below.

Training Synthetic Models

python scripts/synthetic_train.py --num_res_blocks 3 --diffusion_steps 4000 --noise_schedule linear --lr 1e-4 --batch_size 20000 --task 0

Task is an integer in {0, 1, 2, 3, 4, 5} corresponding to one of the synthetic types.

Training each model probably takes only a few (3-4) hours on a GPU.

The models and logs are saved to the directory at OPENAI_LOGDIR. If you want to save the model files to your desired folder, modify the variable via export OPENAI_LOGDIR=...

Cycle Consistency

python scripts/synthetic_cycle.py --num_res_blocks 3 --diffusion_steps 4000 --batch_size 30000 --source 0 --target 1

The above command runs the cycle-consistent translation experiment in the paper: between datasets Moons (0) and Checkerboards (1). The generated experiment plots are saved under the new directory experiments/images.

Synthetic Translation

python scripts/synthetic_translation.py --num_res_blocks 3 --diffusion_steps 4000 --batch_size 30000 --source 0 --target 3

The above command performs translation between the two synthetic domains and saves the resulting plots to experiments/images.

Sample from Synthetic Models

python scripts/synthetic_sample.py --num_res_blocks 3 --diffusion_steps 4000 --batch_size 20000 --num_samples 80000 --task 1

Miscellaneous

ImageNet Translation

Installation

Download the model weights. Similarly, run python download.py --exp imagenet to download the pretrained, class-conditional ImageNet models from guided-diffusion. The script will create a directory models/imagenet and put the classifier and diffusion model weights there.

Copy the validation dataset. We use the ImageNet validation set for domain translation. Two steps:

  • Download the ImageNet validation set from ILSVRC2012. Unzip it. This will create a folder containing image files named like "ILSVRC2012_val_000XXXXX.JPG".
  • Remember the path to the folder containing the validation set as val_dir, as we'll need it for the next command.

We are now ready to translate the images.

Translation between ImageNet Classes

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python scripts/imagenet_translation.py $MODEL_FLAGS --classifier_scale 1.0 --source 260,261,282,283 --target 261,262,283,284 --val_dir path_to_val_set

The above command copies the validation images for the source classes to ./experiments/imagenet. Translated images are placed in the same folder, with the target class appended in the filename.

We can update source and target to translate between other ImageNet classes. The corresponding val images are copied automatically.

We translate the domain pairs in the specified order. For example, in the above command, we translate from class 260 to 261, 283 to 284, etc.

We can experiment with classifier_scale to guide the denoising process towards the target class with different strengths.

We can prepend the Python command with mpiexec -n N to run it over multiple GPUs. For details, refer to guided-diffusion.

References and Acknowledgements

@inproceedings{
      su2022dual,
      title={Dual Diffusion Implicit Bridges for Image-to-Image Translation},
      author={Su, Xuan and Song, Jiaming and Meng, Chenlin and Ermon, Stefano},
      booktitle={International Conference on Learning Representations},
      year={2023},
}

This implementation is based on / inspired by: OpenAI: openai/guided-diffusion and openai/improved-diffusion.

To-do List

  • Release pretrained models on AFHQ and yosemite datasets
  • Add color translation experiments
  • Add scripts to translate between AFHQ, yosemite images

ddib's People

Contributors

chenlin9 avatar suxuann avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

ddib's Issues

Some confuse of the reproduction results

I'm interested in this paper of yours. In the task of image translation, I found that the effect of ddib was not as good as that of cyclegan by comparing the results of reproduction. Meanwhile, the model mentioned in the paper also had poor retention of the original picture content. Could you explain the reason for this result, or could it be improved by adjusting the hyperparameters?

Looking forward to your reply,thanks!

About the DDIM reconstruction error.

Hi there,

I've been attempting to use your code to reconstruct the original image with DDIM sampling functions, but I'm having some difficulty. Specifically, the image below is the result of my attempt to reconstruct the FFHQ dataset using a pre-trained diffusion model trained on the CelebA-HQ dataset.

Unfortunately, the reconstruction error is much larger than I expected. I was wondering if there might be something I'm overlooking or not taking into account? Any insights you could provide would be greatly appreciated.

Thank you.

recon_ffhq_clip=TF

def main():
    # ...

    for batch_idx, data in enumerate(dataloader):
        img_batch = data[0].to(device)
        noise = diffusion.ddim_reverse_sample_loop(
            model,
            img_batch,
            clip_denoised=False,
            device=device,
            progress=True,
        )
        recon_sample = diffusion.ddim_sample_loop(
            model,
            (args.batch_size, 3, args.image_size, args.image_size),
            noise=noise,
            clip_denoised=True,
            device=device,
            eta=args.eta,
            progress=True,
        )
        torchvision.utils.save_image(
            (torch.cat([img_batch, recon_sample], dim=0) + 1.) / 2.,
            f"results/recon_ffhq_clip=FT.png",
            nrow=4,
        )

    # ...

why cycleconsistency is not robust

Dear @suxuann ,

Thanks for sharing the awesome work.

I tried DDIB for modality transfer: CT image to MR image.

I trained a CT model and an MR model on my own dataset based on guided-diffusion, respectively. I have verified that they can generate good samples.

Then, I tried cycle consistency and modality transfer CT-> MR. However, the cycle consistency is not robust and the transferred MR images have very different structures.

Here are some examples:

ct_0

ct_1

ct_2

ct_3

What could be the possible reason? Any comments are highly appreciated.

where is my model?

I'm a newbie in machine learning and I'm using the following script to train on my own dataset:
python scripts/synthetic_train.py --num_res_blocks 3 --diffusion_steps 4000 --noise_schedule linear --lr 1e-4 --batch_size 2000 --task 0
But the model has been trained for three days, there is no sign of stopping, and the step has reached 1e+06, where can I find the model I trained?

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

I have the following error, can you help to take a look?

(/home/ec2-user/SageMaker/env/test) sh-4.2$ python synthetic_train.py --num_res_blocks 3 --diffusion_steps 4000 --noise_schedule linear --lr 1e-4 --batch_size 20000 --task 1

Logging to /tmp/openai-2022-11-08-05-25-53-353605
args: Namespace(task=1, schedule_sampler='uniform', lr=0.0001, weight_decay=0.0, lr_anneal_steps=1000, batch_size=20000, microbatch=-1, ema_rate='0.9999', log_interval=10, save_interval=10000, resume_checkpoint='', use_fp16=False, fp16_scale_growth=0.001, num_channels=256, num_res_blocks=3, dropout=0.2, use_checkpoint=False, in_channels=2, learn_sigma=False, diffusion_steps=4000, noise_schedule='linear', timestep_respacing='', use_kl=False, predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False)
[W ProcessGroupGloo.cpp:694] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
Logging to /tmp/openai-2022-11-08-05-25-53-357098
creating 2d model and diffusion...
creating 2d data loader...
training 2d model...
Traceback (most recent call last):
File "/home/ec2-user/SageMaker/DDIB/ddib/synthetic_train.py", line 82, in
main()
File "/home/ec2-user/SageMaker/DDIB/ddib/synthetic_train.py", line 40, in main
TrainLoop(
File "/home/ec2-user/SageMaker/DDIB/ddib/guided_diffusion/train_util.py", line 67, in init
self._load_and_sync_parameters()
File "/home/ec2-user/SageMaker/DDIB/ddib/guided_diffusion/train_util.py", line 122, in _load_and_sync_parameters
dist_util.sync_params(self.model.parameters())
File "/home/ec2-user/SageMaker/DDIB/ddib/guided_diffusion/dist_util.py", line 83, in sync_params
dist.broadcast(p, 0)
File "/home/ec2-user/SageMaker/env/test/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1408, in broadcast
work.wait()
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

The noise image obtained by inverse DDIM is not like a Gaussian distribution!

I try to use DDIB to achieve the MR-CT image translation. I have trained diffusion models for unconditional generation of MR and CT images respectively. Given an MR image, I use MR diffusion model and inverse DDIM to obtain the corresponding noise image in latent space (steps=1000, the total step of trained diffusion model is also 1000). However, this noise image does not seem like an isotropic noise distribution (as shown in Fig. 1 in the paper):

图片1

Taking the obtained noise image as the input of CT diffusion model and sampling with forward DDIM, the generated CT image is not ideal.

Eagerly awaiting the experiment of the color transformation!

Hello author, I really appreciate the work you have done and thank you for making the experimental code public! I am eager to add color translation experiments to the repository, thank you very much! And when can the color translation experiments will release?Please forgive me if I offend you.

Translation failed when reproducing with imagent dataset

Hi,
I trained a model for 150k rounds with the parameters given in the codebase, the dataset is imagnet and a classification model, but the translation failed, what is the cause?
train dataset

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
TRAIN_FLAGS="--batch_size 1 --lr 1e-4"
python3 scripts/image_train.py --data_dir /opt/data/private/ImageNet_ILSVRC2012/train $MODEL_FLAGS $TRAIN_FLAGS

classifier train

TRAIN_FLAGS="--iterations 300000 --anneal_lr True --batch_size 256 --lr 3e-4 --save_interval 10000 --weight_decay 0.05"
CLASSIFIER_FLAGS="--image_size 256 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True"
python3 scripts/classifier_train.py --data_dir /opt/data/private/ImageNet_ILSVRC2012/train $TRAIN_FLAGS $CLASSIFIER_FLAGS

translation
python3 scripts/imagenet_translation.py $MODEL_FLAGS --classifier_scale 1.0 --source 260 --target 284 --val_dir /opt/data/private/ImageNet_ILSVRC2012/val --model_path /opt/data/private/ddib/models/test/ILSVRCModel/resume/model150000.pt --classifier_path /opt/data/private/ddib/models/test/classifier/model299999.pt

dog to cat (Both use the trained classifier model,This proves that the classifier is successful)
Results obtained using a pre-trained model
260_1
260_1_translated_284_true
Results from the model I trained
260_1
260_1_translated_284

Code Release

Hi, thank you so much for releasing your code to do translation between synthetic datasets. Do you know when you will be able to release code for tasks you mentioned in your todo list.

When do you plan to release the pretrained models on AFHQ?

Hi Xuan, just wondering, do you have an approximate timeline to release the pretrained models on AFHQ in your to-do list? thanks very much!

  • Release pretrained models on AFHQ and yosemite datasets
  • Add color translation experiments
  • Add scripts to translate between AFHQ, yosemite images

Training on my own datasets

Hi there, if I had the paired image datasets A B. I want to train a DDIB model translate the B to A (B is synthetic dataset from A). Which script I need to run? I'm a little confused about the README. Could you explain for me? Thanks a lot!

Instructions on creating my own dataset?

Are there any instructions on creating my own dataset and starting to train from scratch (rather than loading pretrained model). Also do you think the image to image translation (Style transfer) would work well in medical images? I'm trying to use histopathological images (H&E) images.
Also, like others have said, when would the color translation model be released? I'm really looking forward to applying this.

I would appreciate a response. Thank you!

how to train this model in own dataset?

Thanks for your job! Due to my scientific research,I need to train this model in own dataset,and I found that the synthetic_train.py is the only one file about training this model,i try to change the dataset filepath, but it is nothing about filepath in the file, so please tell me a solution to solve this question,i will appreciate for you!

Reproduce Cycle Consistency poorly

I use the provided pre-trained model and script to reproduce Cycle Consistency but the results are very poor, My l2 distance is more than 10.how can I get the results in the paper?

scatter_source
scatter_target
scatter_source_2

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.