Giter Site home page Giter Site logo

boschresearch / aldm Goto Github PK

View Code? Open in Web Editor NEW
50.0 5.0 3.0 57.15 MB

Official implementation of "Adversarial Supervision Makes Layout-to-Image Diffusion Models Thrive" (ICLR 2024)

Home Page: https://yumengli007.github.io/ALDM/

License: GNU Affero General Public License v3.0

Shell 0.30% Python 22.14% Jupyter Notebook 77.56%
aigc controlnet diffusion-models iclr2024 semantic-image-synthesis stable-diffusion generative-ai generative-model layout-to-image bcai

aldm's Introduction

Adversarial Supervision Makes Layout-to-Image Diffusion Models Thrive (ALDM)

πŸ”₯ Official implementation of "Adversarial Supervision Makes Layout-to-Image Diffusion Models Thrive" (ICLR 2024)

arXiv Static Badge Static Badge

overview result


Getting Started

Our environment is built on top of ControlNet:

conda env create -f environment.yaml  
conda activate aldm
pip install mit-semseg # for segmentation network UperNet

Pretrained Models

Pretrained models ade20k_step9.ckpt and cityscapes_step9.ckpt can be downloaded from here. They should be stored in the checkpoint folder.

Dataset Preparation

Datasets should be structured as follows to enable ALDM training. Dataset path should be adjusted accordingly in dataloader/cityscapes.py, dataloader/ade20k.py and dataloader/coco_stuff.py. Check convert_coco_stuff_id.ipynb for converting coco stuff labels.

Click to expand
datasets
β”œβ”€β”€ cityscapes
β”‚   β”œβ”€β”€ gtFine
β”‚       β”œβ”€β”€ train 
β”‚       └── val 
β”‚   └── leftImg8bit
β”‚       β”œβ”€β”€ train 
β”‚       └── val 
β”œβ”€β”€ ADE20K
β”‚   β”œβ”€β”€ annotations
β”‚       β”œβ”€β”€ train 
β”‚       └── val 
β”‚   └── images
β”‚       β”œβ”€β”€ train 
β”‚       └── val 
β”œβ”€β”€ COCOStuff
β”‚   β”œβ”€β”€ train_img
β”‚   β”œβ”€β”€ val_img
β”‚   β”œβ”€β”€ train_label
β”‚   β”œβ”€β”€ val_label
β”‚   β”œβ”€β”€ train_label_convert # New: after converting
β”‚   └── val_label_convert # New: after converting
└── ...

Inference

We provide three ways for testing: (1) JupyterNotebook, (2) Gradio Demo, (3) Bash scripts.

  1. JupyterNotebook: we provided one sample layout for quick test without requiring dataset setup.

  2. Gradio Demo:

    Run the command after the dataset preparation.

    gradio gradio_demo/gradio_seg2image_cityscapes.py
    

    demo


  1. Bash scripts: we provide some bash scripts to enable large scale generation for the whole dataset. The synthesized data can be further used for training downstream models, e.g., semantic segmentation networks.

Training

Example training bash scripts for Cityscapes and ADE20K training can be found here: bash_script/train_cityscapes.sh, bash_script/train_ade20k.sh.

The main entry script is train_cldm_seg_pixel_multi_step.py, and YAML configuration files can be found under models folder, e.g., models/cldm_seg_cityscapes_multi_step_D.yaml.

How to train on a new dataset?

To train on a new customized dataset, one may need to change the following places:

  1. Define a new dataset class and add it to the dataloader/__init__.py, cf. dataloader/cityscapes.py, where semantic classes need to be defined accordingly. The class language embedding, e.g., class_embeddings_cityscapes.pth can be generated using CLIP text encoder with a pre-defined prompt template, e.g., "A photo of {class_name}", which will produce embeddings of shape (N, 768), where N is the number semantic classes.

Note that, the class language embedding is not mandatory for the training. It doesn't impact the final performance, while we observe it can accelerate the training convergence, compared to the simple RGB-color coding.

  1. The captions of images, e.g., dataloader/ade20k_caption_train.json, can be obtained by vision-language models like BLIP and LLaVA.
  2. Adjust the segmenter-based discriminator, cf. cldm_seg/seg/ade_upernet101_20cls.yaml. Similar to the initialization in ControlNet here, one would need to manually match the semantic classes between the customized dataset and the pretrained segmenter. If there are new classes, where the pretrained segmenter wasn't trained on, one can simply initialize the weights randomly. Check out the example code snippet below, where a ADE20K pretrained UperNet is adjusted for Cityscapes.

Note that, essentially we update the generator and discriminator jointly during training, using a pretrained segmenter as initiliaztion can help to make the adversarial training more stable. So that's why the segmenter doesn't have to be trained on the same dataset.

Click to expand
    ### Cityscapes
    try:
        model = ADESegDiscriminator(segmenter_type='upernet101_20cls')
        # model.load_pretrained_segmenter()
    except:
        pass
    select_index = torch.tensor([6, 11, 1, 0, 32, 93, 136, 43, 72, 9, 2, 12, 150, 20, 83, 80, 38, 116, 128, 150]).long()
    
    old_model = ADESegDiscriminator(segmenter_type='upernet101')
    old_model.load_pretrained_segmenter()
    
    target_dict = {}
    
    for k, v in old_model.state_dict().items():
        print(k, v.shape)
        if 'conv_last.1.' in k:
            new_v = torch.zeros((20,) + v.shape[1:]).to(v.device)
            print(new_v.shape)
            new_v = torch.index_select(v, dim=0, index=select_index)
            new_v[12] = torch.randn_like(new_v[12])
            target_dict[k] = new_v
        else:
            target_dict[k] = v
    
    model.load_state_dict(target_dict, strict=True)
    output_path = './pretrained/ade20k_semseg/upernet101/decoder_epoch_50_20cls.pth'
    torch.save(model.state_dict(), output_path)

