Giter Site home page Giter Site logo

chenyaofo / pytorch-cifar-models Goto Github PK

View Code? Open in Web Editor NEW
267.0 1.0 49.0 9.22 MB

Pretrained models on CIFAR10/100 in PyTorch

License: BSD 3-Clause "New" or "Revised" License

Python 92.21% Jupyter Notebook 7.79%
cifar10 cifar100 classification deep-learning pretrained-models pytorch-cifar-models notebook pytorch

pytorch-cifar-models's Introduction

PyTorch CIFAR Models

Introduction

The goal of this project is to provide some neural network examples and a simple training codebase for begginners.

Get Started with Google Colab Open In Colab

Train Models: Open the notebook to train the models from scratch on CIFAR10/100. It will takes several hours depend on the complexity of the model and the allocated GPU type.

Test Models: Open the notebook to measure the validation accuracy on CIFAR10/100 with pretrained models. It will only take about few seconds.

Use Models with Pytorch Hub

You can simply use the pretrained models in your project with torch.hub API. It will automatically load the code and the pretrained weights from GitHub (If you cannot directly access GitHub, please check this issue for solution).

import torch
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)

To list all available model entry, you can run:

import torch
from pprint import pprint
pprint(torch.hub.list("chenyaofo/pytorch-cifar-models", force_reload=True))

Model Zoo

CIFAR-10

Model Top-1 Acc.(%) Top-5 Acc.(%) #Params.(M) #MAdds(M)
resnet20 92.60 99.81 0.27 40.81 model | log
resnet32 93.53 99.77 0.47 69.12 model | log
resnet44 94.01 99.77 0.66 97.44 model | log
resnet56 94.37 99.83 0.86 125.75 model | log
vgg11_bn 92.79 99.72 9.76 153.29 model | log
vgg13_bn 94.00 99.77 9.94 228.79 model | log
vgg16_bn 94.16 99.71 15.25 313.73 model | log
vgg19_bn 93.91 99.64 20.57 398.66 model | log
mobilenetv2_x0_5 92.88 99.86 0.70 27.97 model | log
mobilenetv2_x0_75 93.72 99.79 1.37 59.31 model | log
mobilenetv2_x1_0 93.79 99.73 2.24 87.98 model | log
mobilenetv2_x1_4 94.22 99.80 4.33 170.07 model | log
shufflenetv2_x0_5 90.13 99.70 0.35 10.90 model | log
shufflenetv2_x1_0 92.98 99.73 1.26 45.00 model | log
shufflenetv2_x1_5 93.55 99.77 2.49 94.26 model | log
shufflenetv2_x2_0 93.81 99.79 5.37 187.81 model | log
repvgg_a0 94.39 99.82 7.84 489.08 model | log
repvgg_a1 94.89 99.83 12.82 851.33 model | log
repvgg_a2 94.98 99.82 26.82 1850.10 model | log

CIFAR-100

Model Top-1 Acc.(%) Top-5 Acc.(%) #Params.(M) #MAdds(M)
resnet20 68.83 91.01 0.28 40.82 model | log
resnet32 70.16 90.89 0.47 69.13 model | log
resnet44 71.63 91.58 0.67 97.44 model | log
resnet56 72.63 91.94 0.86 125.75 model | log
vgg11_bn 70.78 88.87 9.80 153.34 model | log
vgg13_bn 74.63 91.09 9.99 228.84 model | log
vgg16_bn 74.00 90.56 15.30 313.77 model | log
vgg19_bn 73.87 90.13 20.61 398.71 model | log
mobilenetv2_x0_5 70.88 91.72 0.82 28.08 model | log
mobilenetv2_x0_75 73.61 92.61 1.48 59.43 model | log
mobilenetv2_x1_0 74.20 92.82 2.35 88.09 model | log
mobilenetv2_x1_4 75.98 93.44 4.50 170.23 model | log
shufflenetv2_x0_5 67.82 89.93 0.44 10.99 model | log
shufflenetv2_x1_0 72.39 91.46 1.36 45.09 model | log
shufflenetv2_x1_5 73.91 92.13 2.58 94.35 model | log
shufflenetv2_x2_0 75.35 92.62 5.55 188.00 model | log
repvgg_a0 75.22 92.93 7.96 489.19 model | log
repvgg_a1 76.12 92.71 12.94 851.44 model | log
repvgg_a2 77.18 93.51 26.94 1850.22 model | log

pytorch-cifar-models's People

Contributors

chenyaofo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

pytorch-cifar-models's Issues

How to split the training set and the validation set?

Thanks for your pretrained-models on CIFAR10, I want to do some test with pretrained model, howerer, the inference acc(99.624%) seems too high (using torchvision.datasets.CFAR10). I think your train/val split approach is different with torchvision, could you tell me how your data set is divided?

accuracy score

I downloaded some of the networks and the accuracy is different from what you mentioned.
For example after downloading the RES-NET I checked the accuracy and it was only 0.8 (on train).
Could you please check if this is a problem happens only to me?
(I down loaded the model from README and then uploaded into python)

Error running on Colab

Hi, I ran

!python -m entry.run --conf conf/cifar10.conf -o output/cifar10/resnet20 -M model.name=cifar10_resnet20

on colab, and it gives me the error

