Giter Site home page Giter Site logo

resnet-implementation's Introduction

Implementations of ResNet architecture on CIFAR10 with less than 5M Parameters – Deep Learning Mini Project 1

OUR ARCHITECTURE:

architecture

Here,

  • N: # Residual Layers = 3
  • Bi: # Residual blocks in Residual Layer i = 2
  • Ci: # channels in Residual Layer i = 64
  • Fi: Conv. kernel size in Residual Layer i = 3 x 3
  • Ki: Skip connection kernel size in Residual Layer i = 1 x 1
  • P: Average pool kernel size = 8 x 8

ABOUT THE REPO:

1. File structure

OUTPUTS – This folder contains all the outputs in ‘.out’ format, of the different experiments with different parameters. To view contents of the folder, use the “cat” command.

PLOTS – This folder contains all the different graphs plotted for each corresponding experiment. They contain .png files and can be opened on Github itself.

SBATCH – This folder contains all the different ‘.sbatch’ files created for each corresponding experiment. They are used to assign Slurm jobs, use the command “sbatch filename.sbatch” to run the particular experiment.

best_model.out – This is the output file generated for our model which produced the best results.

best_model_acc.png & best_model_loss.png – These are the train/test accuracy graph and loss graph for our model. bestmodel.sbatch – The sbatch for this model is bestmodel.sbatch.

main.py – Python file being ran by the slurm command which contains our training logic and saves the best model weights in project1_model.pt file.

project1_model.pt – This is a PyTorch file for our best architecture with saved parameters that can be loaded for testing.

requirements.txt – This file contains all the different libraries used for this project.

test.py – Python program to run the model saved in project1_model.pt on CIFAR10 testset.

utils.py – Python program which is being used by ‘main.py’ to import different functionalities such as Progress bar and computing the mean and standard deviation value of dataset.


2. How to clone

(Make sure you have Git Bash installed) Run a Git Bash terminal in the folder you want to clone in and use the following command:
git clone https://github.com/dhyani15/resnet-implementation.git

3. How to test the code

Method 1(To only test the trained model on cifar10 testset): Using ‘test.py’ (Saved parameters and weights) - Recommended

(Make sure the file ‘project1_model.pt’ and ‘test.py’ are in the same folder)

  • Step 1: Create a conda environment
    Check out the following link to do so:
    Managing environments — https://sites.google.com/a/nyu.edu/nyu-hpc/documentation/prince/packages/conda-environments

  • Step 2: For installing requirements Run the following commands to set up the environment:

    pip install -r requirements.txt

  • Step 3. Run ‘test.py’ to display accuracy Run the python command:
    python test.py

    This program uses the saved model ‘project1_model.pt’ and displays the accuracy as a Tensor.

    If you are planning to test our model using your own test script, make sure it has the following command
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = project1_model().to(device)
    model_path = './project1_model.pt'
    checkpoint = torch.load(model_path, map_location=device)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint, strict=False)
    --------------------------------------------------------------------------------------------------------

    Method 2(to train our model from scratch) : How to retrain the model using SLURM jobs on HPC

  • Make sure you have cloned this repo on your hpc and repeat step 1 & step 2 from Method 1
  • Step 3: Run the following SLURM command by running the following command ‘bestmodel.sbatch’:
    sbatch bestmodel.sbatch

    This command will create a .out file and two .png files one for training/test loss and one for accuracy. It will run 200 epochs for the model and will print both losses and accuracies for each epoch.
    (Warning: This process takes approx. 45-60 mins for both training and testing combined).

    REFERENCES:

  • [1] Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. ”Imagenet classification with deep convolutional neural networks.” Advances in neural information processing systems 25 (2012).
  • [2] Simonyan, Karen, and Andrew Zisserman. ”Very deep convolutional networks for large-scale image recognition.” arXiv preprint arXiv:1409.1556 (2014).
  • [3] Szegedy, Christian, et al. ”Inception-v4, inception-resnet and the impact of residual connections on learning.” Thirty-first AAAI conference on artificial intelligence. 2017.
  • [4] He, Kaiming, et al. ”Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
  • [5] Tan, Mingxing, and Quoc Le. ”Efficientnet: Rethinking model scaling for convolutional neural networks.” International conference on machine learning. PMLR, 2019.
  • [6] Zagoruyko, Sergey, and Nikos Komodakis. ”Wide residual networks.” arXiv preprint arXiv:1605.07146 (2016).
  • [7] Zhang, Michael, et al. ”Lookahead optimizer: k steps forward, 1 step back.” Advances in Neural Information Processing Systems 32 (2019).
  • [8] Krizhevsky, Alex, and Geoffrey Hinton. ”Learning multiple layers of features from tiny images.” (2009): 7.
  • [9] Shorten, C., Khoshgoftaar, T.M. A survey on Image Data Augmentation for Deep Learning. J Big Data 6, 60 (2019). https://doi.org/10.1186/s40537-019-0197-0
  • [10] Bergstra, James, and Yoshua Bengio. ”Random search for hyper-parameter optimization.” Journal of machine learning research 13.2 (2012).
  • [11] Yu, Tong, and Hong Zhu. ”Hyper-parameter optimization: A review of algorithms and applications.” arXiv preprint arXiv:2003.05689 (2020).
  • [12] https://github.com/kuangliu/pytorch-cifar
  • resnet-implementation's People

    Contributors

    mohitk29 avatar dhyani15 avatar siddhanthiyer-99 avatar

    Watchers

     avatar

    Forkers

    mohitk29

    Recommend Projects

    • React photo React

      A declarative, efficient, and flexible JavaScript library for building user interfaces.

    • Vue.js photo Vue.js

      🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

    • Typescript photo Typescript

      TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

    • TensorFlow photo TensorFlow

      An Open Source Machine Learning Framework for Everyone

    • Django photo Django

      The Web framework for perfectionists with deadlines.

    • D3 photo D3

      Bring data to life with SVG, Canvas and HTML. 📊📈🎉

    Recommend Topics

    • javascript

      JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

    • web

      Some thing interesting about web. New door for the world.

    • server

      A server is a program made to process requests and deliver data to clients.

    • Machine learning

      Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

    • Game

      Some thing interesting about game, make everyone happy.

    Recommend Org

    • Facebook photo Facebook

      We are working to build community through open source technology. NB: members must have two-factor auth.

    • Microsoft photo Microsoft

      Open source projects and samples from Microsoft.

    • Google photo Google

      Google ❤️ Open Source for everyone.

    • D3 photo D3

      Data-Driven Documents codes.