Giter Site home page Giter Site logo

filipbasara0 / simple-diffusion Goto Github PK

View Code? Open in Web Editor NEW
68.0 2.0 9.0 31 KB

A minimal implementation of a denoising diffusion model in PyTorch.

License: MIT License

Python 100.00%
attention-mechanism computer-vision deep-learning pytorch stable-diffusion unet image-generation diffusion

simple-diffusion's Introduction

Simple Denoising Diffusion

A minimal implementation of a denoising diffusion uncoditional image generation model in PyTorch. The idea was to test the performance of a very small model on the Oxford Flowers dataset.

Includes the DDIM scheduler and the UNet architecture with residual connections and Attention layers.

Oxford Flowers

flowers

So far, the model was tested on the Oxford Flowers dataset - the results can be seen on the image above. Images were generated with 50 DDIM steps.

The results were surprisingly decent and training unexpectedly smooth, considering the model size.

Training was done for 40k steps, with a batch size of 64. Learning rate was 1e-3 and weight decay was 5e-2. Training took ~6 hours on GTX 1070Ti.

Hidden dims of [16, 32, 64, 128] were used, which resulted in a total of 2,346,835 million params.

To train the model, run the following command:

 python train.py   --dataset_name="huggan/flowers-102-categories"   --resolution=64   --output_dir="trained_models/ddpm-ema-64.pth"   --train_batch_size=16   --num_epochs=121 --gradient_accumulation_steps=1   --learning_rate=1e-4   --lr_warmup_steps=300

Conclusions

  • Skip and residual connections are a must - training doesn't converge without them
  • Attention speeds up convergence and improves the quality of generated samples
  • Normalizing images to N(0,1) didn't yield improvents compared to the standard -1 to 1 normalization
  • Learning rate of 1e-3 resulted in a faster convergence for the smaller models, compared to 1e-4 which is usually used in literature

Improvements

  • Training longer - these models require a lot of iterations. For example, in Diffusion Models Beat GANs on Image Synthesis, iterations ranged between 300K and 4360k!
  • Using bigger models
  • Would like to explore the impact of more diverse augmentations

Future steps

  • Training on huggan/pokemons dataset with a bigger model. This dataset proved to be too difficult for the 2M model
  • Training a model on a custom task

simple-diffusion's People

Contributors

filipbasara0 avatar

Stargazers

Steven Hogue avatar Nikita avatar  avatar Persvadisto avatar Swift avatar Anders Støttrup Larsen avatar S.PO.I.L.E.R avatar Yoshitaka Inoue avatar Ömer Erdinç Yağmurlu avatar Norio Kosaka avatar Tianyuan Chen avatar Ziqian Zhang avatar Jialong Wu avatar typoverflow avatar Martin Lumiste avatar  avatar Perry Gibson avatar Hyunsoo Park avatar Tam Nguyen avatar Aritra Dey avatar Lê Anh Duy avatar Phạm Văn Lĩnh avatar Pedro Rodriguez avatar jinghanSun avatar Luca Parolari avatar  avatar  avatar Marc Andre Stadelmann avatar  avatar cyan-at avatar Yangyang Li avatar Nikita avatar Shoaib Ahmed Siddiqui avatar Sehee Min avatar Dhruv Karan avatar David Marx avatar Sofian Mejjoute avatar  avatar  avatar  avatar Matthew MacKay avatar Saúl Cerdá Peris avatar  avatar Kaze avatar  avatar  avatar  avatar Peyton avatar JingLuo avatar JIJIN CHEN avatar Jinx avatar  avatar Tao Hu avatar Hertz avatar Sandalots avatar Xuecheng avatar 爱可可-爱生活 avatar JinfengZhang avatar yule-li avatar  avatar Antonio Padalino avatar  avatar  avatar Jacob Moonki Back avatar  avatar  avatar  avatar Nikolay avatar

Watchers

Kostas Georgiou avatar  avatar

simple-diffusion's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.