Giter Site home page Giter Site logo

cnn-arch's Introduction

CNN Architecures

Here, I present some of my own implementations of some popular architectures for Convolutional Neural Networks (CNNs). All of the models have been trained on the CIFAR-10 dataset.
As far as pre-processing the images is concerned, I have normalized them using the mean and standard deviation of the CIFAR-10 dataset. Further, all the images in the training set were subject to:

  • Random horizontal flips
  • Random crops

I have written the DenseNet implementation using the Keras backend, whereas the other two have been written using PyTorch (purely as a learning exercise).
Although I created files for storing checkpoints after every 10 epochs, I was unable to upload it along with the notebooks because of GitHub’s limit on maximum file size (being 100MiB).

DenseNet

My implementation had a total of 995,230 parameters. This model was trained for a total of 120 epochs. Thr training and testing accuracy are as follows:

  • Training accuracy: 93.32%
  • Testing accuracy: 91.47%

Some hyperparameters that I used:

  • Mini-batch size: 128
  • Filter number: 36
  • Compression factor: 0.5
  • Dropout rate: 0.2
  • Learning rate: 0.01
  • Momentum: 0.7

A ReLU non-linearity was used for computing activations. The loss function chosen for this task was the cross-entropy loss, which was optimized using Stochastic Gradient Descent (SGD).

ResNet-56

This implementation was pretty close to the actual one that was described in the original paper. In fact, I observed a slightly higher accuracy on the test set in comparison to what was quoted by the authors.
I had a total of 853,018 parameters, which I trained for 200 epochs. The model’s performance was pretty impressive, as it followed in close conjunction with the quoted accuracies (top-1%).

  • Training accuracy: 99.9265%
  • Testing accuracy: 91.730%

Some hyperparameters that I used:

  • Mini-batch size: 128
  • Learning rate: 0.1
  • Momentum: 0.9
  • Weight decay: 1e-4

ResNeXt

This did not exactly perform “upto the mark” because I realized far too late into the process that my model was extremely fat. I had a total of 27,104,586 trainable parameters (oops), which is far more than the “ideal” number as detailed by the authors in this paper (max being at 20M).
Nevertheless, this resulted in an extremely slow convergence of the model (which took around 10 hours on a GPU!) to complete just 55 epochs. On account of training it for such less steps, I was unable to achive the accuracies that I was looking for.

  • Training accuracy: 93.4%
  • Testing accuracy: 92.54%

Some of the hyperparameters that I used for this model:

  • Expansion factor: 2
  • Cardinality: 4
  • Bottleneck width: 64
  • Learning rate: 0.1
  • Weight decay: 5e-4
  • Momentum: 0.9

I used the cross-entropy loss function, which was optimized using Stochastic Gradient Descent (SGD). I also used a cosine annealed warm restart learning scheduler using SGD and with a maximum of 200 epochs.

cnn-arch's People

Contributors

ckapoor7 avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.