If an error occured due to the segmenter, e.g., "got an unexpected keyword argument 'is_inference'", check this issue here.

The above might not be a complete list of items need to be adjusted. Please don't hesitate to open issues in case of doubts. I will update the instruction accordingly to make it clearer.

Citation

If you find our work useful, please star this repo and cite:

@inproceedings{li2024aldm,
  title={Adversarial Supervision Makes Layout-to-Image Diffusion Models Thrive},
  author={Li, Yumeng and Keuper, Margret and Zhang, Dan and Khoreva, Anna},
  booktitle={The Twelfth International Conference on Learning Representations},
  year={2024}
}

License

This project is open-sourced under the AGPL-3.0 license. See the LICENSE file for details.

For a list of other open source components included in this project, see the file 3rd-party-licenses.txt.

Purpose of the project

This software is a research prototype, solely developed for and published as part of the publication cited above.

Contact

Please feel free to open an issue or contact personally if you have questions, need help, or need explanations. Don't hesitate to write an email to the following email address: [email protected]

aldm's People

Contributors

yumengli007 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

aldm's Issues

Training on new data guidance

Hi! Great work!
I want to train the model on my dataset.
Can you provide instructions on how this could be done?
What should we change to change the number of classes?

Thank you!

Out-of-distribution generation

Is it possible to generate out-of-distribution samples using custom colormaps without the need to finetune(similar to the FreestyleNet implementation here: https://github.com/essunny310/FreestyleNet)?

According to my understanding of the codebase, the inference seems to be limited to only the colormaps defined in the dataloader modules for the respective datasets (coco-stuff/ADE20K/Cityscapes).

If yes, could you provide an example, or outline the steps needed to implement this feature?

Thanks in advance!

Releasing the data

Thanks for publishing the code.

I was wondering if you could also release the generated data that was used to train the semantic segmentation models?

Manual definition of classes for new dataset

For definition of classes for a new dataset, do I need a dataset_info.json file as for cityscapes?
Should the new dataloader has the structure of cityscapes.py or ade20k.py? Note that my new dataset has some classes from ade20 and some new classes.

Can you also explain, why in the function conver_labels in the ade20k.py file, we subsstract q from the label: label = label -1
Why don't we keep the ids as they are?

IndexError: index out of range in self

/ALDM/cldm_seg/util.py", line 41, in forward
seg_emb = torch.index_select(self.class_embeddings, 0, seg_map_.cpu())
IndexError: index out of range in self

I am using the structure of cityscapes.py for my new dataset, where I am using self.lb_map = {el['id']: el['trainId'] for el in labels_info}
I created a json file "newdataset_info.json" where I packed :
{
"name": "terrain",
"ignoreInEval": false,
"id": 30,
"color": [
112,
9,
255
],
"trainId": 30
}
for each class. Note that I put all my classes and I am not ignoring any of them (No trainId = 255).
When I run the code, I receive this error (See above).
However, I tried to add 3 fictive classes to the json file, named them randomly but put their trainId as 255, it worked. I dont understand what is the problem here? And I don't know if the way it made the program compile is true and does not affect the training?

got an unexpected keyword argument 'is_inference'

pic_1

When I remove is_inference from:

class NewSegmentationModule(SegmentationModule):
def forward(self, image, label, is_inference=None):
segSize = (label.shape[-2], label.shape[-1])
pred = self.decoder(self.encoder(image, return_feature_maps=True), segSize=segSize, is_inference=is_inference)
return pred

The error is gone but I get another error:

pic_2

Inquiry about GPU Training Time

Thank you for your incredible work!I would like to ask about the GPU training time required for the model. Additionally, could you please provide guidance on how to find the evaluation functions for mIOU and FID?

Evaluation Metrics.

Hi author, could you please provide the code for the evaluation calculation?

Need for captions for new dataset?

I want to train the model on another urban dataset. One folder should contain the segmentation masks and the other the images right?
Also, do I need the caption.json for my new dataset?
Thank you :)

Training code / weight for COCO

Hi,
could you please provide the code for training COCO (e.g., dataloader, segmenter config) or weight file of COCO?

Thank you.

Cannot Training with Released Code

Thanks for your contributions and open-source code.

When I try to reproduce training with the command:

bash bash_script/train_ade20k.sh

these errors happened:

FileNotFoundError: [Errno 2] No such file or directory: 
'/fs/scratch/rng_cr_bcai_dl/lyu7rng/0_project_large_models/code_repo/0_ControlNet/checkpoint/control_seg_enc_scratch.ckpt'
FileNotFoundError: [Errno 2] No such file or directory: 
'/fs/scratch/rng_cr_bcai_dl/lyu7rng/0_project_large_models/pretrained/ade20k_semseg/upernet101/encoder_epoch_50.pth'

Could you please provide these checkpoints and provide instructions of how to train ALDM?

How to warm up the segmentation model upper-net?

Great work! I would like to train your model on my own data, but need a warm-up model for the segmentation model upper-net. Can you provide the code for training the segmentation model during the warm-up stage? Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.