Traceback (most recent call last):
  File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/content/image-classification-codebase/entry/run.py", line 6, in <module>
    main(get_args())
  File "/content/image-classification-codebase/codebase/main.py", line 234, in main
    main_worker(local_rank, ngpus_per_node, args, args.conf)
  File "/content/image-classification-codebase/codebase/main.py", line 203, in main_worker
    prepare_for_training(conf, args.output_dir, local_rank)
  File "/content/image-classification-codebase/codebase/main.py", line 142, in prepare_for_training
    basic_bs = optimizer_config.pop("basic_bs")
  File "/usr/local/lib/python3.7/dist-packages/pyhocon/config_tree.py", line 274, in pop
    value = self.get(key, UndefinedKey)
  File "/usr/local/lib/python3.7/dist-packages/pyhocon/config_tree.py", line 236, in get
    return self._get(ConfigTree.parse_key(key), 0, default)
  File "/usr/local/lib/python3.7/dist-packages/pyhocon/config_tree.py", line 177, in _get
    u"No configuration setting found for key {key}".format(key='.'.join(key_path[:key_index + 1])))
pyhocon.exceptions.ConfigMissingException: 'No configuration setting found for key basic_bs'

Does anyone know how to solve this?

Incorrect value for pretrained in the readme

model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar_resnet20", pretrained=true)
should be
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar_resnet20", pretrained="cifar10")
or
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar_resnet20", pretrained="cifar100")

Question about pretrained models

Hi, I was wondering how many epochs were the pre-trained models trained for? And what other hyperparameters (learning rate, optimizer, scheduler, etc.) were used? I'm particularly interested in ResNet-20.
Thanks!

Data preprocessing

I cannot find in the training logs how the data is being preprocessed for training and testing datasets. Does this mean that normalization only has been applied for both?

Can not load model

There is an Error: http.client.RemoteDisconnected: Remote end closed connection without response
Code:
net = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)

Environment:
Pytorch 1.11

`mobilenetv2_x0_75` trained on `Cifar100 raised HTTPError

All models worked (both for Cifar10 and 100); except mobilenetv2_x0_75 trained with Cifar100.

Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/archive/master.zip" to C:\Users\xxx/.cache\torch\hub\master.zip
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x0_75-9ab3e178.pt" to C:\Users\xxx/.cache\torch\hub\checkpoints\cifar100_mobilenetv2_x0_75-9ab3e178.p

...
...
 File "C:\ProgramData\Anaconda3\envs\PyTorchEight\lib\urllib\request.py", line 649, in http_error_default
    raise HTTPError(req.full_url, code, msg, hdrs, fp)

HTTPError: Not Found

I thought it is the .pt filename, but looks correct to me.

Update:
When I try to download the .pt model directly (by on the link to the model), it does not download and I got the message Failed-No file.
Direct download works for all the other models.

cifar10的weight decay

您好,我注意到论文中在cifar10上训练的resnet56微调时使用的weight decay是0.05,请问这是不是笔误呢。
image

requirements.txt not found

pip install throws the following error. requirements.txt is missing.

FileNotFoundError: [Errno 2] No such file or directory: 'requirements.txt'

Mean and stds used for normalization

I couldn't find in the code/documentation the values used for mean/std of the normalization done if I want to test on new images, were they the same ones used in the image-classification-codebase code? ie,

mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]

Side note, on the get_vit_val_transforms(mean, std, img_size) and def get_vit_train_transforms(mean, std, img_size) functions there seems to be a small bug, you end up not using the values of mean/std passed in to the function and instead using hard coded ones (unless this is intended behaviour, but the mean/std parameters threw me off)

CIFAR-10 normalize params

Thanks for the pre-trained models!

I'm testing the pre-trained models on CIFAR-10 using torchvision.datasets.CIFAR10() and found that you are using different normalization params from pytorch.org:

  • From your repo: mean: [0.4914, 0.4822, 0.4465], std: [0.2023, 0.1994, 0.2010].
  • From pytorch.org: transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)).

I can reproduce the results using the same normalization params as you had. It took me some time to locate the issue so I assume clarifying the normalization params somewhere in the project README would help other people.

Model's accuracy is different with your report on README.md when I evaluate the models.

Hi, I tried to evaluate your models' accuracy in two ways. In your report, cifar100_MobileNetV2_x1_0 model accuracy is 74.20, but when I tried

  1. evaluate by your start_on_colab,
  2. and evaluate by my own code
import torch
from torchvision import datasets, transforms
# data will be downloaded in data_directory.
data_directory = './data'
batchsize=256
device = 'cuda'
normalize = transforms.Normalize(mean=[0.507, 0.4865, 0.4409],
                                 std=[0.2673, 0.2564, 0.2761])

train_dataset = datasets.CIFAR100(root=data_directory,
                                  train=True,
                                  transform=transforms.Compose([
                                                                  transforms.RandomCrop(32, padding=4),
                                                                  transforms.RandomHorizontalFlip(),
                                                                  transforms.ToTensor(),
                                                                  normalize]),
                                  download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batchsize,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=2)

test_dataset = datasets.CIFAR100(root=data_directory,
                                 train=False,
                                 transform=transforms.Compose([
                                                                  transforms.ToTensor(),
                                                                  normalize]),
                                 download=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batchsize,
                                          shuffle=False,
                                          pin_memory=True,
                                          num_workers=2)


def validate(model,loader):
    global device
    model.eval()
    correct = 0.
    total = 0.
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            pred = model(images)

        pred = torch.max(pred.data, 1)[1]
        total += labels.size(0)
        correct += (pred == labels).sum().item()

    val_acc = (correct / total)*100
    model.train()
    return val_acc
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_mobilenetv2_x1_0", pretrained=True).to(device)
validate(model, test_loader)

and I got

  1. 74.30 (+0.1)
  2. 74.29 (+0.09).

Can you check your models?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.