Giter Site home page Giter Site logo

weiaicunzai / bag_of_tricks_for_image_classification_with_convolutional_neural_networks Goto Github PK

View Code? Open in Web Editor NEW
699.0 20.0 121.0 55 KB

experiments on Paper <Bag of Tricks for Image Classification with Convolutional Neural Networks> and other useful tricks to improve CNN acc

Python 100.00%
pytorch image-classification

bag_of_tricks_for_image_classification_with_convolutional_neural_networks's Introduction

Bag of Tricks for Image Classification with Convolutional Neural Networks

This repo was inspired by Paper Bag of Tricks for Image Classification with Convolutional Neural Networks

I would test popular training tricks as many as I can for improving image classification accuarcy, feel free to leave a comment about the tricks you want me to test(please write the referenced paper along with the tricks)

hardware

Using 4 Tesla P40 to run the experiments

dataset

I will use CUB_200_2011 dataset instead of ImageNet, just for simplicity, this is a fine-grained image classification dataset, which contains 200 birds categlories, 5K+ training images, and 5K+ test images.The state of the art acc on vgg16 is around 73%(please correct me if I was wrong).You could easily change it to the ones you like: Stanford Dogs, Stanford Cars. Or even ImageNet.

network

Use a VGG16 network to test my tricks, also for simplicity reasons, since VGG16 is easy to implement. I'm considering switch to AlexNet, to see how powerful these tricks are.

tricks

tricks I've tested, some of them were from the Paper Bag of Tricks for Image Classification with Convolutional Neural Networks :

trick referenced paper
xavier init Understanding the difficulty of training deep feedforward neural networks
warmup training Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour
no bias decay Highly Scalable Deep Learning Training System with Mixed-Precision: Training ImageNet in Four Minutes
label smoothing Rethinking the inception architecture for computer vision)
random erasing Random Erasing Data Augmentation
cutout Improved Regularization of Convolutional Neural Networks with Cutout
linear scaling learning rate Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour
cosine learning rate decay SGDR: Stochastic Gradient Descent with Warm Restarts

and more to come......

result

baseline(training from sctrach, no ImageNet pretrain weights are used):

vgg16 64.60% on CUB_200_2011 dataset, lr=0.01, batchsize=64

effects of stacking tricks

trick acc
baseline 64.60%
+xavier init and warmup training 66.07%
+no bias decay 70.14%
+label smoothing 71.20%
+random erasing does not work, drops about 4 points
+linear scaling learning rate(batchsize 256, lr 0.04) 71.21%
+cutout does not work, drops about 1 point
+cosine learning rate decay does not work, drops about 1 point

bag_of_tricks_for_image_classification_with_convolutional_neural_networks's People

Contributors

weiaicunzai avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bag_of_tricks_for_image_classification_with_convolutional_neural_networks's Issues

resnet-C structure cannot decrease computational cost

Paper says 'replace a 7x7 convolution with three 3x3 convolutions'. I am confused about it.
ex:
input: 224x224 channel=3
7x7 conv pad=3 stride=2 channel=64-> output dimension: 112x112
Its flops is: 7x7x3x64x112x112

If replace 7x7 conv with three 3x3 conv, there are two methods:

  1. 'the first and second convolutions have their output channel of 32 and a stride of 2, with the last convolution, uses a 64-output channel', If this is the case, the three 3x3 conv output dimension is 64x64, It is not according to with 7x7 conv output dimension. (Form figure2.b, three 3x3 conv + maxpool similar to 7x7 conv + maxpool)

  2. the first 3x3 conv have a stride of 2, the second 3x3 conv have a stride of 1, It has the same output dimension as 7x7 conv, but their flops are 3x3x3x32x112x112 + 3x3x32x32x112x112 + 3x3x32x64x112x112, the flops are bigger than original 7x7 conv.

Can you give me some tutorials about it?

update:
From Table 5, the flops of ResNet-50-B is 0.3 G bigger than original ResNet-50, so it may use the above setting 2, so Is the author's description wrong?

About learning rate

I wonder how much the value of the base learning rate you use when batch size is 256? I have tried 0.1 ( as the figure 3 of the paper suggests ) , but got a bad consequence.
I find that in your code the default value of the learning rate is 0.04, does it work?

Forget covert BGR to RGB using cv2?

ToCVImage just set an image as uint8, but not to convert BGR2RGB because we usually use RGB. And the normalize is RGB data.

Forget covert BGR to RGB using cv2?

Errors when I run train.py

File "train.py", line 165, in
predicts = net(images)
File "/home//.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home//.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 159, in forward
return self.module(*inputs[0], **kwargs[0])
File "/home//.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home//LJY/Bag_of_Tricks_for_Image_Classification_with_Convolutional_Neural_Networks-master/models/vgg.py", line 53, in forward
x = self.classifier(x)
File "/home//.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home//.local/lib/python3.6/site-packages/torch/nn/modules/container.py", line 117, in forward
input = module(input)
File "/home//.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home//.local/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/home//.local/lib/python3.6/site-packages/torch/nn/functional.py", line 1690, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0

question

Do you use "xavier init" in code? But I did not find it.Thank you !

Tricks also apply for ResNets?

Hi,

Thanks for your implementation and repo.

I tested on ResNet for CIFAR100 and seems Label Smoothing, No-bias-decay does not improve the result.

BTW, I disable all the image augmentation for both train and test, only use normalize and purely test on above mentioned tricks.

If you also tested on ResNet, say ResNet56, and have different results, please let me know.

Really thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.