Text to Image Latent Diffusion using a Transformer core in PyTorch.
Below are some random examples (at 256 resolution) from a 100MM model trained from scratch for 260k iterations (about 32 hours on 1 A100):
Note that the model has not converged yet and could use more training.
The main goal of this project is to build an accessible diffusion model in PyTorch that is:
- fast (close to real time generation)
- small (~100MM params)
- reasonably good (of course not SOTA)
- can be trained in a reasonable amount of time on a single GPU (under 50 hours on an A100 or equivalent).
- simple self-contained codebase (model + train loop is about ~400 lines of PyTorch with little dependencies)
- uses ~ 1 million images with a focus on data quality over quantity
This is part II of a previous project I did where I trained a pixel level diffusion model in Keras. Even though this model outputs 4x higher resolution images (256px vs 64px), it's actually faster to both train and sample from, which shows the power of training in the latent space and speed of transformer architectures.
The code is written in pure PyTorch with as few dependencies as possible.
- transformer_blocks.py - basic transformer building blocks relevant to the transformer denoiser
- denoiser.py - the architecture of the denoiser transformer
- train.py. The train loop uses
accelerate
so its training can scale to multiple GPUs if needed. - diffusion.py. Class to generate image from noise using reverse diffusion. Short (~60 lines) and self-contained.
- data.py. Data utils to download images/text and process necessary features for the diffusion model.
If you have your own dataset of URLs + captions, the process to train a model on the data would be to first run data.py
then run train.py
with the correct configs. (TODO - still working on adding parameterizations here - a lot of values are hardcoded. I'll add a notebook with a full run here).
PyTorch
numpy
einops
for model buildingwandb
tqdm
for logging + progress barsaccelerate
for train loop and multi-GPU supportimg2dataset
webdataset
torchvision
for data downloading and image processingdiffusers
clip
for pretrained VAE and CLIP text model
I try to speed up training and inference as much as possible by:
- using mixed precision for training + [sdpa]
- precompute all latent and text embeddings
- using float16 precision for inference
- using [sdpa] for the attention natively + torch.compile() (compile doesn't always work).
- use a deterministic denoising process (DDIM) for fewer steps
- TODO: would distillation or something like LCM work here?
The time to generate 16 images on a
- T4:
- A100:
More examples generated with the 100MM model - click the photo to see the prompt and other params like cfg and seed:
In data.py, I have some helper functions to process images and captions. The flow is as follows:
- Use
img2dataset
to download images from a dataframe containing URLs and captions. - Use
CLIP
to encode the prompts and theVAE
to encode images to latents on a web2dataset data generator. - Save the latents and text embedding for future training.
There are two advantages to this approach. One is that the VAE encoding is somewhat expensive, so doing it every epoch would affect training times. The other is that we can discard the images after processing. For 3*256*256
images, the latent dimension is 4*32*32
, so every latent is around 4KB (when quantized in uint8; see here). This means that 1 million latents will be "only" 4GB in size, which is easy to handle even in RAM. Storing the raw images would have been 48x larger in size.
See here for the denoiser class.
The denoiser model is a Transformer-based model based on the archirtecture in DiT and Pixart-Alpha, albeit with quite a few modifications and simplifications. Using a Transformer as the denoiser is different from most diffusion models in that most other models used a CNN-based U-NET as the denoising backbone. I decided to use a Transformer for a few reasons. One was I just wanted to experiment and learn how to build and train Transformers from the ground up. Secondly, Transformers are fast both to train and to do inference on, and they will benefit most from future advances (both in hardware and in software) in performance.
Transformers are not natively built for spatial data and at first I found a lot of the outputs to be very "patchy". To remediy that I added a depth-wise convolution in the FFN layer of the transformer (this was introduced in the Local ViT paper. This allows the model to mix pixels that are close to each other with very little added compute cost.
The image latent inputs are 4*32*32
and we use a patch size of 2 to build 256 flattened 4*2*2=16
dimensional input "pixels". These are then projected into the embed dimensions are are fed through the transformer blocks.
The text and noise conditioning is very simple - we concatenate a pooled CLIP text embedding (ViT/L14
- 768-dimensional) and the sinusoidal noise embedding and feed it as input in the cross-attention layer in each transformer block. No unpooled CLIP embeddings are used.
The base model is 101MM parameters and has 12 layers and embedding dimension = 768. I train it with a batch size of 256 on a A100 and learning rate of 3e-4
. I used 1000 steps for warmup. Due to computational contraints I did not do any ablations for this configuration.
We train a denoising transformer that takes the following three inputs:
noise_level
(sampled from 0 to 1 with more values concentrated close to 0 - I use a beta distribution)- Image latent (x) corrupted with a level of random noise
- For a given
noise_level
between 0 and 1, the corruption is as follows:x_noisy = x*(1-noise_level) + eps*noise_level where eps ~ np.random.normal(0, 1)
- For a given
- CLIP embeddings of a text prompt
- You can think of this as a numerical representation of a text prompt.
- We use the pooled text embedding here (768 dimensional for
ViT/L14
)
The output is a prediction of the denoised image latent - call it f(x_noisy)
.
The model is trained to minimize the mean squared error |f(x_noisy) - x|
between the prediction and actual image
(you can also use absolute error here). Note that I don't reparameterize the loss in terms of the noise here to keep things simple.
Using this model, we then iteratively generate an image from random noise as follows:
for i in range(len(self.noise_levels) - 1):
curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]
# Predict original denoised image:
x0_pred = predict_x_zero(new_img, label, curr_noise)
# New image at next_noise level is a weighted average of old image and predicted x0:
new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
The predict_x_zero
method uses classifier free guidance by combining the conditional and unconditional
prediction: x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional
A bit of math: The approach above falls within the VDM parametrization see 3.1 in Kingma et al.:
Where
Generally,
- better config in the train file
- how to speed up generation even more - LCMs or other sampling strategies?
- add script to compute FID