Giter Site home page Giter Site logo

transformer_latent_diffusion's Introduction

Transformer Latent Diffusion

Text to Image Latent Diffusion using a Transformer core in PyTorch.

Try with own inputs: Open In Colab

Below are some random examples (at 256 resolution) from a 100MM model trained from scratch for 260k iterations (about 32 hours on 1 A100):

image

Note that the model has not converged yet and could use more training.

The main goal of this project is to build an accessible diffusion model in PyTorch that is:

  • fast (close to real time generation)
  • small (~100MM params)
  • reasonably good (of course not SOTA)
  • can be trained in a reasonable amount of time on a single GPU (under 50 hours on an A100 or equivalent).
  • simple self-contained codebase (model + train loop is about ~400 lines of PyTorch with little dependencies)
  • uses ~ 1 million images with a focus on data quality over quantity

This is part II of a previous project I did where I trained a pixel level diffusion model in Keras. Even though this model outputs 4x higher resolution images (256px vs 64px), it's actually faster to both train and sample from, which shows the power of training in the latent space and speed of transformer architectures.

Table of Contents:

Codebase:

The code is written in pure PyTorch with as few dependencies as possible.

  • transformer_blocks.py - basic transformer building blocks relevant to the transformer denoiser
  • denoiser.py - the architecture of the denoiser transformer
  • train.py. The train loop uses accelerate so its training can scale to multiple GPUs if needed.
  • diffusion.py. Class to generate image from noise using reverse diffusion. Short (~60 lines) and self-contained.
  • data.py. Data utils to download images/text and process necessary features for the diffusion model.

Usage:

If you have your own dataset of URLs + captions, the process to train a model on the data would be to first run data.py then run train.py with the correct configs. (TODO - still working on adding parameterizations here - a lot of values are hardcoded. I'll add a notebook with a full run here).

Dependencies:

  • PyTorch numpy einops for model building
  • wandb tqdm for logging + progress bars
  • accelerate for train loop and multi-GPU support
  • img2dataset webdataset torchvision for data downloading and image processing
  • diffusers clip for pretrained VAE and CLIP text model

Codebases used for inspiration:

Speed:

I try to speed up training and inference as much as possible by:

  • using mixed precision for training + [sdpa]
  • precompute all latent and text embeddings
  • using float16 precision for inference
  • using [sdpa] for the attention natively + torch.compile() (compile doesn't always work).
  • use a deterministic denoising process (DDIM) for fewer steps
  • TODO: would distillation or something like LCM work here?

The time to generate 16 images on a

  • T4:
  • A100:

Examples:

More examples generated with the 100MM model - click the photo to see the prompt and other params like cfg and seed: image image image image image image image

Data Processing:

In data.py, I have some helper functions to process images and captions. The flow is as follows:

  • Use img2dataset to download images from a dataframe containing URLs and captions.
  • Use CLIP to encode the prompts and the VAE to encode images to latents on a web2dataset data generator.
  • Save the latents and text embedding for future training.

There are two advantages to this approach. One is that the VAE encoding is somewhat expensive, so doing it every epoch would affect training times. The other is that we can discard the images after processing. For 3*256*256 images, the latent dimension is 4*32*32, so every latent is around 4KB (when quantized in uint8; see here). This means that 1 million latents will be "only" 4GB in size, which is easy to handle even in RAM. Storing the raw images would have been 48x larger in size.

Architecture:

See here for the denoiser class.

The denoiser model is a Transformer-based model based on the archirtecture in DiT and Pixart-Alpha, albeit with quite a few modifications and simplifications. Using a Transformer as the denoiser is different from most diffusion models in that most other models used a CNN-based U-NET as the denoising backbone. I decided to use a Transformer for a few reasons. One was I just wanted to experiment and learn how to build and train Transformers from the ground up. Secondly, Transformers are fast both to train and to do inference on, and they will benefit most from future advances (both in hardware and in software) in performance.

Transformers are not natively built for spatial data and at first I found a lot of the outputs to be very "patchy". To remediy that I added a depth-wise convolution in the FFN layer of the transformer (this was introduced in the Local ViT paper. This allows the model to mix pixels that are close to each other with very little added compute cost.

Img+Text+Noise Encoding:

The image latent inputs are 4*32*32 and we use a patch size of 2 to build 256 flattened 4*2*2=16 dimensional input "pixels". These are then projected into the embed dimensions are are fed through the transformer blocks.

The text and noise conditioning is very simple - we concatenate a pooled CLIP text embedding (ViT/L14 - 768-dimensional) and the sinusoidal noise embedding and feed it as input in the cross-attention layer in each transformer block. No unpooled CLIP embeddings are used.

Training:

The base model is 101MM parameters and has 12 layers and embedding dimension = 768. I train it with a batch size of 256 on a A100 and learning rate of 3e-4. I used 1000 steps for warmup. Due to computational contraints I did not do any ablations for this configuration.

Train and Diffusion Setup:

We train a denoising transformer that takes the following three inputs:

  • noise_level (sampled from 0 to 1 with more values concentrated close to 0 - I use a beta distribution)
  • Image latent (x) corrupted with a level of random noise
    • For a given noise_level between 0 and 1, the corruption is as follows:
      • x_noisy = x*(1-noise_level) + eps*noise_level where eps ~ np.random.normal(0, 1)
  • CLIP embeddings of a text prompt
    • You can think of this as a numerical representation of a text prompt.
    • We use the pooled text embedding here (768 dimensional for ViT/L14)

The output is a prediction of the denoised image latent - call it f(x_noisy).

The model is trained to minimize the mean squared error |f(x_noisy) - x| between the prediction and actual image (you can also use absolute error here). Note that I don't reparameterize the loss in terms of the noise here to keep things simple.

Using this model, we then iteratively generate an image from random noise as follows:

     for i in range(len(self.noise_levels) - 1):

        curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]

        # Predict original denoised image:
        x0_pred = predict_x_zero(new_img, label, curr_noise)

        # New image at next_noise level is a weighted average of old image and predicted x0:
        new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise

The predict_x_zero method uses classifier free guidance by combining the conditional and unconditional prediction: x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional

A bit of math: The approach above falls within the VDM parametrization see 3.1 in Kingma et al.:

$$z_t = \alpha_t x + \sigma_t \epsilon, \epsilon \sim \mathcal{N}(0,1)$$

Where $z_t$ is the noisy version of $x$ at time $t$.

Generally, $\alpha_t$ is chosen to be $\sqrt{1-\sigma_t^2}$ so that the process is variance preserving. Here, I chose $\alpha_t=1-\sigma_t$ so that we linearly interpolate between the image and random noise. Why? For one, it simplifies the updating equation quite a bit, and it's easier to understand what the noise to signal ratio will look like. I also found that the model produces sharper images faster - more validation here is needed. The updating equation above is the DDIM model for this parametrization, which simplifies to a simple weighted average. Note that the DDIM model deterministically maps random normal noise to images - this has two benefits: we can interpolate in the random normal latent space, and it generally takes fewer steps to achieve decent image quality.

TODOS:

  • better config in the train file
  • how to speed up generation even more - LCMs or other sampling strategies?
  • add script to compute FID

transformer_latent_diffusion's People

Contributors

apapiu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.