Giter Site home page Giter Site logo

explainingai-code / stablediffusion-pytorch Goto Github PK

View Code? Open in Web Editor NEW
51.0 3.0 12.0 91 KB

This repo implements a Stable Diffusion model in PyTorch with all the essential components.

Python 100.00%
latent-diffusion latent-diffusion-models stable-diffusion stable-diffusion-tutorial

stablediffusion-pytorch's Introduction

Stable Diffusion Implementation in PyTorch

This repository implements Stable Diffusion. As of today the repo provides code to do the following:

  • Training and Inference on Unconditional Latent Diffusion Models
  • Training a Class Conditional Latent Diffusion Model
  • Training a Text Conditioned Latent Diffusion Model
  • Training a Semantic Mask Conditioned Latent Diffusion Model
  • Any Combination of the above three conditioning

For autoencoder I provide code for vae as well as vqvae. But both the stages of training use VQVAE only. One can easily change that to vae if needed

For diffusion part, as of now it only implements DDPM with linear schedule.

Stable Diffusion Tutorial Videos

Stable Diffusion Tutorial Stable Diffusion Conditioning Tutorial ___

Sample Output for Autoencoder on CelebHQ

Image - Top, Reconstructions - Below

Sample Output for Unconditional LDM on CelebHQ (not fully converged)

Sample Output for Conditional LDM

Sample Output for Class Conditioned on MNIST

50 50 50 50 50

Sample Output for Text(using CLIP) and Mask Conditioned on CelebHQ (not converged)


Text - She is a woman with blond hair
Text - She is a woman with black hair

Setup


Data Preparation

Mnist

For setting up the mnist dataset follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation

Ensure directory structure is following

StableDiffusion-PyTorch
    -> data
        -> mnist
            -> train
                -> images
                    -> *.png
            -> test
                -> images
                    -> *.png

CelebHQ

Unconditional

For setting up on CelebHQ for unconditional, simply download the images from the official repo of CelebMASK HQ here.

Ensure directory structure is the following

StableDiffusion-PyTorch
    -> data
        -> CelebAMask-HQ
            -> CelebA-HQ-img
                -> *.jpg

Mask Conditional

For CelebHQ for mask conditional LDM additionally do the following:

Ensure directory structure is the following

StableDiffusion-PyTorch
    -> data
        -> CelebAMask-HQ
            -> CelebA-HQ-img
                -> *.jpg
            -> CelebAMask-HQ-mask-anno
                -> 0/1/2/3.../14
                    -> *.png
            
  • Run python -m utils.create_celeb_mask from repo root to create the mask images from mask annotations

Ensure directory structure is the following

StableDiffusion-PyTorch
    -> data
        -> CelebAMask-HQ
            -> CelebA-HQ-img
                -> *.jpg
            -> CelebAMask-HQ-mask-anno
                -> 0/1/2/3.../14
                    -> *.png
            -> CelebAMask-HQ-mask
                  -> *.png

Text Conditional

For CelebHQ for text conditional LDM additionally do the following:

Ensure directory structure is the following

StableDiffusion-PyTorch
    -> data
        -> CelebAMask-HQ
            -> CelebA-HQ-img
                -> *.jpg
            -> CelebAMask-HQ-mask-anno
                -> 0/1/2/3.../14
                    -> *.png
            -> CelebAMask-HQ-mask
                -> *.png
            -> celeba-caption
                -> *.txt

Configuration

Allows you to play with different components of ddpm and autoencoder training

  • config/mnist.yaml - Small autoencoder and ldm can even be trained on CPU
  • config/celebhq.yaml - Configuration used for celebhq dataset

Relevant configuration parameters

Most parameters are self explanatory but below I mention couple which are specific to this repo.

  • autoencoder_acc_steps : For accumulating gradients if image size is too large for larger batch sizes
  • save_latents : Enable this to save the latents , during inference of autoencoder. That way ddpm training will be faster

Training

The repo provides training and inference for Mnist(Unconditional and Class Conditional) and CelebHQ (Unconditional, Text and/or Mask Conditional).

For working on your own dataset:

  • Create your own config and have the path in config point to images (look at celebhq.yaml for guidance)
  • Create your own dataset class which will just collect all the filenames and return the image in its getitem method. Look at mnist_dataset.py or celeb_dataset.py for guidance

Once the config and dataset is setup:

Training AutoEncoder for LDM

  • For training autoencoder on mnist,ensure the right path is mentioned in mnist.yaml
  • For training autoencoder on celebhq,ensure the right path is mentioned in celebhq.yaml
  • For training autoencoder on your own dataset
    • Create your own config and have the path point to images (look at celebhq.yaml for guidance)
    • Create your own dataset class, similar to celeb_dataset.py without conditining parts
  • Map the dataset name to the right class in the training code here
  • For training autoencoder run python -m tools.train_vqvae --config config/mnist.yaml for training vqvae with the desire config file
  • For inference using trained autoencoder runpython -m tools.infer_vqvae --config config/mnist.yaml for generating reconstructions with right config file. Use save_latent in config to save the latent files

Training Unconditional LDM

