Giter Site home page Giter Site logo

dnlcrl / deep-residual-networks-pyfunt Goto Github PK

View Code? Open in Web Editor NEW
52.0 6.0 10.0 4.7 MB

Python implementation of "Deep Residual Learning for Image Recognition" (http://arxiv.org/abs/1512.03385 - MSRA, winner team of the 2015 ILSVRC and COCO challenges).

License: MIT License

Python 100.00%
python residual-networks deep-residual-learning mnist image-recognition ipython-notebook image-classification numpy strider accuracy

deep-residual-networks-pyfunt's Introduction

Deep Residual Learning for Image Recognition

Implementation of "Deep Residual Learning for Image Recognition", Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun in PyFunt (a simple Python + Numpy DL framework).

Also inspired by this implementation in Lua + Torch.

The network operates on minibatches of data that have shape (N, C, H, W) consisting of N images, each with height H and width W and with C input channels. It has, like in the reference paper, (6*n)+2 layers, composed as below:

		                                        (image_dim: 3, 32, 32; F=16)
		                                        (input_dim: N, *image_dim)
		 INPUT
		    |
		    v
		+-------------------+
		|conv[F, *image_dim]|                    (out_shape: N, 16, 32, 32)
		+-------------------+
		    |
		    v
		+-------------------------+
		|n * res_block[F, F, 3, 3]|              (out_shape: N, 16, 32, 32)
		+-------------------------+
		    |
		    v
		+-------------------------+
		|res_block[2*F, F, 3, 3]  |              (out_shape: N, 32, 16, 16)
		+-------------------------+
		    |
		    v
		+---------------------------------+
		|(n-1) * res_block[2*F, 2*F, 3, 3]|      (out_shape: N, 32, 16, 16)
		+---------------------------------+
		    |
		    v
		+-------------------------+
		|res_block[4*F, 2*F, 3, 3]|              (out_shape: N, 64, 8, 8)
		+-------------------------+
		    |
		    v
		+---------------------------------+
		|(n-1) * res_block[4*F, 4*F, 3, 3]|      (out_shape: N, 64, 8, 8)
		+---------------------------------+
		    |
		    v
		+-------------+
		|pool[1, 8, 8]|                          (out_shape: N, 64, 1, 1)
		+-------------+
		    |
		    v
		+-------+
		|softmax|                                (out_shape: N, num_classes)
		+-------+
		    |
		    v
		 OUTPUT

Every convolution layer has a pad=1 and stride=1, except for the dimension enhancning layers which has a stride of 2 to mantain the computational complexity. Optionally, there is the possibility of setting m affine layers immediatley before the softmax layer by setting the hidden_dims parameter, which should be a list of integers representing the numbe of neurons for each affine layer.

Each residual block is composed as below:

          Input
             |
     ,-------+-----.
Downsampling      3x3 convolution+dimensionality reduction
    |               |
    v               v
Zero-padding      3x3 convolution
    |               |
    `-----( Add )---'
             |
          Output

After every layer, a batch normalization with momentum .1 is applied.

Requirements

After you get Python, you can get pip and install all requirements by running:

pip install -r requirements.txt

Usage

If you want to train the network on the CIFAR-10 dataset, simply run:

python train.py --help

Otherwise, you have to get the right train.py for MNIST or SFDDD datasets, they are respectively on the mnist and sfddd git branches:

Experiments Results

You can view all the experiments results in the ./docs directory. Main results are shown below:

best error: 9.59 % (accuracy: 0.9041) with a 20 layers residual network (n=3):

CIFAR-10 results

CIFAR-10 Results - iPython notebook

best error: 0.36 % (accuracy: 0.9964) with a 32 layers residual network (n=5):

MNIST results

MNIST Results - iPython notebook

best error: 0.25 % (accuracy: 0.9975 %) on a subset (1000 samples) of the train data (~21k images) with a 44 layers residual network (n=7), resizing the images to 64x48, randomly cropping 32x32 images for training and cropping a 32x32 image from the center of the original images for testing. Unfortunately I got more than 2% error on Kaggle's results (composed of ~80k images).

SFDDD results

SFDDD Results - iPython notebook

deep-residual-networks-pyfunt's People

Contributors

dnlcrl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

deep-residual-networks-pyfunt's Issues

About nnet.utils.vis_utils

Hi, I tried to run code but I have an error with the nnet.utils.vis_utils. I think that is not a package in python. Is that a user-defined module? How can get the access to use it?

Thanks

两个问题

问题1:pip list中有pyfunt,运行requirements.txt找不到,在python包中也找不到,是安装失败了吗。而且用sudo pip install git+git://github.com/dnlcrl/PyFunt.git安装会报错。SSLError: The read operation timed out
Storing debug log for failure in /home/zxc/.pip/pip.log

问题2:运行train.py时,No module named pydatset.cifar10,但是在/usr/local/lib/python2.7/dist-packages路径下,确实已经安装了pydatset。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.