Giter Site home page Giter Site logo

huggingface / pytorch-image-models Goto Github PK

View Code? Open in Web Editor NEW
29.9K 304.0 4.6K 25.28 MB

PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNet-V3/V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more

Home Page: https://huggingface.co/docs/timm

License: Apache License 2.0

Python 85.08% Shell 0.01% MDX 14.92%
pytorch resnet dual-path-networks pretrained-models pretrained-weights distributed-training mobile-deep-learning mobilenet-v2 mobilenetv3 efficientnet

pytorch-image-models's Introduction

PyTorch Image Models

What's New

❗Updates after Oct 10, 2022 are available in version >= 0.9❗

  • Many changes since the last 0.6.x stable releases. They were previewed in 0.8.x dev releases but not everyone transitioned.
  • timm.models.layers moved to timm.layers:
    • from timm.models.layers import name will still work via deprecation mapping (but please transition to timm.layers).
    • import timm.models.layers.module or from timm.models.layers.module import name needs to be changed now.
  • Builder, helper, non-model modules in timm.models have a _ prefix added, ie timm.models.helpers -> timm.models._helpers, there are temporary deprecation mapping files but those will be removed.
  • All models now support architecture.pretrained_tag naming (ex resnet50.rsb_a1).
    • The pretrained_tag is the specific weight variant (different head) for the architecture.
    • Using only architecture defaults to the first weights in the default_cfgs for that model architecture.
    • In adding pretrained tags, many model names that existed to differentiate were renamed to use the tag (ex: vit_base_patch16_224_in21k -> vit_base_patch16_224.augreg_in21k). There are deprecation mappings for these.
  • A number of models had their checkpoints remaped to match architecture changes needed to better support features_only=True, there are checkpoint_filter_fn methods in any model module that was remapped. These can be passed to timm.models.load_checkpoint(..., filter_fn=timm.models.swin_transformer_v2.checkpoint_filter_fn) to remap your existing checkpoint.
  • The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for timm weights. Model cards include link to papers, original source, license.
  • Previous 0.6.x can be cloned from 0.6.x branch or installed via pip with version.

April 11, 2024

  • Prepping for a long overdue 1.0 release, things have been stable for a while now.
  • Significant feature that's been missing for a while, features_only=True support for ViT models with flat hidden states or non-std module layouts (so far covering 'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*')
  • Above feature support achieved through a new forward_intermediates() API that can be used with a feature wrapping module or direclty.
model = timm.create_model('vit_base_patch16_224')
final_feat, intermediates = model.forward_intermediates(input) 
output = model.forward_head(final_feat)  # pooling + classifier head

print(final_feat.shape)
torch.Size([2, 197, 768])

for f in intermediates:
    print(f.shape)
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])

print(output.shape)
torch.Size([2, 1000])
model = timm.create_model('eva02_base_patch16_clip_224', pretrained=True, img_size=512, features_only=True, out_indices=(-3, -2,))
output = model(torch.randn(2, 3, 512, 512))

for o in output:    
    print(o.shape)   
torch.Size([2, 768, 32, 32])
torch.Size([2, 768, 32, 32])
  • TinyCLIP vision tower weights added, thx Thien Tran

Feb 19, 2024

  • Next-ViT models added. Adapted from https://github.com/bytedance/Next-ViT
  • HGNet and PP-HGNetV2 models added. Adapted from https://github.com/PaddlePaddle/PaddleClas by SeeFun
  • Removed setup.py, moved to pyproject.toml based build supported by PDM
  • Add updated model EMA impl using _for_each for less overhead
  • Support device args in train script for non GPU devices
  • Other misc fixes and small additions
  • Min supported Python version increased to 3.8
  • Release 0.9.16

Jan 8, 2024

Datasets & transform refactoring

  • HuggingFace streaming (iterable) dataset support (--dataset hfids:org/dataset)
  • Webdataset wrapper tweaks for improved split info fetching, can auto fetch splits from supported HF hub webdataset
  • Tested HF datasets and webdataset wrapper streaming from HF hub with recent timm ImageNet uploads to https://huggingface.co/timm
  • Make input & target column/field keys consistent across datasets and pass via args
  • Full monochrome support when using e:g: --input-size 1 224 224 or --in-chans 1, sets PIL image conversion appropriately in dataset
  • Improved several alternate crop & resize transforms (ResizeKeepRatio, RandomCropOrPad, etc) for use in PixParse document AI project
  • Add SimCLR style color jitter prob along with grayscale and gaussian blur options to augmentations and args
  • Allow train without validation set (--val-split '') in train script
  • Add --bce-sum (sum over class dim) and --bce-pos-weight (positive weighting) args for training as they're common BCE loss tweaks I was often hard coding

Nov 23, 2023

  • Added EfficientViT-Large models, thanks SeeFun
  • Fix Python 3.7 compat, will be dropping support for it soon
  • Other misc fixes
  • Release 0.9.12

Nov 20, 2023

Nov 3, 2023

Oct 20, 2023

  • SigLIP image tower weights supported in vision_transformer.py.
    • Great potential for fine-tune and downstream feature use.
  • Experimental 'register' support in vit models as per Vision Transformers Need Registers
  • Updated RepViT with new weight release. Thanks wangao
  • Add patch resizing support (on pretrained weight load) to Swin models
  • 0.9.8 release pending

Sep 1, 2023

  • TinyViT added by SeeFun
  • Fix EfficientViT (MIT) to use torch.autocast so it works back to PT 1.10
  • 0.9.7 release

Aug 28, 2023

  • Add dynamic img size support to models in vision_transformer.py, vision_transformer_hybrid.py, deit.py, and eva.py w/o breaking backward compat.
    • Add dynamic_img_size=True to args at model creation time to allow changing the grid size (interpolate abs and/or ROPE pos embed each forward pass).
    • Add dynamic_img_pad=True to allow image sizes that aren't divisible by patch size (pad bottom right to patch size each forward pass).
    • Enabling either dynamic mode will break FX tracing unless PatchEmbed module added as leaf.
    • Existing method of resizing position embedding by passing different img_size (interpolate pretrained embed weights once) on creation still works.
    • Existing method of changing patch_size (resize pretrained patch_embed weights once) on creation still works.
    • Example validation cmd python validate.py /imagenet --model vit_base_patch16_224 --amp --amp-dtype bfloat16 --img-size 255 --crop-pct 1.0 --model-kwargs dynamic_img_size=True dyamic_img_pad=True

Aug 25, 2023

Aug 11, 2023

  • Swin, MaxViT, CoAtNet, and BEiT models support resizing of image/window size on creation with adaptation of pretrained weights
  • Example validation cmd to test w/ non-square resize python validate.py /imagenet --model swin_base_patch4_window7_224.ms_in22k_ft_in1k --amp --amp-dtype bfloat16 --input-size 3 256 320 --model-kwargs window_size=8,10 img_size=256,320

Aug 3, 2023

  • Add GluonCV weights for HRNet w18_small and w18_small_v2. Converted by SeeFun
  • Fix selecsls* model naming regression
  • Patch and position embedding for ViT/EVA works for bfloat16/float16 weights on load (or activations for on-the-fly resize)
  • v0.9.5 release prep

