Giter Site home page Giter Site logo

super-resolution-generative-adversarial-networks's Introduction

Super-Resolution GAN (SRGAN)

Overview

Welcome to the Super-Resolution Generative Adversarial Network (SRGAN) repository! This code implements an image enhancement model using GANs to generate high-resolution images from low-resolution inputs. SRGANs are particularly useful for tasks like upscaling images with improved visual quality or to reduce the noise in an image.

Architecture

Screen_Shot_2020-07-19_at_11 13 45_AM_zsF2pa7

Code Explanation

Training:

The training process is encapsulated in the train function within the SRGAN.ipynb notebook. Here's a brief breakdown:

Data Loading:

  • High-resolution (HR) and low-resolution (LR) images are loaded using PyTorch DataLoader.
  • The number of images in HR and LR folders is verified for consistency.

Model Initialization:

  • The SRGAN consists of a generator (G) responsible for upscaling LR images and a discriminator (D) distinguishing between real HR images and fake HR images generated by G.
  • Model weights are initialized with Kaiming He, and Adam optimizers are set up for both generator and discriminator.

Training Loop:

  • The training loop runs for the number of epochs provided as an hyperparameter.
  • In each epoch, first the discriminator gets calculated the discriminator loss using error from real image and error from generated image during classification.
  • After the discriminator update, generator loss which consists of vgg loss, adversarial loss and pixel loss are used to backprop and update generator parameters.
  • We also tried to put a n:1 ratio for generator and discriminator update, where generator gets updated n times while discriminator gets updated only once in each epoch. This helps generator to keep up with discriminator and avoid scenarios where discriminator always wins. Due to high GPU ram requirements, we could not run this procedure and had to adjust to single updates resulting in acceptable performance by the generator.
  • The discriminator is trained to distinguish between real HR images and fake HR images generated by the generator.
  • The generator is trained to minimize the adversarial loss and generate realistic HR images.

Loss Functions:

Three loss components contribute to the overall generator loss (err_G):

1. Content Loss (criterion_G):

Measures the difference between the generated image and the ground truth in a perceptually meaningful way.

2. Adversarial Loss (criterion_D):

Captures the discriminator's ability to distinguish between real and fake images.

3. Pixel Loss and VGG Loss:

Additional components contributing to the overall generator loss.

Model Checkpoints and Visualization:

  • Training progress is visualized by printing statistics for each batch, including discriminator and generator losses.
  • Model checkpoints are saved periodically, enabling the resumption of training or deployment of a pre-trained model.

Epochs and Memory Management

  • The training loop iterates over multiple epochs, refining the model's performance. Memory management techniques, such as clearing variable data to free up GPU memory, are employed to ensure efficient usage during training.
  • We tried to be efficient as much as possible in terms of GPU RAM usage of 16 GB in Google Colab Pro, but still it requires a lot of GPU Ram for GAN models with so many layers and parameters for generator and discriminator.

Deployment

For deploying the trained model, follow these steps:

  1. Load Trained Model:
  • Load the trained generator model using the SRGenerator class by uploading the D.pt and G.pt files in the model folder in google colab.
  • Use the load_state_dict method to load the saved model parameters.
  • Size of our stored model: D.pt - 60 MB and G.pt - 27 MB
  1. Inference:
  • Provide a low-resolution image as input to the generator to obtain a high-resolution output, though a lot of hyperparameter and powerful compute power and hardware is required for excellent high-resolution image quality.
  1. Visualize and Save:
  • Visualize the enhanced image and save it to a desired location.

Setup and Dependencies

  1. Clone the Repository:

-- git clone https://github.com/your-username/SRGAN.git

  1. Install Dependencies:

-- pip install -r requirements.txt

  1. Prepare Data:
  • Organize HR and LR images in the data/ directory.
  1. Run Training:
  • Execute the SRGAN.ipynb notebook to train the model.

Additional Notes

  • GPU support is recommended for faster training; the code automatically detects and uses CUDA if available.
  • Hyperparameters and network architecture can be experimented with for potential improvements.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.