Giter Site home page Giter Site logo

sayakpaul / stable-diffusion-keras-ft Goto Github PK

View Code? Open in Web Editor NEW
56.0 3.0 8.0 8.87 MB

Fine-tuning Stable Diffusion using Keras.

Home Page: https://keras.io/examples/generative/finetune_stable_diffusion/

License: Apache License 2.0

Python 62.64% Jupyter Notebook 37.36%
fine-tuning generative-ai keras keras-cv stable-diffusion tensorflow text2image transfer-learning

stable-diffusion-keras-ft's Introduction

Fine-tuning Stable Diffusion using Keras

This repository provides code for fine-tuning Stable Diffusion in Keras. It is adapted from this script by Hugging Face. The pre-trained model used for fine-tuning comes from KerasCV. To know about the original model check out this documentation.

The code provided in this repository is for research purposes only. Please check out this section to know more about the potential use cases and limitations.

By loading this model you accept the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE.

If you're just looking for the accompanying resources of this repository, here are the links:

Table of contents:

This repository has a sister repository (keras-sd-serving) that covers various deployment patterns for Stable Diffusion.

Update January 13 2023: This project secured 2nd place at the first-ever Keras Community Prize Competition organized by Google.

Dataset

Following the original script from Hugging Face, this repository also uses the Pokemon dataset. But it was regenerated to work better with tf.data. The regenerated version of the dataset is hosted here. Check out that link for more details.

Training

Fine-tuning code is provided in finetune.py. Before running training, ensure you have the dependencies (refer to requirements.txt) installed.

You can launch training with the default arguments by running python finetune.py. Run python finetune.py -h to know about the supported command-line arguments. You can enable mixed-precision training by passing the --mp flag.

When you launch training, a diffusion model checkpoint will be generated epoch-wise only if the current loss is lower than the previous one.

For avoiding OOM and faster training, it's recommended to use a V100 GPU at least. We used an A100.

Some important details to note:

  • Distributed training is not yet supported. Gradient accumulation and gradient checkpointing are also not supported.
  • Only the diffusion model is fine-tuned. The VAE and the text encoder are kept frozen.

Training details:

We fine-tuned the model on two different resolutions: 256x256 and 512x512. We only varied the batch size and number of epochs for fine-tuning with these two different resolutions. Since we didn't use gradient accumulation, we use this code snippet to derive the number of epochs.

  • 256x256: python finetune.py --batch_size 4 --num_epochs 577
  • 512x512: python finetune.py --img_height 512 --img_width 512 --batch_size 1 --num_epochs 72 --mp

For 256x256 resolution, we intentionally reduced the number of epochs to save compute time.

Fine-tuned weights:

You can find the fine-tuned diffusion model weights here.

Training with custom data

The default Pokemon dataset used in this repository comes with the following structure:

pokemon_dataset/
    data.csv
    image_24.png   
    image_3.png    
    image_550.png  
    image_700.png
    ...

data.csv looks like so:

As long as your custom dataset follows this structure, you don't need to change anything in the current codebase except for the dataset_archive.

In case your dataset has multiple captions per image, you can randomly select one from the pool of captions per image during training.

Based on the dataset, you might have to tune the hyperparameters.

Inference

import keras_cv
import matplotlib.pyplot as plt
from tensorflow import keras

IMG_HEIGHT = IMG_WIDTH = 512


def plot_images(images, title):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.title(title)
        plt.imshow(images[i])
        plt.axis("off")


# We just have to load the fine-tuned weights into the diffusion model.
weights_path = keras.utils.get_file(
    origin="https://huggingface.co/sayakpaul/kerascv_sd_pokemon_finetuned/resolve/main/ckpt_epochs_72_res_512_mp_True.h5"
)
pokemon_model = keras_cv.models.StableDiffusion(
    img_height=IMG_HEIGHT, img_width=IMG_WIDTH
)
pokemon_model.diffusion_model.load_weights(weights_path)

# Generate images.
generated_images = pokemon_model.text_to_image("Yoda", batch_size=3)
plot_images(generated_images, "Fine-tuned on the Pokemon dataset")

You can bring in your weights_path (should be compatible with the diffusion_model) and reuse the code snippet.

Check out this Colab Notebook to play with the inference code.

Results

Initially, we fine-tuned the model on a resolution of 256x256. Here are some results along with comparisons to the results of the original model.

Images Prompts
Yoda
robotic cat with wings
Hello Kitty
Weights

We can see that the fine-tuned model has more stable outputs than the original model. Even though the results can be aesthetically improved much more, the fine-tuning effects are visible. Also, we followed the same hyperparameters from Hugging Face's script for the 256x256 resolution (apart from number of epochs and batch size). With better hyperparameters, the results will likely improve.

For the 512x512 resolution, we observe something similar. So, we experimented with the unconditional_guidance_scale parameter and noticed that when it's set to 40 (while keeping the other arguments fixed), the results came out better.

Images Prompts
Yoda
robotic cat with wings
Hello Kitty
Weights

Note: Fine-tuning on the 512x512 is still in progress as of this writing. But it takes a lot of time to complete a single epoch without the presence of distributed training and gradient accumulation. The above results are from the checkpoint derived after 60th epoch.

With a similar recipe (but trained for more optimization steps), Lambda Labs demonstrate amazing results.

Acknowledgements

  • Thanks to Hugging Face for providing the fine-tuning script. It's very readable and easy to understand.
  • Thanks to the ML Developer Programs' team at Google for providing GCP credits.

stable-diffusion-keras-ft's People

Contributors

deep-diver avatar sayakpaul avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

stable-diffusion-keras-ft's Issues

Got ValueError: Unknown layer: 'PaddedConv2D' when running the the script.

I got the following error message when running the finetune.py script:

ValueError: Unknown layer: 'PaddedConv2D'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

The line that got this issue is this line in the code: link
Based on my understanding, the built-in copy function cannot copy keras-cv PaddedConv2D layer. After I modified that line to this:

self.ema_diffusion_model = keras.models.clone_model(self.diffusion_model)

The script can run now.
Here are my machine specs:

  • macOS
  • Python 3.8.10
  • tensorflow 2.11.0
  • keras-cv 0.4.0

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.