Train the autoencoder first and setup dataset accordingly.

For training unconditional LDM map the dataset to the right class in train_ddpm_vqvae.py

  • python -m tools.train_ddpm_vqvae --config config/mnist.yaml for training unconditional ddpm using right config
  • python -m tools.sample_ddpm_vqvae --config config/mnist.yaml for generating images using trained ddpm

Training Conditional LDM

For training conditional models we need two changes:

  • Dataset classes must provide the additional conditional inputs(see below)
  • Config must be changed with additional conditioning config added

Specifically the dataset getitem will return the following:

  • image_tensor for unconditional training
  • tuple of (image_tensor, cond_input ) for conditional training where cond_input is a dictionary consisting of keys {class/text/image}

Training Class Conditional LDM

The repo provides class conditional latent diffusion model training code for mnist dataset, so one can use that to follow the same for their own dataset

  • Use mnist_class_cond.yaml config file as a guide to create your class conditional config file. Specifically following new keys need to be modified according to your dataset within ldm_params.
  • condition_config:
      condition_types: ['class']
      class_condition_config :
        num_classes : <number of classes: 10 for mnist>
        cond_drop_prob : <probability of dropping class labels>
    
  • Create a dataset class similar to mnist where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
  • For class, conditional input will ONLY be the integer class
  •   (image_tensor, {
                      'class' : {0/1/.../num_classes}
                      })
    
    

For training class conditional LDM map the dataset to the right class in train_ddpm_cond and run the below commands using desired config

  • python -m tools.train_ddpm_cond --config config/mnist_class_cond.yaml for training class conditional on mnist
  • python -m tools.sample_ddpm_class_cond --config config/mnist.yaml for generating images using class conditional trained ddpm

Training Text Conditional LDM

The repo provides text conditional latent diffusion model training code for celebhq dataset, so one can use that to follow the same for their own dataset

  • Use celebhq_text_cond.yaml config file as a guide to create your config file. Specifically following new keys need to be modified according to your dataset within ldm_params.
  •   condition_config:
          condition_types: [ 'text' ]
          text_condition_config:
              text_embed_model: 'clip' or 'bert'
              text_embed_dim: 512 or 768
              cond_drop_prob: 0.1
    
  • Create a dataset class similar to celebhq where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
  • For text, conditional input will ONLY be the caption
  •   (image_tensor, {
                      'text' : 'a sample caption for image_tensor'
                      })
    
    

For training text conditional LDM map the dataset to the right class in train_ddpm_cond and run the below commands using desired config

  • python -m tools.train_ddpm_cond --config config/celebhq_text_cond.yaml for training text conditioned ldm on celebhq
  • python -m tools.sample_ddpm_text_cond --config config/celebhq_text_cond.yaml for generating images using text conditional trained ddpm

Training Text and Mask Conditional LDM

The repo provides text and mask conditional latent diffusion model training code for celebhq dataset, so one can use that to follow the same for their own dataset and can even use that train a mask only conditional ldm

  • Use celebhq_text_image_cond.yaml config file as a guide to create your config file. Specifically following new keys need to be modified according to your dataset within ldm_params.
  •   condition_config:
          condition_types: [ 'text', 'image' ]
          text_condition_config:
              text_embed_model: 'clip' or 'bert
              text_embed_dim: 512 or 768
              cond_drop_prob: 0.1
          image_condition_config:
             image_condition_input_channels: 18
             image_condition_output_channels: 3
             image_condition_h : 512 
             image_condition_w : 512
             cond_drop_prob: 0.1
    
  • Create a dataset class similar to celebhq where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
  • For text and mask, conditional input will be caption and mask image
  •   (image_tensor, {
                      'text' : 'a sample caption for image_tensor',
                      'image' : NUM_CLASSES x MASK_H x MASK_W
                      })
    
    

For training text unconditional LDM map the dataset to the right class in train_ddpm_cond and run the below commands using desired config

  • python -m tools.train_ddpm_cond --config config/celebhq_text_image_cond.yaml for training text and mask conditioned ldm on celebhq
  • python -m tools.sample_ddpm_text_image_cond --config config/celebhq_text_image_cond.yaml for generating images using text and mask conditional trained ddpm

Output

Outputs will be saved according to the configuration present in yaml files.

For every run a folder of task_name key in config will be created

During training of autoencoder the following output will be saved

  • Latest Autoencoder and discriminator checkpoint in task_name directory
  • Sample reconstructions in task_name/vqvae_autoencoder_samples

During inference of autoencoder the following output will be saved

  • Reconstructions for random images in task_name
  • Latents will be save in task_name/vqvae_latent_dir_name if mentioned in config