July 27, 2023

  • Added timm trained seresnextaa201d_32x8d.sw_in12k_ft_in1k_384 weights (and .sw_in12k pretrain) with 87.3% top-1 on ImageNet-1k, best ImageNet ResNet family model I'm aware of.
  • RepViT model and weights (https://arxiv.org/abs/2307.09283) added by wangao
  • I-JEPA ViT feature weights (no classifier) added by SeeFun
  • SAM-ViT (segment anything) feature weights (no classifier) added by SeeFun
  • Add support for alternative feat extraction methods and -ve indices to EfficientNet
  • Add NAdamW optimizer
  • Misc fixes

May 11, 2023

  • timm 0.9 released, transition from 0.8.xdev releases

May 10, 2023

  • Hugging Face Hub downloading is now default, 1132 models on https://huggingface.co/timm, 1163 weights in timm
  • DINOv2 vit feature backbone weights added thanks to Leng Yue
  • FB MAE vit feature backbone weights added
  • OpenCLIP DataComp-XL L/14 feat backbone weights added
  • MetaFormer (poolformer-v2, caformer, convformer, updated poolformer (v1)) w/ weights added by Fredo Guan
  • Experimental get_intermediate_layers function on vit/deit models for grabbing hidden states (inspired by DINO impl). This is WIP and may change significantly... feedback welcome.
  • Model creation throws error if pretrained=True and no weights exist (instead of continuing with random initialization)
  • Fix regression with inception / nasnet TF sourced weights with 1001 classes in original classifiers
  • bitsandbytes (https://github.com/TimDettmers/bitsandbytes) optimizers added to factory, use bnb prefix, ie bnbadam8bit
  • Misc cleanup and fixes
  • Final testing before switching to a 0.9 and bringing timm out of pre-release state

April 27, 2023

  • 97% of timm models uploaded to HF Hub and almost all updated to support multi-weight pretrained configs
  • Minor cleanup and refactoring of another batch of models as multi-weight added. More fused_attn (F.sdpa) and features_only support, and torchscript fixes.

April 21, 2023

  • Gradient accumulation support added to train script and tested (--grad-accum-steps), thanks Taeksang Kim
  • More weights on HF Hub (cspnet, cait, volo, xcit, tresnet, hardcorenas, densenet, dpn, vovnet, xception_aligned)
  • Added --head-init-scale and --head-init-bias to train.py to scale classiifer head and set fixed bias for fine-tune
  • Remove all InplaceABN (inplace_abn) use, replaced use in tresnet with standard BatchNorm (modified weights accordingly).

April 12, 2023

  • Add ONNX export script, validate script, helpers that I've had kicking around for along time. Tweak 'same' padding for better export w/ recent ONNX + pytorch.
  • Refactor dropout args for vit and vit-like models, separate drop_rate into drop_rate (classifier dropout), proj_drop_rate (block mlp / out projections), pos_drop_rate (position embedding drop), attn_drop_rate (attention dropout). Also add patch dropout (FLIP) to vit and eva models.
  • fused F.scaled_dot_product_attention support to more vit models, add env var (TIMM_FUSED_ATTN) to control, and config interface to enable/disable
  • Add EVA-CLIP backbones w/ image tower weights, all the way up to 4B param 'enormous' model, and 336x336 OpenAI ViT mode that was missed.

April 5, 2023

  • ALL ResNet models pushed to Hugging Face Hub with multi-weight support
  • New ImageNet-12k + ImageNet-1k fine-tunes available for a few anti-aliased ResNet models
    • resnetaa50d.sw_in12k_ft_in1k - 81.7 @ 224, 82.6 @ 288
    • resnetaa101d.sw_in12k_ft_in1k - 83.5 @ 224, 84.1 @ 288
    • seresnextaa101d_32x8d.sw_in12k_ft_in1k - 86.0 @ 224, 86.5 @ 288
    • seresnextaa101d_32x8d.sw_in12k_ft_in1k_288 - 86.5 @ 288, 86.7 @ 320

March 31, 2023

  • Add first ConvNext-XXLarge CLIP -> IN-1k fine-tune and IN-12k intermediate fine-tunes for convnext-base/large CLIP models.
model top1 top5 img_size param_count gmacs macts
convnext_xxlarge.clip_laion2b_soup_ft_in1k 88.612 98.704 256 846.47 198.09 124.45
convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384 88.312 98.578 384 200.13 101.11 126.74
convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320 87.968 98.47 320 200.13 70.21 88.02
convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384 87.138 98.212 384 88.59 45.21 84.49
convnext_base.clip_laion2b_augreg_ft_in12k_in1k 86.344 97.97 256 88.59 20.09 37.55
  • Add EVA-02 MIM pretrained and fine-tuned weights, push to HF hub and update model cards for all EVA models. First model over 90% top-1 (99% top-5)! Check out the original code & weights at https://github.com/baaivision/EVA for more details on their work blending MIM, CLIP w/ many model, dataset, and train recipe tweaks.
model top1 top5 param_count img_size
eva02_large_patch14_448.mim_m38m_ft_in22k_in1k 90.054 99.042 305.08 448
eva02_large_patch14_448.mim_in22k_ft_in22k_in1k 89.946 99.01 305.08 448
eva_giant_patch14_560.m30m_ft_in22k_in1k 89.792 98.992 1014.45 560
eva02_large_patch14_448.mim_in22k_ft_in1k 89.626 98.954 305.08 448
eva02_large_patch14_448.mim_m38m_ft_in1k 89.57 98.918 305.08 448
eva_giant_patch14_336.m30m_ft_in22k_in1k 89.56 98.956 1013.01 336
eva_giant_patch14_336.clip_ft_in1k 89.466 98.82 1013.01 336
eva_large_patch14_336.in22k_ft_in22k_in1k 89.214 98.854 304.53 336
eva_giant_patch14_224.clip_ft_in1k 88.882 98.678 1012.56 224
eva02_base_patch14_448.mim_in22k_ft_in22k_in1k 88.692 98.722 87.12 448
eva_large_patch14_336.in22k_ft_in1k 88.652 98.722 304.53 336
eva_large_patch14_196.in22k_ft_in22k_in1k 88.592 98.656 304.14 196
eva02_base_patch14_448.mim_in22k_ft_in1k 88.23 98.564 87.12 448
eva_large_patch14_196.in22k_ft_in1k 87.934 98.504 304.14 196
eva02_small_patch14_336.mim_in22k_ft_in1k 85.74 97.614 22.13 336
eva02_tiny_patch14_336.mim_in22k_ft_in1k 80.658 95.524 5.76 336
  • Multi-weight and HF hub for DeiT and MLP-Mixer based models

March 22, 2023

  • More weights pushed to HF hub along with multi-weight support, including: regnet.py, rexnet.py, byobnet.py, resnetv2.py, swin_transformer.py, swin_transformer_v2.py, swin_transformer_v2_cr.py
  • Swin Transformer models support feature extraction (NCHW feat maps for swinv2_cr_*, and NHWC for all others) and spatial embedding outputs.
  • FocalNet (from https://github.com/microsoft/FocalNet) models and weights added with significant refactoring, feature extraction, no fixed resolution / sizing constraint
  • RegNet weights increased with HF hub push, SWAG, SEER, and torchvision v2 weights. SEER is pretty poor wrt to performance for model size, but possibly useful.
  • More ImageNet-12k pretrained and 1k fine-tuned timm weights:
    • rexnetr_200.sw_in12k_ft_in1k - 82.6 @ 224, 83.2 @ 288
    • rexnetr_300.sw_in12k_ft_in1k - 84.0 @ 224, 84.5 @ 288
    • regnety_120.sw_in12k_ft_in1k - 85.0 @ 224, 85.4 @ 288
    • regnety_160.lion_in12k_ft_in1k - 85.6 @ 224, 86.0 @ 288
    • regnety_160.sw_in12k_ft_in1k - 85.6 @ 224, 86.0 @ 288 (compare to SWAG PT + 1k FT this is same BUT much lower res, blows SEER FT away)
  • Model name deprecation + remapping functionality added (a milestone for bringing 0.8.x out of pre-release). Mappings being added...
  • Minor bug fixes and improvements.

Feb 26, 2023

  • Add ConvNeXt-XXLarge CLIP pretrained image tower weights for fine-tune & features (fine-tuning TBD) -- see model card
  • Update convnext_xxlarge default LayerNorm eps to 1e-5 (for CLIP weights, improved stability)
  • 0.8.15dev0

Feb 20, 2023

  • Add 320x320 convnext_large_mlp.clip_laion2b_ft_320 and convnext_lage_mlp.clip_laion2b_ft_soup_320 CLIP image tower weights for features & fine-tune
  • 0.8.13dev0 pypi release for latest changes w/ move to huggingface org

Feb 16, 2023

  • safetensor checkpoint support added
  • Add ideas from 'Scaling Vision Transformers to 22 B. Params' (https://arxiv.org/abs/2302.05442) -- qk norm, RmsNorm, parallel block
  • Add F.scaled_dot_product_attention support (PyTorch 2.0 only) to vit_*, vit_relpos*, coatnet / maxxvit (to start)
  • Lion optimizer (w/ multi-tensor option) added (https://arxiv.org/abs/2302.06675)
  • gradient checkpointing works with features_only=True

Introduction

PyTorch Image Models (timm) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.

The work of many others is present here. I've tried to make sure all source material is acknowledged via links to github, arxiv papers, etc in the README, documentation, and code docstrings. Please let me know if I missed anything.

Features

Models

All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated.

Optimizers

Included optimizers available via create_optimizer / create_optimizer_v2 factory methods:

Augmentations

Regularization

Other

Several (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:

Results

Model validation results can be found in the results tables

Getting Started (Documentation)

The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.

Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide by Chris Hughes is an extensive blog post covering many aspects of timm in detail.

timmdocs is an alternate set of documentation for timm. A big thanks to Aman Arora for his efforts creating timmdocs.

paperswithcode is a good resource for browsing the models within timm.

Train, Validation, Inference Scripts

The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See documentation.

Awesome PyTorch Resources

One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and components here are listed below.

Object Detection, Instance and Semantic Segmentation

Computer Vision / Image Augmentation

Knowledge Distillation

Metric Learning

Training / Frameworks

Licenses

Code

The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue.

Pretrained Weights

So far all of the pretrained weights available here are pretrained on ImageNet with a select few that have some additional pretraining (see extra note below). ImageNet was released for non-commercial research purposes only (https://image-net.org/download). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.

Pretrained on more than ImageNet

Several weights included or references here were pretrained with proprietary datasets that I do not have access to. These include the Facebook WSL, SSL, SWSL ResNe(Xt) and the Google Noisy Student EfficientNet models. The Facebook models have an explicit non-commercial license (CC-BY-NC 4.0, https://github.com/facebookresearch/semi-supervised-ImageNet1K-models, https://github.com/facebookresearch/WSL-Images). The Google models do not appear to have any restriction beyond the Apache 2.0 license (and ImageNet concerns). In either case, you should contact Facebook or Google with any questions.

Citing

BibTeX

@misc{rw2019timm,
  author = {Ross Wightman},
  title = {PyTorch Image Models},
  year = {2019},
  publisher = {GitHub},
  journal = {GitHub repository},
  doi = {10.5281/zenodo.4414861},
  howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
}

Latest DOI

DOI

pytorch-image-models's People

Contributors

a-r-r-o-w avatar alexander-soare avatar amaarora avatar benjaminbossan avatar bryant1410 avatar chris-ha458 avatar christophreich1996 avatar contrastive avatar developer0hye avatar fffffgggg54 avatar gcucurull avatar hankyul2 avatar joao-alex-cunha avatar kaczmarj avatar kecsap avatar kozistr avatar kushajveersingh avatar laurent2916 avatar lorenzbaraldi avatar mehtadushy avatar michalwols avatar morizin avatar mrt23 avatar nateraw avatar okojoalg avatar rwightman avatar seefun avatar separius avatar yassineyousfi avatar yehuitang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-image-models's Issues

AdvProp implementation

https://arxiv.org/abs/1911.09665

In the paper, they propose calculating two losses: one for the forward pass with "clean" BN params, and another for the forward pass with adversarial BN params. Then they combine these two losses, and backprop through both BN paths at the same time (joint optimization).

Does the following look correct to you:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=2)
        self.bnC = nn.BatchNorm2d(32)
        self.bnA = nn.BatchNorm2d(32)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(32*14*14, 10)

    def forward(self, x, clean=True):
        x = self.conv(x)
        if clean:
            x = self.bnC(x)
        else:
            x = self.bnA(x)
        x = self.relu(x)
        x = self.linear(x)
        return x

model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for i in range(1000):
    batchC, targetC = get_clean_batch()
    batchA, targetA = get_adv_batch()

    outputC = model(batchC, clean=True)
    outputA = model(batchA, clean=False)

    lossC = loss_fn(outputC, targetC)
    lossA = loss_fn(outputA, targetA)
    loss = lossC + lossA

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

If so, how would you propagate clean argument to all the blocks, especially the ones that use nn.Sequential lists?

Is there some existing AdvProp code to look at?

Inference Error for EfficientNet-B0

Hi, I have trained efficientnet_b0 and the checkpoint has been produced. When I ran the inference.py using this command
python inference.py /home/usr/images/ --model efficientnet_b0 --checkpoint ../efficientnet_b0-224/checkpoint-13.pth.tar

I got the following error, any idea what's the reason behind this?

Traceback (most recent call last):
  File "inference.py", line 124, in <module>
    main()
  File "inference.py", line 68, in main
    checkpoint_path=args.checkpoint)
  File "/ssd/pytorch/pytorch-image-models/timm/models/factory.py", line 42, in create_model
    load_checkpoint(model, checkpoint_path)
  File "/ssd/pytorch/pytorch-image-models/timm/models/helpers.py", line 22, in load_checkpoint
    model.load_state_dict(new_state_dict)
  File "/home/ivan/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for GenEfficientNet:
	size mismatch for classifier.weight: copying a param with shape torch.Size([30, 1280]) from checkpoint, the shape in current model is torch.Size([1000, 1280]).
	size mismatch for classifier.bias: copying a param with shape torch.Size([30]) from checkpoint, the shape in current model is torch.Size([1000]).

Could you please show the training hyperparameters of "mobilenetv3_100"?

Hey, I have tried to train mobilenetv3_large these days, but I only got the Prec@1 73. I trained it for 150 epochs, using warm up and cosine decay. Pictures were pregrocessed as in MoGA (https://github.com/xiaomi-automl/MoGA/blob/master/dataloader.py). I have seen your training hyperparameters of EfficientNet-B2 and other model in README, but I never seen your training hyperparameters of mobilenetv3_100. I want to know whether there are some problems in my codes or the training hyperparameters are not good enough. Could you please show me your training hyperparameters of "mobilenetv3_100"? I want to reproduce your experiment and draw the curve of learning and testeing to find the reasons.
By the way, I think it will be convenient for experiment if you can add your training scripts in the repo. Thanks a lot!

Cutmix

clovaai/CutMix-PyTorch: Official Pytorch implementation of CutMix regularizer GitHub: https://github.com/clovaai/CutMix-PyTorch

Hi, I saw that you’ve been implementing mixup as an additional feature.
However, if the model trained with mixup used as the backbone of object detector, it seems the performance of the detector degenerates.

Could you please consider cutmix in addition to mixup?

thanks!

About Training Epochs

Hi, I train ResNet50 by 100 epochs, and just got a 76.98% , lower than your result 78.470%.

Do you train all your models listed in README for 100 epochs, or other values?

Thank your very much.

Reproductibility since last commit

Hi.

I have been working on your repository, and got several results during training that less or more matches the checkpoint you provided. When I pulled few days ago, I tried to reproduce some training I did before and got very bad results. For example, training imagenet with the same net and same parameters got from 78% accuracy to 50% accuracy. The commit that was working for me was :4748c6d, and here is a command of a training that fails to provide decent results (and that did provide very good results before):

python -u -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 ./train.py /data/imagenet/ -b=190 --model-ema --num-gpu=8 -j=16 --amp --model=gluon_resnet50_v1d --sync-bn --mixup=0.2 --sched=cosine

Any idea why?

Thanks

very slow training

Trying to train MNv3 and experience very slow training with one Tesla GPU. Any hints how to improve perfomance?

param # issue

Hi, thanks for the great work!
When I finished efficientnet_b0 training, I found the model size is not 5.29M in your list, but 41M. what is the reason?

Inference on single image (ImageNet)

Hello sir, I am testing EfficientNet-b0 (mine trained from scratch and your pretrained) on random images as well as images within the ImageNet dataset and getting very poor results.

model = create_model(
        "efficientnet_b0",
        pretrained=False,
        num_classes=1000,
        in_chans=3,
        checkpoint_path='./output/quant/b0/model_best.pth')

checkpoint = torch.load("./output/quant/b0/model_best.pth", map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
#out = model(input)

image = Image.open("dog.JPEG")

img_size=224
interpolation='bilinear'
mean=IMAGENET_DEFAULT_MEAN
std=IMAGENET_DEFAULT_STD
crop_pct = 0.875
scale_size = int(math.floor(img_size / crop_pct))

transform = transforms.Compose([
        transforms.Resize(scale_size, interpolation=3),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
)

image_tensor = transform(image).float()
image_tensor = image_tensor.unsqueeze(0)

y = model(image_tensor)

calc_confidence = torch.nn.functional.softmax

conf = calc_confidence(y)
top5_val, top5_idx = conf.topk(5)
top5_idx = top5_idx.cpu().numpy()
for idx, conf_ in zip(top5_idx[0], top5_val[0]):
  print('{}: {:2f}%'.format(labels.cls_idx[idx], conf_ * 100))

where 'labels' is imported from:
https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a

pole: 0.160048%
hook, claw: 0.159218%
Norfolk terrier: 0.155365%
bucket, pail: 0.155316%
sunglasses, dark glasses, shades: 0.154541%

which is the result no matter which image I feed the network

I tried it on pre-trained mobilenetv2 and it had accurate results

my model @ 200 epochs is around 75.31% top1

Any help would be much appreciated, your repo is very nice!

create model error

When I call m = timm.create_model("efficientnet_b0", pretrained=True)
There is an error: RuntimeError: unexpected EOF, expected 72233 more bytes. The file might be corrupted.

I don't know how to fix it.
Please help me.

default training hyper-parameters

Hi,
Impressive work!
The train scipts contains a large combination of various hyper-parameter options.
However, there are different types of models, and even many models are contained even within the efficientnet part. I wonder whether you trained models with default ones. If not, do you plan to release model specific hyper-parameters?
Thanks!

Efficientnetb1-b7 hyper parameters

First of all thanks for the fantastic code!

I am wondering if anyone has successfully reproduce (or close to it) the results for Efficientnetb1-b7? I am able to reproduce b0 with jiefengpeng's setting:
./distributed_train.sh 8 ../ImageNet/ --model efficientnet_b0 -b 256 --sched step --epochs 500 --decay-epochs 3 --decay-rate 0.963 --opt rmsproptf --opt-eps .001 -j 8 --warmup-epochs 5 --weight-decay 1e-5 --drop 0.2 --color-jitter .06 --model-ema --lr .128

The same setting (with adjusted drop rate) for b1 came with only 78.11 (with EMA enabled), compared to 78.8% reported in the paper.

bug in mixup implementation

i think there is a bug in mixup implementation

input.mul_(lam).add_(1 - lam, input.flip(0))

i put lam=0 and got a black picture.
although it is nicer to write it in a one-liner, the order of operations is wrong, and you actually double input.flip(0) by lam*(1-lam)... :)

should change to a two-liner

input_flipped = (1 - lam) * input.flip(0)
input.mul_(lam).add_(input_flipped)

Questions about self-trained results

Hi, thanks for making this awesome training code in pytorch.
I was trying to reproduce some results which can be reproduced using tensorflow, but I'm having hard time when I use pytorch.

I noticed that this repository tried to minimize the difference between pytorch and tensorflow. There are custom RMSPropTF, and even the data transformation looks similar to official mnasnet tensorflow code etc..
But I also noticed that you probably used a little different configuration to make your results. For example, you used bilinear interpolation instead when you trained spnas_100.

I have two questions..!

  1. You tried to minimize the difference between pytorch and tensorflow, am I right..?
  2. Why different configuration..? Can you let us know the configuration you used to make the results..?

SEResNet34 accuracy worse than ResNet34 ?

Hi Ross

Thanks for the great work in putting together this repository with trained weights and training techniques!

I have a quick question about the reported accuracy for SEResNet(34) vs ResNet(34). SEResNet in this case shows a worse accuracy than ResNet. Is it because they were trained at different points in time with different hyperparameters and techniques, or do SEResNets when pushed to the accuracy limits, perform worse than ResNets in general ?

what is the value range of magnitude in auto-augment when the MAX_LEVEL is set as 10.

Dear @rwightman , I have read the code about auto-augmentation and random-augmentation, and I noticed that the MAX_LEVEL is set as 10, same as the google's implementation. Also in the google implementation, they say an optimal magnitude is often in [5, 30]. But in your implementation you clip the input magnitude to be less than MAX_LEVEL (magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range).

Could you give me some hints about why MAX_LEVEL is set as 10, but the input magnitude range is recommended as [5, 30]? Really thanks!

Jan 3, 2019?

description of bug:
in README.md,

the last item title which follows "Dec 30, 2019" reads as
"Jan 3, 2019"
I think the author meant ""Jan 3, 2020"

recreation of bug: trivial(read README.md)
proposed solution : after review, change 2019 into 2020

Trouble reproducing efficientnet_b0 training

I'm trying to reproduce B0, training from scratch. My setup is pytorch 1.1 with V100*8. My command is shown below without any functional change to your training script. The args are borrowed from your comment in other issues:

./distributed_train.sh 8 /tmp/ImageNet/ \
    --model efficientnet_b0 \
    --lr 0.0175 -b 32 \
    --drop 0.2 \
    --img-size 224 \
    --sched step --epochs 400 \
    --decay-epochs 2 --decay-rate 0.975 \
    --opt rmsproptf \
    -j 8 \
    --warmup-epochs 5 --warmup-lr 1e-6 \
    --weight-decay 1e-5 \
    --opt-eps .001 \
    --model-ema

It takes about 12 min per epoch (slow?). I'm currently at epoch 100, which achieves 73.85% top1 prec. It is still ~3 points below your reported number, and the learning curve doesn't seem to rise fast enough.

  • Is it normal to hit only 73.85% at epoch 100? How many epochs are necessary to get to 75% and 76%?
  • I'm using batch size 32 per GPU. Do you recommend larger batch size per GPU? I remember reading somewhere that batchnorm will degrade after you exceed a certain batch size (like 128 something). BN isn't always the larger the better.
  • I don't have good intuition on how to set model-ema-decay for EfficientNet. I'm just using the default 0.9998 for my 32x8 batch. Maybe this is causing degrade?

Thank you so much!

Would you like to show the training details?

Hello! Recently I have trained MobileNetV3, but I can not get the accuracy as the results showed in thre origin paper(mbv3-small, multiplier1.0≈62% vs 67% in the origin paper). Would you like to show your training scripts? Thank you very much!!!

“Self-training with Noisy Student improves ImageNet classification” Can it work?

【Google】Self-training with Noisy Student improves ImageNet classification
Qizhe Xie, Eduard Hovy, Minh-Thang Luong, Quoc V. Le
(Submitted on 11 Nov 2019)

https://arxiv.org/abs/1911.04252

We present a simple self-training method that achieves 87.4% top-1 accuracy on ImageNet, which is 1.0% better than the state-of-the-art model that requires 3.5B weakly labeled Instagram images. On robustness test sets, it improves ImageNet-A top-1 accuracy from 16.6% to 74.2%, reduces ImageNet-C mean corruption error from 45.7 to 31.2, and reduces ImageNet-P mean flip rate from 27.8 to 16.1.

Have you ever noted this? Can it work?THX!

License

Where are you collecting the info about licenses of different models in this repo?

MobilenetV3

Thanks for your excellent work!
I follow ./distributed_train.sh 4 / --model mobilenetv3_100 -b 256 --sched step --epochs 500 --decay-epochs 3 --decay-rate 0.963 --opt rmsproptf --opt-eps .001 -j 8 --warmup-epochs 5 --weight-decay 1e-5 --drop 0.2 --color-jitter .06 --model-ema --lr .064 , the_ accuracy is only 74.7, cannot reach 75.6. Is there a problem with training strategy?

epoch,train_loss,eval_loss,eval_prec1,eval_prec5
0,6.807157278060913,6.90778008392334,0.102,0.5020000006103515
1,5.050369849571815,6.908198606567383,0.098,0.496
2,4.088481426239014,6.912384178161621,0.096,0.498
3,3.7333187139951267,6.933573886260986,0.1,0.502
4,3.6070617253963766,7.005352065582275,0.102,0.496
5,3.5379599057711086,7.2024485604858395,0.108,0.496
6,3.373521878169133,7.632997455596924,0.1,0.504
7,3.3250165627552914,7.853865611114502,0.1,0.502
8,3.2218695420485277,7.600986016540527,0.1,0.524
9,3.1613639776523295,7.4364893447875975,0.088,0.532
10,3.1045892605414758,7.2747615184021,0.104,0.53
11,3.0920635095009437,7.175860616149903,0.124,0.622
12,3.0443460574516883,7.003116539306641,0.362,1.274
13,3.0157195513065043,6.5055738218688965,1.7200000007629395,5.724000004730224
14,3.000295345599835,5.391144183197022,9.15200000213623,22.20800001098633
15,2.9705558189978967,4.057146546173096,23.916000005493164,45.93000009277344
16,2.9603280195823083,3.081855848464966,38.044000006103516,63.08399995605469
17,2.9342497220406165,2.479943684425354,47.724000078125,72.68999994384765
18,2.9035698725627017,2.116020347824097,53.991999978027344,78.07
19,2.897769176042997,1.889836706047058,58.117999931640625,81.31999996582032
20,2.900837540626526,1.7450967617034912,60.86000015136719,83.2879999609375
21,2.875648883672861,1.6505073560333252,62.64800001464844,84.47399998291016
22,2.8701863013781033,1.579269640159607,64.0640000390625,85.51599998046875
23,2.850243330001831,1.5308125699043273,65.10799995605468,86.19999995361329
24,2.8110759808466983,1.4939452957725525,65.85400005859375,86.60800013183594
25,2.8074400975153995,1.4652649970245362,66.44000008300782,86.92799997558593
26,2.8163699095065775,1.4398712505722047,66.9500000024414,87.27800015625
27,2.7757075749910793,1.4162097897148131,67.30999995117187,87.65000018066407
28,2.7732831789897037,1.3985680765724182,67.80599997558593,87.88800010253907
29,2.7760123564646793,1.382454013786316,68.16,88.07200007568359
30,2.7418163189521203,1.3684436791229249,68.40399991943359,88.30399986816406
31,2.7352854013442993,1.3570151011276246,68.6379999194336,88.45799999755859
32,2.7576055985230665,1.3498287722396851,68.74400004882813,88.55600007324219
33,2.738512919499324,1.3405399858856202,69.06200002197265,88.69600004638671
34,2.7094647884368896,1.3309602670669556,69.20999994628906,88.87200001953126
35,2.714123716721168,1.3231018839836122,69.33800010253906,88.94600001953125
36,2.6893110000170193,1.3152600993728638,69.55600002441406,89.01599996826172
37,2.6938162308472853,1.307082928905487,69.72399997314453,89.13199989013673
38,2.7122582563987145,1.302183748550415,69.874,89.21800009765624
39,2.670985139333285,1.2956605754852295,70.03000010253906,89.33200004638672
40,2.689751249093276,1.2902504335403442,70.20200002441406,89.41599999511719
41,2.6894300167377176,1.2844672778511048,70.28200002441406,89.47999996826172
42,2.6576323967713575,1.2795114562416077,70.38200010253907,89.54399994140626
43,2.6550380816826453,1.2755136382484435,70.46800004882813,89.62000004394531
44,2.6638874274033766,1.2723018176651002,70.52999994628907,89.64799999267578
45,2.6185947198134203,1.269956514415741,70.63400002441406,89.67000007080078
46,2.6414913947765646,1.2651101098251343,70.77999996826172,89.75200009521484
47,2.6292767708118143,1.262466173362732,70.83600001953126,89.77800006835938
48,2.603688643528865,1.25764315202713,70.98600009765624,89.89000006835937
49,2.611600023049575,1.2532284469985961,71.07400001953125,89.95400006835938
50,2.6265424214876614,1.246946288356781,71.17600009765626,89.96800014648437
51,2.613306659918565,1.241887101173401,71.21999999511719,90.06800009521484
52,2.600176187661978,1.2374098636627198,71.23799999511719,90.05600009521484
53,2.612116914529067,1.2330633629989625,71.32600004882812,90.10200001708985
54,2.5721828570732703,1.2287508207130433,71.54600004882812,90.19800006835938
55,2.588179890926068,1.2252994434738158,71.49799997070312,90.18800011962891
56,2.5698808156527004,1.2236238661193848,71.60599991699219,90.1960001196289
57,2.5557694160021267,1.2217922224617004,71.61599994384765,90.31800001708984
58,2.5503780199931216,1.219379303073883,71.62000002197266,90.31399999023438
59,2.544643603838407,1.2171901927566529,71.64799996826171,90.30800006835938
60,2.545564761528602,1.21428450920105,71.75800001953125,90.35800006835937
61,2.5454268088707557,1.2108582588386536,71.83600004394532,90.44400011962891
62,2.535468871776874,1.2076356092453002,71.91200004394531,90.5280000415039
63,2.530643133016733,1.205882862701416,71.88999999267578,90.49400011962891
64,2.5182992770121646,1.2035164887809753,71.90000004394531,90.5880001196289
65,2.5161099892396193,1.2005548810005189,71.98999991210937,90.6680001196289
66,2.517294113452618,1.1971218403816224,72.06199991210937,90.69800017089844
67,2.525810195849492,1.194104630947113,72.10400001708985,90.7160001953125
68,2.4934457632211537,1.1917588474082947,72.19200007080079,90.72400009277344
69,2.4995968341827393,1.1903308367538452,72.24800007080079,90.79600009277344
70,2.4921426681371837,1.1873383280754088,72.34000004394531,90.76000001464844
71,2.4875874060850878,1.1848803549957276,72.31199993896485,90.81400009277344
72,2.469612662608807,1.1816198341560364,72.43400006835938,90.8440000415039
73,2.4733768609853892,1.1786696234703065,72.50200001708984,90.84800004150391
74,2.4827969257648173,1.1769603626060485,72.53800004394532,90.86600001464844
75,2.4532883075567393,1.1754435988807679,72.55600004394532,90.88600001464843
76,2.4508722837154684,1.1732804940223693,72.66400004394531,90.89800009277344
77,2.487632760634789,1.1717199251365662,72.78400004394531,90.89800006591797
78,2.4353009278957662,1.1698606922531127,72.83999993896484,90.89800001464843
79,2.449423203101525,1.1679184624290466,72.8640000415039,90.95400009277344
80,2.4576810139876146,1.165514492969513,72.93800004150391,91.02600014404297
81,2.4310616713303785,1.1630925346946717,72.92399993896484,91.05400004150391
82,2.4323965402749868,1.1604746030235291,73.0019999633789,91.01999996337891
83,2.427893170943627,1.1596575518226624,73.03199996337891,91.05600001464843
84,2.418949090517484,1.1574514235687257,73.10199999023438,91.1259999633789
85,2.448477396598229,1.157387068786621,73.0720001171875,91.14600009277343
86,2.43564812036661,1.156338175239563,73.10600001464844,91.1500000415039
87,2.3889316137020407,1.1543081414413452,73.01200006591797,91.14400014404296
88,2.3968455699773936,1.152047341156006,73.03000001464844,91.17399999023438
89,2.4199149792010965,1.1505577443695068,73.07999998779297,91.17199999023437
90,2.3994678350595326,1.148949045715332,73.09199998779297,91.1900001196289
91,2.3858356659228983,1.1472585159683228,73.17200006591797,91.17400017089844
92,2.391085487145644,1.1460050506210326,73.2660000390625,91.24800011962891
93,2.3828944242917576,1.1451176188087464,73.25599993652344,91.2320001196289
94,2.378520507078904,1.1441639073371888,73.29399991210937,91.22199996582032
95,2.3641012631929836,1.142308811416626,73.27399999023437,91.29000006835938
96,2.3812367549309363,1.1402403129577636,73.41999996337891,91.30400009277344
97,2.3666531581145067,1.1383883423042298,73.46799993652344,91.32200009277344
98,2.3729229982082662,1.1365342340087892,73.48400001464844,91.36400006835937
99,2.3416601419448853,1.1349524842834473,73.45599993652344,91.3520000415039
100,2.350464866711543,1.1353474979782106,73.49199993652344,91.33600004150391
101,2.3492721227499156,1.1336183923912049,73.42400006591797,91.39599996337891
102,2.346577369249784,1.1327397372817993,73.48600006591796,91.33600004150391
103,2.354083015368535,1.131389742641449,73.5359998852539,91.4120000415039
104,2.3353832960128784,1.130198454246521,73.63400001464844,91.4300000415039
105,2.337353046123798,1.1284184492111207,73.7200000390625,91.44399999023437
106,2.319014833523677,1.126461137008667,73.69999998779296,91.45400001708984
107,2.331433351223285,1.1259556946182252,73.7100000390625,91.54600001708984
108,2.323273943020747,1.1243394508743285,73.70999998779297,91.59199999023437
109,2.3238758398936343,1.1235428905677796,73.67399993652344,91.63800011962891
110,2.319566644155062,1.1221704779243469,73.74800001464844,91.59600004150391
111,2.309805604127737,1.1204168052482606,73.87600001464844,91.6220001171875
112,2.312829329417302,1.1194741576385498,73.90000001464844,91.6840001953125
113,2.308381887582632,1.1184360822486878,73.92400006591797,91.6540001953125
114,2.303919452887315,1.1174426330947875,73.98800001464843,91.6280001953125
115,2.2909882435431848,1.117208812561035,73.91000009277344,91.62200014404297
116,2.2770427740537205,1.1161277367782594,73.80999993896485,91.62000009277344
117,2.312555588208712,1.1152151550292968,73.85800006835937,91.63400017089843
118,2.2784791818031898,1.1140099111938477,73.9120000415039,91.66400009277343
119,2.266396265763503,1.1137916134643555,73.97600006835937,91.67599999023437
120,2.275199422469506,1.1129420568466186,74.00199999023438,91.72599999023437
121,2.2749333748450646,1.1128695734405518,73.94800001464844,91.68200011962891
122,2.2812217198885403,1.111792557373047,73.92600001708985,91.67800009277343
123,2.277574823452876,1.1100896430587768,73.99800006835937,91.6320000415039
124,2.2698056606146007,1.1084838445281981,74.00399988769531,91.71800011962891
125,2.261097632921659,1.107766972808838,73.94400004150391,91.75400011962891
126,2.2644664507645826,1.1068881179618835,73.95600004150391,91.76600014648437
127,2.2546519132760854,1.1059809099388123,74.02399993652344,91.75400006835937
128,2.2532090498850894,1.105235412940979,74.08400009033203,91.70200009521484
129,2.24976790868319,1.1046096514320374,74.10199998779296,91.71600012207031
130,2.2354619686420145,1.1038959642791748,74.1540000390625,91.75400014648437
131,2.2523352274527917,1.1033907950592041,74.17800009033203,91.77599991455078
132,2.239263653755188,1.103264320602417,74.2220000366211,91.77199999267579
133,2.22566556930542,1.102758490562439,74.20400001220703,91.75800004394532
134,2.232908698228689,1.1017799267196655,74.22200001220703,91.80600012207032
135,2.227965419109051,1.1017435648155212,74.1979999609375,91.88000004394532
136,2.242378445772024,1.1008160496711732,74.29199995849609,91.93800001708985
137,2.2392307886710534,1.0993753656768799,74.30600006347656,91.89399996582031
138,2.218605169883141,1.098401421394348,74.2200000390625,91.91599996582032
139,2.2219124023730936,1.0975180311393737,74.2359999609375,91.95200014648438
140,2.1995630172582774,1.0969582042503356,74.2279999609375,91.85600004394531
141,2.216558438081008,1.0965148325538636,74.24400006591797,91.86800006835938
142,2.201518801542429,1.0956385377311706,74.30400001464844,91.89000004150391
143,2.2095043384111843,1.0955083542251587,74.23999998779297,91.88399999023437
144,2.2081411435053897,1.0950895781898498,74.29399993652343,91.90599999023438
145,2.1888111371260424,1.0943855343437194,74.3400000390625,91.84200006835937
146,2.204651190684392,1.0940196996879579,74.29200001220703,91.82800011962891
147,2.185206422438988,1.0931081489562988,74.29200009033202,91.91400011962891
148,2.19277509359213,1.0935040406417846,74.35600001220703,91.9220001196289
149,2.179109261586116,1.0928431137084962,74.3300000390625,91.95200001708984
150,2.1944129191912136,1.0925009962081909,74.39800001220704,91.90999993896484
151,2.181222383792584,1.0922499877166747,74.40800003662109,91.91800001708984
152,2.179825791945824,1.0909117463302613,74.51600006347657,91.96999993896485
153,2.1657151534007144,1.0900434015655518,74.4980000366211,91.97399996582031
154,2.1426210586841288,1.089547190361023,74.4479999609375,91.98399993896484
155,2.16977203809298,1.089533913898468,74.46800014160156,92.02000009521484
156,2.151352671476511,1.089706090145111,74.46999998779297,92.03600009521485
157,2.1578336495619554,1.0895013568878174,74.4580000390625,92.04400001708984
158,2.1505370690272403,1.0896922289466857,74.4759999609375,92.01400001708984
159,2.150887581018301,1.0885500703620912,74.48600006347657,92.02400001708985
160,2.1470747544215274,1.0881524993133544,74.55000006347656,92.01600001708984
161,2.157308495961703,1.0868901877403259,74.58600009033204,91.9600001196289
162,2.1394928143574643,1.086728733253479,74.5359999609375,91.98000011962891
163,2.1365932409579935,1.0865377537155152,74.5459999609375,91.98999993896484
164,2.1314157247543335,1.0864816753578186,74.52200006347657,91.96400004150391
165,2.123403237416194,1.087101491317749,74.54399990966797,91.9900000415039
166,2.144001612296471,1.0865108477401733,74.5160000390625,92.0560001196289
167,2.1436057274158182,1.085620917701721,74.4820000390625,92.09200001708984
168,2.1267095437416663,1.0850480150794983,74.54200001220703,92.06600001708985
169,2.1371524975850034,1.084764951171875,74.56600001464844,92.04000001708984
170,2.138287452551035,1.0852635061454774,74.49799993896484,92.03600006835937
171,2.116690617341262,1.0854216492843627,74.48800006591797,92.00600006835937
172,2.125142271702106,1.085345158100128,74.4180001171875,91.99199993896484
173,2.118704007222102,1.0846420208358765,74.39799998779297,92.00199999023438
174,2.1178802160116343,1.0842044367408752,74.4540000390625,91.99799993896484
175,2.120506635079017,1.084026732978821,74.39200001464843,92.00599993896485
176,2.115511967585637,1.0842051233100891,74.44399990966797,92.00399999023438
177,2.1164698692468495,1.0847807112503052,74.4500000390625,91.97800009277344
178,2.103126571728633,1.0854538097000122,74.44000006591797,91.96800009277344
179,2.1043120439235983,1.0855156500625611,74.4320001171875,91.96400009277343
180,2.0988125067490797,1.0858051753425597,74.41800001464844,91.99800009277344
181,2.0941573473123403,1.0854854698944092,74.42000004150391,92.0160000415039
182,2.108525046935448,1.085402806892395,74.48400006591797,92.03600004150391
183,2.0855122896341176,1.0855355533981323,74.57599998779297,92.01599999023438
184,2.093753851377047,1.085414316482544,74.47999998779297,92.03599999023437
185,2.1076699678714457,1.085764733543396,74.46999998535156,92.02799999023438
186,2.0897318766667294,1.0858629461479188,74.4759999609375,92.04000009277344
187,2.088271892987765,1.0859300466156006,74.5220000390625,92.05200009277344
188,2.094233356989347,1.085611225719452,74.5320000390625,92.05000001464843
189,2.0826735496520996,1.0861849008369446,74.5179999609375,92.04000001464844
190,2.0652174812096815,1.0862134606552123,74.59600009033203,92.04200001464844
191,2.0876417343433085,1.0857337873649597,74.61399993408203,92.04199996337891
192,2.0855307762439432,1.08521177028656,74.5959999584961,92.08200001464844
193,2.076764560662783,1.0847806893730163,74.70800006103515,92.08200001464844
194,2.0756751665702233,1.0848990515518189,74.72200000732421,92.08000001464843
195,2.063154078446902,1.0851755967712402,74.67199998291015,92.07600001464844
196,2.060938651745136,1.085455788116455,74.68000006103516,92.06000001464844
197,2.076886846468999,1.0855616837120057,74.71399995849609,92.05800006591797
198,2.0626262105428257,1.085580121154785,74.7220000366211,92.03399991210938
199,2.0431891954862156,1.0855276605415345,74.6879999609375,92.01200001464844
200,2.0542377279354977,1.0852036755752563,74.64999998535156,92.02999996337891
201,2.0652991166481605,1.0852199787330628,74.6859999609375,92.04200006835937
202,2.068277808336111,1.0855673571968079,74.6999999609375,92.06999999023438
203,2.0548804585750284,1.0862246788978576,74.70600001220703,92.06599999023437
204,2.068242173928481,1.0868803758621215,74.7239999609375,92.00200006835938
205,2.034245403913351,1.086896453151703,74.6780000366211,92.05199999023438
206,2.0448381121342,1.0873380460357667,74.63800006347657,92.06600006835937
207,2.0357320721332846,1.0874692331695557,74.6379999609375,92.07999988769531
208,2.043750529105847,1.087085955581665,74.56000001220703,92.07200006835937
209,2.0323717777545633,1.086833318939209,74.60400003662109,92.06999988769532
210,2.0327078975163975,1.0864537406730652,74.67200000976563,92.05399988769531
211,2.0303796850717983,1.0863107000541687,74.58400003662109,92.07999999023437
212,2.0312073551691494,1.0868104565048218,74.55000011230469,92.02599993896484
213,2.030185791162344,1.0869530696868897,74.5580000366211,92.09800001708984
214,2.018157550921807,1.0868383379173279,74.57800001220703,92.13600001708984
215,2.0411546459564796,1.0867106021308899,74.5460001171875,92.12200001708985
216,2.024625686498789,1.0869791983032226,74.5640000390625,92.14600004150391
217,2.026153807456677,1.08763531791687,74.5459999609375,92.08600006835937
218,2.0231489768395057,1.0877682425689696,74.5980000390625,92.09399991455078
219,2.012903488599337,1.0878160471725464,74.56999998779297,92.04999991455078
220,2.0080551963586073,1.087732875518799,74.5380001171875,92.05800006835938
221,1.9911203659497774,1.087887629623413,74.5380000390625,92.06800006835938
222,2.0156690570024343,1.0878703461456298,74.5400001171875,92.04399999023437
223,2.0063899572079,1.0878751870918273,74.53599998779296,92.06799999023437
224,2.0017235783430247,1.0880411310195923,74.53000009277343,92.05999999023437
225,1.9885269036659827,1.088345354652405,74.54199999023437,92.05199999023438
226,1.9946725872846751,1.0889458667182923,74.50199993652343,92.05000017089844
227,1.987199292733119,1.0895495727729798,74.55000001464843,92.02199999023438
228,1.9774265885353088,1.090521994857788,74.55000001464843,92.0440000415039
229,2.0050827310635495,1.0905523964309691,74.51000001464844,91.98399999023438
230,1.978372940650353,1.090744952659607,74.55599993652343,91.98400001708984
231,1.9689390842731183,1.0916994673919678,74.57799985839844,92.01199996582031
232,1.9917400616865892,1.092270773563385,74.55200001464844,91.99799996582031
233,1.9931343564620385,1.0927853950691224,74.56199993652343,91.98599996582031
234,1.9902563049243047,1.0936042615127564,74.50999998779297,92.00399988769531
235,1.97390040067526,1.0940121513748169,74.49400001220702,92.00800006835938
236,1.9864924916854272,1.094450903892517,74.52599998535156,91.99600006835938
237,1.9699827891129713,1.0940869623374938,74.50999998535156,91.97600001708985
238,1.9734247831197886,1.0940573473167419,74.49999998535156,92.00199999023438
239,1.9648521680098314,1.0940183190727233,74.5419998828125,91.95999993896484
240,1.9831855618036711,1.093936259803772,74.56000008789063,91.95000001708985
241,1.9733638213231013,1.09378001745224,74.5460000366211,91.9639999658203
242,1.9813134945355928,1.0940869575881957,74.48400000976562,91.95599996582031
243,1.9581326429660504,1.094584370994568,74.51199998535157,91.99399996582031
244,1.9586645456460805,1.0950595213127137,74.48600006347657,92.02199993896484
245,1.9671411101634686,1.09523440826416,74.49800001220703,91.98999993896484
246,1.974704453578362,1.0956531619262695,74.46600006347656,91.95599988769531
247,1.9608415044271028,1.0956848133659363,74.4780000366211,91.96999988769531
248,1.9716106332265413,1.0959427178192138,74.55599998535156,91.95399988769532
249,1.956282982459435,1.0967141045379638,74.5380000366211,91.94399988769531
250,1.9584733889653132,1.0975952058410645,74.53399998535156,91.9159998876953
251,1.960865162886106,1.098129497718811,74.5759999584961,91.95999988769532
252,1.946923851966858,1.0987713743782044,74.52800006347657,91.93600006835938

How to export to ONNX?

Hi, how to export the trained checkpoint / model to ONNX? For example I have trained the efficientnet successfully and I want to convert this model to ONNX.

Is the code below the right way to do it? Especially I have confusion in how to load the model definition, usually I only need to load model from torchvision.models.

import torch
import torch.onnx
import torchvision
import timm

device = 'cuda'
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=30)
model.load_state_dict(torch.load('model_best.pth.tar', map_location=device))
dummy_input = torch.randn(1, 3, 224, 224)
model.eval()
torch.onnx.export(model, dummy_input, "efficientnet.onnx")

With above code I got the following error :
Unexpected key(s) in state_dict: "epoch", "arch", "state_dict", "optimizer", "args", "version", "metric".

if I add strict=False into the model.load_state_dict the error will be gone, but I'm not sure whether this is the right way to convert to onnx. Any help will be appreciated, thanks.

No provisioning to run on a CPU

I know nobody would want to run image models on a CPU, but I wanted to edit the architecture locally before deploying it on my cloud but it appears to fail:

seresnet = timm.create_model(model_name="seresnet34", pretrained=True, num_classes=4)

Fails with-

C:\Users\makul\AppData\Local\Continuum\anaconda3\python.exe C:/Users/makul/PycharmProjects/kaggle/edge/work/seresnet/seresnet.py
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth" to C:\Users\makul/.cache\torch\checkpoints\seresnet34-a4004e63.pth
Traceback (most recent call last):
  File "C:/Users/makul/PycharmProjects/kaggle/edge/work/seresnet/seresnet.py", line 34, in <module>
    model = create_resnet_with_bottleneck()
  File "C:/Users/makul/PycharmProjects/kaggle/edge/work/seresnet/seresnet.py", line 15, in create_resnet_with_bottleneck
    seresnet = timm.create_model(model_name="seresnet34", pretrained=True, num_classes=4)
  File "C:\Users\makul\AppData\Local\Continuum\anaconda3\lib\site-packages\timm\models\factory.py", line 37, in create_model
    model = create_fn(**margs, **kwargs)
  File "C:\Users\makul\AppData\Local\Continuum\anaconda3\lib\site-packages\timm\models\senet.py", line 423, in seresnet34
    load_pretrained(model, default_cfg, num_classes, in_chans)
  File "C:\Users\makul\AppData\Local\Continuum\anaconda3\lib\site-packages\timm\models\helpers.py", line 65, in load_pretrained
    state_dict = model_zoo.load_url(default_cfg['url'], progress=False)
  File "C:\Users\makul\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\hub.py", line 463, in load_state_dict_from_url
    return torch.load(cached_file, map_location=map_location)
  File "C:\Users\makul\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\serialization.py", line 386, in load
    return _load(f, map_location, pickle_module, **pickle_load_args)
  File "C:\Users\makul\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\serialization.py", line 573, in _load
    result = unpickler.load()
  File "C:\Users\makul\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\serialization.py", line 536, in persistent_load
    deserialized_objects[root_key] = restore_location(obj, location)
  File "C:\Users\makul\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\serialization.py", line 119, in default_restore_location
    result = fn(storage, location)
  File "C:\Users\makul\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\serialization.py", line 95, in _cuda_deserialize
    device = validate_cuda_device(location)
  File "C:\Users\makul\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\serialization.py", line 79, in validate_cuda_device
    raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

Process finished with exit code 1

Maybe a parameter to pass in the device location would fix it

why the params size of b0 model is 20M rather than 5.3M

Hi, @rwightman ~
Thank you for sharing your code.
But I have a question. The paper said that the Params size of the efficientnet-b0 model is 5.29M,but the size of pretrained model what you offered is 20.3M,
and then I test it through torchsummary, which also shows that Params size (MB): 20.17, why?


#torchsummary code:
import timm
modelBody = timm.create_model('efficientnet_b0', pretrained=True).to(device)
from torchsummary import summary
summary(modelBody, input_size=(3, 224, 224))


#result
Total params: 5,288,548
Trainable params: 5,288,548
Non-trainable params: 0
Input size (MB): 0.57
Forward/backward pass size (MB): 124.93
Params size (MB): 20.17
Estimated Total Size (MB): 145.68


can you share hyper-parameters for mixnet?

Thanks for your impressive work!
Can you share your hyper-parameters for mixnet-m (77.256% top 1 reported here.)?
Google said they used the same tricks as mnas-net, I cannot reproduce that high top 1 acc following their settings.

No improvement in accuracy !!

result:
epoch,train_loss,eval_loss,eval_prec1,eval_prec5

0,1.236780494476995,0.7422786220911346,76.72929120409906,93.21093082835183

1,0.5891604418683256,0.32295734967867806,89.58155423367155,99.23142613151153

2,0.49250254085940176,0.369856527328084,87.53202391118703,98.59094790777114

3,0.5022595248670659,0.30262660618830095,90.13663535439795,98.932536293766

4,0.5062575871618386,0.42040570814521977,86.5072587532024,98.6336464560205

5,0.5522724993717976,0.47850830970693103,82.83518360375747,94.91887275832622

6,0.5353751410276462,0.47170425576088676,82.87788215200683,95.21776259607174

7,0.5296994199610164,0.4946806381061279,80.7002561912895,95.68744662681469

8,0.5263250684126829,0.5145736786669691,82.40819812126388,94.53458582408199

9,0.523469217465474,0.419988414003847,84.58582408198122,96.199829205807

10,0.5286529584572865,0.43731096309286016,84.03074295473954,96.24252775405637

11,0.5198398996367414,0.47687894446945517,83.04867635151955,95.90093936806149

12,0.5214749731823929,0.49508275170544086,81.55422715627668,95.60204953031597

13,0.5200576116386642,0.5131631942575656,80.4867634500427,94.53458582408199

14,0.5269585337139603,0.4538466817771747,83.26216908625106,96.79760888129803

15,0.51254571291626,0.4498334223075263,82.87788215200683,97.22459436379162

16,0.5205180874747088,0.43391486263041,84.67122117847993,97.18189581554228

17,0.5241914974573331,0.44768022860374135,82.57899231426131,96.75491033304867

18,0.5190987390075993,0.3754906276675191,86.76345004269855,98.37745516652434

19,0.52161620198152,0.5598128839601929,77.36976943435474,94.36379163759983

20,0.5262750121263358,0.4470690688822994,83.94534585824081,97.43808710503843

21,0.5272856320580865,0.4441590788245303,82.96327924850556,96.6268146883006

22,0.5244932282684196,0.4157507783589355,85.05550811272417,97.95046968403074

23,0.5278004272371276,0.4044609047427939,85.95217762596072,98.33475661827498

24,0.5286687795423035,0.42453585003910665,85.39709649871904,97.65157984628523

25,0.5267966113284103,0.38768093757503164,86.46456021798359,97.18189581554228

26,0.5280775483602133,0.4554431226063141,84.75661827497865,97.43808710503843

27,0.516062701639966,0.37677112987136147,86.63535439795046,97.65157984628523

28,0.5188340039844187,0.4685350774344372,82.19470538653236,95.90093936806149

29,0.5255055600761348,0.43660790125189636,85.52519214346712,96.92570452604612

30,0.46025000082121953,0.3342584821973775,88.55678906917164,98.88983774551666

31,0.42857470713619494,0.3014380550035852,90.35012809564475,99.23142613151153

32,0.4174317510209532,0.3104757423791796,89.8804440649018,98.932536293766

33,0.41206431961976564,0.32924012689637283,89.36806148590948,98.29205807002562

34,0.4107723833913477,0.2913964033432235,90.35012809564475,99.18872758326216

35,0.4065763853554033,0.3079596189392315,89.83774551665243,98.88983774551666

36,0.3994874527821174,0.3052167390194029,90.13663535439795,99.06063193851409

37,0.40214641817614566,0.30504225219268255,90.13663535439795,98.97523484201537

38,0.3973450698913672,0.2961104305049072,90.43552519214347,98.932536293766

39,0.39850940541324453,0.3157652917768994,89.15456874466268,98.37745516652434

40,0.38718086768928756,0.30122218622701175,89.75234842015371,98.71904355251921

41,0.3897520958358406,0.3065028481700805,89.79504696840307,99.06063193851409

42,0.38706244617445856,0.32017539155269464,89.53885567890691,98.33475661827498

43,0.3883484112910735,0.31547509009745467,89.4534585824082,98.80444064901793

44,0.38504144256441003,0.30608306833432336,89.49615713065755,98.8471391972673

45,0.38138950963815055,0.32615280828345644,88.85567890691716,98.33475661827498

46,0.3795166269581542,0.303617345068976,89.4534585824082,98.67634500426985

47,0.3816058428369017,0.2970176398766845,90.26473099914602,99.23142613151153

48,0.3786537225684549,0.28595576501498354,90.39282664389411,99.10333048676345

49,0.3790646289403622,0.3038425553465989,89.92314261315116,98.67634500426985

50,0.3804647622964321,0.29870414616815255,89.79504696840307,98.76174210076857

51,0.37879164590794817,0.2824969822008013,90.43552519214347,98.6336464560205

52,0.38108547449621377,0.2989366342568886,90.00853970964987,98.8471391972673

53,0.37669088766615616,0.3307387924965481,88.89837745516652,98.37745516652434

54,0.3788691380849251,0.30334060391638645,90.30742954739539,98.4201537147737

55,0.3772211962021314,0.30208189246570016,90.09393680614859,99.10333048676345

56,0.3792888462288767,0.3121899732417983,89.75234842015371,99.01793339026473

57,0.3719124208148728,0.3037089667228121,89.41076003415884,98.33475661827498

58,0.37400464902066777,0.3048031748957251,89.92314261315116,99.27412467976089

59,0.37563188559988625,0.3008926833966265,89.49615713065755,99.31682322801025

60,0.3702061493936767,0.296091295829652,90.39282664389411,99.23142613151153

61,0.36369912502093193,0.30105791837740514,90.39282664389411,99.18872758326216

62,0.3641057296950593,0.30164378056864166,90.17933390264731,98.88983774551666

63,0.3619639981760938,0.3019152488861707,89.58155422715627,98.97523484201537

64,0.3580131956654736,0.3011659981792651,89.4534585824082,98.80444064901793

65,0.3562619970395015,0.29155933362591685,90.05123825789923,98.97523484201537

66,0.3579206373447027,0.29480031828128,90.43552519214347,98.932536293766

67,0.35348330667385686,0.3068652048159827,89.70964987190436,98.54824935952178

68,0.356836290186287,0.2988827384238463,90.26473099914602,98.80444064901793

69,0.35784130547291193,0.30570485558675764,89.58155422715627,99.01793339026473

70,0.35317992988305213,0.30819127883008157,89.62425277540564,98.88983774551666

71,0.3552428407292081,0.2963386371948086,89.79504696840307,99.10333048676345

72,0.3556893689255429,0.28708859375279544,90.43552519214347,99.06063193851409

73,0.3554823259003142,0.29524144497210913,90.56362083689154,98.8471391972673

74,0.357293913507054,0.31387581278583515,89.49615713065755,98.59094790777114

75,0.35231152444313735,0.29306472874827916,90.52092228864218,98.6336464560205

76,0.3549654332491068,0.29314906997598206,90.56362083689154,98.80444064901793

77,0.3519476456519885,0.302172562334906,90.05123825789923,98.88983774551666

78,0.3533640999569852,0.29685002620738354,90.09393680614859,98.97523484201537

79,0.34736158908941805,0.299040100743229,90.09393680614859,98.932536293766

80,0.35200409822993806,0.28706298853142664,90.43552519214347,99.23142613151153

81,0.35231136908897986,0.3089747596929768,89.75234842015371,98.71904355251921

82,0.35180156401589385,0.30773886597304134,90.09393680614859,98.37745516652434

83,0.3532103667147139,0.300529144233511,90.13663535439795,98.8471391972673

84,0.34947473923874717,0.29519182510425857,90.13663535439795,99.06063193851409

85,0.3498597285941116,0.29852520767997004,90.43552519214347,98.67634500426985

86,0.3483942790163888,0.3120676435905578,89.53885567890691,98.4201537147737

87,0.3477894481685426,0.2945919472638304,90.05123825789923,98.932536293766

88,0.34972514459210585,0.3094837686427122,89.62425277540564,98.46285226302305

89,0.34978164026879854,0.30467455521111403,90.05123825789923,98.80444064901793

90,0.3490699131774087,0.30509472122999665,89.62425277540564,98.59094790777114

91,0.34640579913925923,0.30245392181672776,90.22203245089666,98.67634500426985

92,0.3502828779638323,0.3059284418350504,89.79504696840307,98.71904355251921

93,0.34789298828850446,0.3028878459721218,90.22203245089666,98.67634500426985

94,0.3426535091848455,0.30454785675192164,90.05123825789923,98.67634500426985

95,0.3456347489458883,0.2977035122038359,90.22203245089666,98.67634500426985

96,0.34746757136960316,0.30205830114072046,90.47822374039282,98.71904355251921

97,0.3415108760452678,0.29538114084508255,90.52092228864218,99.01793339026473

98,0.3463334205058905,0.30909938002553994,89.62425277540564,98.67634500426985

99,0.3493201870439399,0.31277492522149897,89.62425277540564,98.50555081127241

100,0.34843685874062724,0.3044554495931079,90.13663535439795,98.76174210076857

101,0.34774928485226425,0.3055654924067038,90.00853970964987,98.59094790777114

102,0.34850498577977856,0.29750282768506864,90.05123825789923,98.97523484201537

103,0.3478707997717409,0.29895274335380095,90.00853970964987,98.8471391972673

104,0.34697830931753176,0.3081327899781483,90.05123825789923,98.50555081127241

105,0.3482728877892861,0.29918750206900907,89.8804440649018,98.71904355251921

106,0.349467678508188,0.30765957653802445,89.92314261315116,98.6336464560205

107,0.3469078950392894,0.3077476313234292,89.96584116140052,98.4201537147737

108,0.3455473939068297,0.2998926101600788,90.00853970964987,98.6336464560205

109,0.3467843368522122,0.28927737134043524,90.43552519214347,98.932536293766

110,0.3482506826137885,0.30868624103033676,89.92314261315116,98.4201537147737

111,0.3472767389228201,0.3121946662295384,89.83774551665243,98.46285226302305

112,0.3478902287462838,0.3017470318979936,90.17933390264731,98.59094790777114

113,0.34667579292232154,0.29794618499446995,90.09393680614859,98.88983774551666

114,0.3480160056016384,0.30356935433324433,90.09393680614859,98.59094790777114

115,0.34320466332455984,0.3097499821694563,89.53885567890691,98.50555081127241

116,0.34662670671430407,0.29823985239641976,90.05123825789923,98.88983774551666

117,0.34547111611080983,0.30261922589511975,89.83774551665243,98.80444064901793

118,0.34876360625792774,0.30093205015848684,90.17933390264731,98.59094790777114

119,0.3450652132941108,0.3104037162549571,89.666951323655,98.54824935952178

120,0.34529478585618173,0.2984814216230895,90.30742954739539,98.59094790777114

121,0.34804036770111474,0.29763796744211224,90.26473099914602,98.8471391972673

122,0.34398792340205264,0.30257368388082306,89.79504696840307,98.67634500426985

123,0.3457088812803611,0.2971319091577086,90.17933390264731,98.67634500426985

124,0.3433658093723476,0.3044278937321153,90.13663535439795,98.37745516652434

125,0.3437653479922531,0.29761536593090054,90.26473099914602,98.54824935952178

126,0.3443275806995539,0.29704202331099927,90.05123825789923,98.67634500426985

127,0.341993944079448,0.30272179439243957,90.22203245089666,98.76174210076857

128,0.35011408637222063,0.3035676547546452,90.05123825789923,98.54824935952178

129,0.3462423269565289,0.3076208256674973,89.70964987190436,98.59094790777114

130,0.34525177134917334,0.31376688567290034,89.70964987190436,98.4201537147737

131,0.3470541987918381,0.3001238174870114,90.30742954739539,98.76174210076857

132,0.35105727365893175,0.30577701456800266,89.666951323655,98.76174210076857

133,0.34343119199968813,0.3013189581153639,89.92314261315116,98.71904355251921

134,0.34964901292935396,0.30014059984027375,90.13663535439795,98.67634500426985

135,0.3466490490314288,0.30473095387296734,89.8804440649018,98.50555081127241

136,0.34078803001305996,0.3015655089872806,90.30742954739539,98.71904355251921

137,0.3438590900256084,0.3039952846455533,89.8804440649018,98.59094790777114

138,0.3436705451235812,0.30031605075685314,90.22203245089666,98.71904355251921

139,0.34387522948603344,0.29580841515051004,90.26473099914602,98.76174210076857

140,0.3500936969732627,0.30117615942233245,90.30742954739539,98.80444064901793

141,0.3440368336235356,0.3023039072077566,90.00853970964987,98.67634500426985

142,0.3458517837982911,0.3063893628606523,89.666951323655,98.71904355251921

143,0.34306892755194607,0.3037028131413215,90.05123825789923,98.67634500426985

144,0.35013412891799567,0.29651374896681276,90.26473099914602,98.80444064901793

145,0.3439970703970673,0.30488468936933194,89.92314261315116,98.4201537147737

146,0.3462661385281473,0.2997296505967946,90.22203245089666,98.76174210076857

147,0.3492571253807117,0.3168051920761215,89.666951323655,98.46285226302305

148,0.35103419137816144,0.2990837080307133,89.92314261315116,98.50555081127241

149,0.3441524163015887,0.29870963920572696,90.09393680614859,98.59094790777114

config:
aa: null
amp: false
batch_size: 8
bn_eps: null
bn_momentum: null
bn_tf: false
color_jitter: 0.4
cooldown_epochs: 10
data: F:\dataset\test_data\person_clothes_image
decay_epochs: 30
decay_rate: 0.1
drop: 0.0
drop_connect: 0.0
epochs: 150
eval_metric: prec1
gp: avg
img_size: 224
initial_checkpoint: ''
interpolation: ''
local_rank: 0
log_interval: 50
lr: 0.1
mean: null
min_lr: 1.0e-05
mixup: 0.0
mixup_off_epoch: 0
model: mixnet_xl
model_ema: false
model_ema_decay: 0.9998
model_ema_force_cpu: false
momentum: 0.9
no_prefetcher: false
no_resume_opt: false
num_classes: 3
num_gpu: 1
opt: sgd
opt_eps: 1.0e-08
output: ''
pretrained: true
recount: 1
recovery_interval: 0
remode: pixel
reprob: 0.0
resume: ''
save_images: false
sched: step
seed: 42
smoothing: 0.1
start_epoch: null
std: null
sync_bn: false
tta: 0
warmup_epochs: 5
warmup_lr: 0.0001
weight_decay: 0.0001
workers: 4

about drop_connect_rate

first, thank you for the great work
in another implementation,i found:

for idx, block in enumerate(self._blocks):
            drop_connect_rate = self._global_params.drop_connect_rate
            if drop_connect_rate:
                drop_connect_rate *= float(idx) / len(self._blocks)
            x = block(x, drop_connect_rate) 

the drop_connect_rate changed in each blocks,but i didn't found similar code in your implementation
so,what the matter ?

about your results

i'm training efficientnet-b0 based on your code.
i wonder if your results in Self-trained Weights table are ema validation or not.

thank u :)

Finetune Parameters

Hi rwightman,
First, many thanks for your Excellent works!!!
I want to use your pretrained models for other image classification task, would you please give some advise on finetune parameters?
Best,
hungsing

pretrained efficientNet

Thank you very much for your excellent work. I am very interested in this project. Could you please upload a pretrained model?

Train my own searched network

Thanks for your excellent work!
If I want to train a neural network that I have searched for and want to adopt the same training strategy as mnasnet, should the network definition be added to the gen_efficientnet.py file? Is there any difference in training between the gen_efficientnet.py file and other files (such as resnet.py)? Is the bn layer parameter different?

Unknown model mixnet_xl

In [2]: model = timm.create_model("mixnet_xl", False)

RuntimeError Traceback (most recent call last)
in
----> 1 model = timm.create_model("mixnet_xl", False)

~/.pyenv/versions/3.7.3/lib/python3.7/site-packages/timm/models/factory.py in create_model(model_name, pretrained, num_classes, in_chans, checkpoint_path, **kwargs)
37 model = create_fn(**margs, **kwargs)
38 else:
---> 39 raise RuntimeError('Unknown model (%s)' % model_name)
40
41 if checkpoint_path:

RuntimeError: Unknown model (mixnet_xl)

some questions about validate.py, i test mixnet by using validate.py.

i use python validate.py /home/glp/Documents/project/pytorch-image-models master/data/Imagenet/ILSVRC2012_img_val --model mixnet_l --pretrained to test mixnet, but i discover mixnet's prec is too low.
i don't know how to use it .

Model mixnet_l created, param count: 7329252
Data processing configuration for current model + dataset:
input_size: (3, 224, 224)
interpolation: bicubic
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)
crop_pct: 0.875
THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/THC/THCGeneral.cpp line=663 error=11 : invalid argument
Test: [ 0/391] Time: 1.459s (1.459s, 87.71/s) Loss: 9.4615 (9.4615) Prec@1: 0.000 ( 0.000) Prec@5: 0.000 ( 0.000)
Test: [ 10/391] Time: 0.379s (0.478s, 267.70/s) Loss: 9.5713 (9.6075) Prec@1: 0.000 ( 0.000) Prec@5: 0.000 ( 0.497)
Test: [ 20/391] Time: 0.379s (0.431s, 296.85/s) Loss: 9.7116 (9.6221) Prec@1: 0.000 ( 0.037) Prec@5: 0.000 ( 0.298)
Test: [ 30/391] Time: 0.380s (0.415s, 308.77/s) Loss: 9.5678 (9.6172) Prec@1: 0.000 ( 0.025) Prec@5: 0.000 ( 0.328)
Test: [ 40/391] Time: 0.380s (0.406s, 315.22/s) Loss: 9.7633 (9.6187) Prec@1: 0.000 ( 0.019) Prec@5: 0.000 ( 0.362)
Test: [ 50/391] Time: 0.378s (0.401s, 319.35/s) Loss: 9.7678 (9.6187) Prec@1: 0.000 ( 0.015) Prec@5: 0.000 ( 0.337)
Test: [ 60/391] Time: 0.379s (0.397s, 322.18/s) Loss: 9.6133 (9.6177) Prec@1: 0.000 ( 0.038) Prec@5: 0.000 ( 0.320)
Test: [ 70/391] Time: 0.379s (0.395s, 324.19/s) Loss: 9.5668 (9.6153) Prec@1: 0.000 ( 0.033) Prec@5: 1.562 ( 0.319)
Test: [ 80/391] Time: 0.378s (0.393s, 325.75/s) Loss: 9.5332 (9.6139) Prec@1: 0.000 ( 0.077) Prec@5: 0.000 ( 0.338)
Test: [ 90/391] Time: 0.381s (0.392s, 326.91/s) Loss: 9.4446 (9.6138) Prec@1: 0.000 ( 0.077) Prec@5: 0.000 ( 0.335)
Test: [ 100/391] Time: 0.379s (0.390s, 327.82/s) Loss: 9.7374 (9.6134) Prec@1: 0.000 ( 0.085) Prec@5: 0.000 ( 0.356)
Test: [ 110/391] Time: 0.380s (0.390s, 328.57/s) Loss: 9.5693 (9.6104) Prec@1: 0.000 ( 0.084) Prec@5: 0.000 ( 0.359)
Test: [ 120/391] Time: 0.380s (0.389s, 329.20/s) Loss: 9.5217 (9.6107) Prec@1: 0.000 ( 0.077) Prec@5: 0.781 ( 0.349)
Test: [ 130/391] Time: 0.379s (0.388s, 329.75/s) Loss: 9.4445 (9.6107) Prec@1: 0.000 ( 0.078) Prec@5: 0.000 ( 0.340)
Test: [ 140/391] Time: 0.380s (0.388s, 330.20/s) Loss: 9.4536 (9.6082) Prec@1: 0.781 ( 0.094) Prec@5: 0.781 ( 0.360)
Test: [ 150/391] Time: 0.381s (0.387s, 330.61/s) Loss: 9.6358 (9.6083) Prec@1: 0.000 ( 0.093) Prec@5: 2.344 ( 0.367)
Test: [ 160/391] Time: 0.379s (0.387s, 330.97/s) Loss: 9.5947 (9.6086) Prec@1: 0.000 ( 0.087) Prec@5: 0.781 ( 0.369)
Test: [ 170/391] Time: 0.380s (0.386s, 331.27/s) Loss: 9.5871 (9.6093) Prec@1: 0.000 ( 0.087) Prec@5: 0.000 ( 0.365)
Test: [ 180/391] Time: 0.379s (0.386s, 331.56/s) Loss: 9.6898 (9.6061) Prec@1: 0.000 ( 0.082) Prec@5: 0.000 ( 0.358)
Test: [ 190/391] Time: 0.379s (0.386s, 331.83/s) Loss: 9.5595 (9.6042) Prec@1: 0.781 ( 0.090) Prec@5: 0.781 ( 0.352)
Test: [ 200/391] Time: 0.379s (0.385s, 332.04/s) Loss: 9.5919 (9.6055) Prec@1: 0.000 ( 0.093) Prec@5: 0.000 ( 0.354)
Test: [ 210/391] Time: 0.429s (0.386s, 331.52/s) Loss: 9.7134 (9.6037) Prec@1: 0.000 ( 0.096) Prec@5: 0.781 ( 0.352)
Test: [ 220/391] Time: 0.380s (0.387s, 331.06/s) Loss: 9.7624 (9.6055) Prec@1: 0.000 ( 0.095) Prec@5: 0.000 ( 0.354)
Test: [ 230/391] Time: 0.379s (0.386s, 331.30/s) Loss: 9.5760 (9.6066) Prec@1: 0.000 ( 0.091) Prec@5: 0.000 ( 0.345)
Test: [ 240/391] Time: 0.379s (0.386s, 331.53/s) Loss: 9.6780 (9.6072) Prec@1: 0.781 ( 0.094) Prec@5: 0.781 ( 0.350)
Test: [ 250/391] Time: 0.379s (0.386s, 331.64/s) Loss: 9.4611 (9.6067) Prec@1: 0.000 ( 0.096) Prec@5: 0.000 ( 0.352)
Test: [ 260/391] Time: 0.379s (0.386s, 331.84/s) Loss: 9.6792 (9.6064) Prec@1: 0.000 ( 0.093) Prec@5: 0.781 ( 0.356)
Test: [ 270/391] Time: 0.379s (0.386s, 332.02/s) Loss: 9.6394 (9.6080) Prec@1: 0.781 ( 0.092) Prec@5: 0.781 ( 0.363)
Test: [ 280/391] Time: 0.412s (0.386s, 331.67/s) Loss: 9.5803 (9.6065) Prec@1: 0.000 ( 0.089) Prec@5: 0.781 ( 0.356)
Test: [ 290/391] Time: 0.379s (0.386s, 331.34/s) Loss: 9.5775 (9.6080) Prec@1: 0.000 ( 0.086) Prec@5: 0.000 ( 0.344)
Test: [ 300/391] Time: 0.379s (0.387s, 330.92/s) Loss: 9.6315 (9.6082) Prec@1: 0.000 ( 0.093) Prec@5: 1.562 ( 0.361)
Test: [ 310/391] Time: 0.379s (0.387s, 331.11/s) Loss: 9.5072 (9.6100) Prec@1: 0.000 ( 0.090) Prec@5: 0.000 ( 0.352)
Test: [ 320/391] Time: 0.379s (0.386s, 331.30/s) Loss: 9.5217 (9.6100) Prec@1: 0.000 ( 0.095) Prec@5: 0.781 ( 0.358)
Test: [ 330/391] Time: 0.391s (0.387s, 330.96/s) Loss: 9.4642 (9.6092) Prec@1: 0.000 ( 0.094) Prec@5: 0.000 ( 0.359)
Test: [ 340/391] Time: 0.379s (0.387s, 330.57/s) Loss: 9.7280 (9.6106) Prec@1: 0.000 ( 0.092) Prec@5: 0.000 ( 0.351)
Test: [ 350/391] Time: 0.379s (0.387s, 330.74/s) Loss: 9.7150 (9.6103) Prec@1: 0.000 ( 0.089) Prec@5: 0.000 ( 0.354)
Test: [ 360/391] Time: 0.380s (0.387s, 330.91/s) Loss: 9.6241 (9.6089) Prec@1: 0.000 ( 0.091) Prec@5: 0.000 ( 0.359)
Test: [ 370/391] Time: 0.380s (0.387s, 331.08/s) Loss: 9.5063 (9.6062) Prec@1: 0.000 ( 0.093) Prec@5: 1.562 ( 0.375)
Test: [ 380/391] Time: 0.379s (0.386s, 331.22/s) Loss: 9.6496 (9.6070) Prec@1: 0.000 ( 0.092) Prec@5: 0.781 ( 0.369)
Test: [ 390/391] Time: 0.234s (0.386s, 207.32/s) Loss: 9.6711 (9.6065) Prec@1: 0.000 ( 0.092) Prec@5: 0.000 ( 0.368)

  • Prec@1 0.092 (99.908) Prec@5 0.368 (99.632)

Inference error

Hi, i trained my own datasets using Efficient_B4, and
instruction as blow:
python inference.py data/custom/test/ --model efficientnet_b4 --img-size 300 -b 32 -j 4 --num-classes 2 --checkpoint output/model_best_b4.pth.tar
The error was

Traceback (most recent call last):
  File "inference.py", line 122, in <module>
    main()
  File "inference.py", line 114, in main
    filenames = loader.dataset.filenames()
AttributeError: 'PrefetchLoader' object has no attribute 'dataset'

How to fix it? Thanks a lot!

Parameters used to get 94.714 top5 with efficientnet_b2?

Hey! I'm a researcher at OpenAI looking into trends in compute used by models. I'm excited to find this repo, since it's the only one with EfficientNet that claims to approximately reproduce the original performance.

I've got two runs going on machines with 8 P100's

./distributed_train.sh 8 /tmp/imagenet-extracted/ --model efficientnet_b0 --lr 0.035 -b 64 --drop 0.2 --img-size 224 --sched step --epochs 550 --decay-epochs 2 --decay-rate 0.975 --opt rmsproptf -j 8 --warmup-epochs 5 --warmup-lr 1e-6 --weight-decay 1e-5 --opt-eps .001 --model-ema

./distributed_train.sh 8 /tmp/imagenet-extracted/ --model efficientnet_b2 --lr 0.0175 -b 32 --drop 0.2 --img-size 224 --sched step --epochs 550 --decay-epochs 2 --decay-rate 0.975 --opt rmsproptf -j 8 --warmup-epochs 5 --warmup-lr 1e-6 --weight-decay 1e-5 --opt-eps .001 --model-ema

Where the only change I made from the parameters recommended here was scaling the learning rate you used, .27 based on the difference in batch size

I'd be very interested in what specific learning rate plus other hyperparameters you used in your efficientnet-b2 run referenced in the ReadMe, and what has worked out best in b0 runs, since the learning rate above was given for a model family rather than b0 specifically.

Do you have a better performance in reappearance? Model of resnext_wsl and efficientnet advprop (b0-b8), thank you!

First of all, thank you for your excellent and complete code!

Do you have a model of ResNeXt_WSL? Thank you!
### Do you have a better performance in reappearance? Model of resnext_wsl and efficientnet advprop (b0-b8), thank you!

'resnext101_32x32d':
'resnext101_32x48d':,

Resnext home page:
https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/

GitHub project:
https://github.com/facebookresearch/WSL-Images/blob/master/hubconf.py

Finetune Parameters

Hi rwightman,
First, many thanks for your Excellent works!!!
I want to use your pretrained models for other image classification task, would you please give some advise on finetune parameters?
Best,
hungsing

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.