During training and inference of ddpm following output will be saved

  • During training of unconditional or conditional DDPM we will save the latest checkpoint in task_name directory
  • During sampling, unconditional sampled image grid for all timesteps in task_name/samples/*.png . The final decoded generated image will be x0_0.png. Images from x0_999.png to x0_1.png will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0
  • During sampling, class conditionally sampled image grid for all timesteps in task_name/cond_class_samples/*.png . The final decoded generated image will be x0_0.png. Images from x0_999.png to x0_1.png will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0
  • During sampling, text only conditionally sampled image grid for all timesteps in task_name/cond_text_samples/*.png . The final decoded generated image will be x0_0.png . Images from x0_999.png to x0_1.png will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0
  • During sampling, image only conditionally sampled image grid for all timesteps in task_name/cond_text_image_samples/*.png . The final decoded generated image will be x0_0.png. Images from x0_999.png to x0_1.png will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0

stablediffusion-pytorch's People

Contributors

explainingai-code avatar kajc10 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

stablediffusion-pytorch's Issues

Why your model generated mnist images are noises?

Dear explainingai-code:
Your codebase for ddpm is so detailed and helpful that I would like to thank you very much for your great work !
I have downloaded this codebase and follow your instructions carefully, luckly I got a good VQVAE result as follows:
image

So I continue to train unconditional and class conditional ddpm models. 
I run python tools/train_ddpm_vqvae.py to train unconditional ddpm model and it seems to have converged as follows:

image
The loss decreases from 0.2833 to 0.0886. Then I run python tools/sample_ddpm_vqvae.py to check model output, I get mnist/samples/x0_0.png to x0_999.png, all seems like PURE NOISE as follows:
image

 Simiarly I trained and tested mnist class conditional ddpm and all results seems like PURE NOISE as follows:

image

 My code structure as follows:

image
I am sure that I just follow your instructions and DO NOT modify model related code, but the results seems weired. Appreciated if you could give some help.

Sincere,
CatLoves

Losses for conditional diffusion models

Hey thanks for the videos and codes, I am experimenting with conditional ldms.

Do you happen to have loss plots or logs of the loss? I have a feeling that the loss is decreasing really slowly or not decreasing at all.
Could you let me know if you had similar loss decrease? Here is the screenshot for your reference.

Screenshot 2024-05-20 at 13 54 27

Unexpected output after sampling using Conditional LDM

I am getting a bunch of noodle-like waves after sampling for conditional LDM instead of proper digits. The unconditional LDM works fine. I am using the MNIST dataset that the Torchvision library has (torchvision.datasets.MNIST).

Can you tell what could be wrong in this scenario?

I have attached x0_0.png outputs for both Unconditional_LDM and Conditional_LDM.
Unconditional_LDM_x0_0
Conditional_LDM_x0_0

Running out of memory... best number of samples for custom data sets?

Hello!

I was wondering if you have any intuition on how many training samples are required to get good results/how much memory is required to train the unconditional VQVAE?

I have about 200k grayscale images at 256x256... which was obviously too much, so I scaled back to 70 images just to see if it would start training, but it didn't... throwing the too little memory error.

Is this something batch size can fix or do I need to mess with a bunch of other parameters? I only changed the im_channels and save_latent parameters from their defaults.

Thank you!

The training loss is not decreasing

Hi,
Thank you for your code.

I am training a ldm model with the config file I have attached.
I have training with multiple dataset and settings. Always the training loss doesnt converge after certain epochs. usually it is when the loss is somewhere around 0.1. The loss does goes down consistently but very slowly.
As I am using MSE loss 0.1 is large for image generation.

Once I continued training until 400 epochs at that time the model was overfitted but the loss was minimum around 0.02.
May be could you share your insights or does anyone has faced this issue?
tuned_class_cond_bdd_1.zip

Bug when saving Latent information?

Hello,
When you run infer_vqvae.py you save the latent information (encoded information) but you do not clamp it (torch.clamp(encoded_output, -1., 1.)).

I also checked when you read it from dataset and when the variable of use_latents is equal to True you don't clamp it.

Maybe its a bug?

Thank you!

How to condition based on multiple features?

I would like to condition the model using multiple features. In my case, I have lot of columns say A, B, C and D and some of the columns are categorical and some are numerical. Now I wanted to implement stable diffusion by conditioning on all the columns together. Please advise what are the modifications I need to do.

Thanks.

It wuold be greatly appreciated if model ckpts could be provided

Hello @explainingai-code !
I have read your code carefully and can train unconditional and text-conditioned celebHQ model now, but my GPU is only one V100 card and it's expected to train for ~110 hours to get one result, which is very time consuming. Renting a GPU cluster is also very expensive and time-consuming. If you have already trained the model, could you provide download links for model ckpts for convenience? I think this would be very helpful for people to quickly use your codebase.
Moreover, thank you again for your great codebase and I think it wuold be greatly appreciated if model ckpts could be provided.

Sincerely,
CatLoves

Unable to run

Hello there, I am trying to run conditional text part and have followed all the instructions but at the end I am facing following error. Screenshot attached below. It says "Model checkpoint celebhq/ddpm_ckpt_text_cond_clip.pth not found"
Screenshot 2024-03-28 144732

Question about sampling data

Hello,
First of all thanks again for sharing the code and good work on that!

I have one doubt about sampling data (after training fase).

When "t" is not equal to zero you are using x0 into "im" variable (instead using mean + sigma * z) but when t is equal to zero you are using the mean variable. Why?

I guess you should use always mean + sigma * z every time when t is not equal to zero right?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.