Giter Site home page Giter Site logo

pytorch-pretrained-vit's Introduction

ViT PyTorch

Quickstart

Install with pip install pytorch_pretrained_vit and load a pretrained ViT with:

from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)

Or find a Google Colab example here.

Overview

This repository contains an op-for-op PyTorch reimplementation of the Visual Transformer architecture from Google, along with pre-trained models and examples.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects.

At the moment, you can easily:

  • Load pretrained ViT models
  • Evaluate on ImageNet or your own data
  • Finetune ViT on your own dataset

(Upcoming features) Coming soon:

  • Train ViT from scratch on ImageNet (1K)
  • Export to ONNX for efficient inference

Table of contents

  1. About ViT
  2. About ViT-PyTorch
  3. Installation
  4. Usage
  5. Contributing

About ViT

Visual Transformers (ViT) are a straightforward application of the transformer architecture to image classification. Even in computer vision, it seems, attention is all you need.

The ViT architecture works as follows: (1) it considers an image as a 1-dimensional sequence of patches, (2) it prepends a classification token to the sequence, (3) it passes these patches through a transformer encoder (like BERT), (4) it passes the first token of the output of the transformer through a small MLP to obtain the classification logits. ViT is trained on a large-scale dataset (ImageNet-21k) with a huge amount of compute.

About ViT-PyTorch

ViT-PyTorch is a PyTorch re-implementation of ViT. It is consistent with the original Jax implementation, so that it's easy to load Jax-pretrained weights.

At the same time, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.

Installation

Install with pip:

pip install pytorch_pretrained_vit

Or from source:

git clone https://github.com/lukemelas/ViT-PyTorch
cd ViT-Pytorch
pip install -e .

Usage

Loading pretrained models

Loading a pretrained model is easy:

from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)

Details about the models are below:

Name * Pretrained on * Finetuned on *Available? *
B_16 ImageNet-21k -
B_32 ImageNet-21k -
L_16 ImageNet-21k - -
L_32 ImageNet-21k -
B_16_imagenet1k ImageNet-21k ImageNet-1k
B_32_imagenet1k ImageNet-21k ImageNet-1k
L_16_imagenet1k ImageNet-21k ImageNet-1k
L_32_imagenet1k ImageNet-21k ImageNet-1k

Custom ViT

Loading custom configurations is just as easy:

from pytorch_pretrained_vit import ViT
# The following is equivalent to ViT('B_16')
config = dict(hidden_size=512, num_heads=8, num_layers=6)
model = ViT.from_config(config)

Example: Classification

Below is a simple, complete example. It may also be found as a Jupyter notebook in examples/simple or as a Colab Notebook.

import json
from PIL import Image
import torch
from torchvision import transforms

# Load ViT
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
model.eval()

# Load image
# NOTE: Assumes an image `img.jpg` exists in the current directory
img = transforms.Compose([
    transforms.Resize((384, 384)), 
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])(Image.open('img.jpg')).unsqueeze(0)
print(img.shape) # torch.Size([1, 3, 384, 384])

# Classify
with torch.no_grad():
    outputs = model(img)
print(outputs.shape)  # (1, 1000)

ImageNet

See examples/imagenet for details about evaluating on ImageNet.

Credit

Other great repositories with this model include:

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

pytorch-pretrained-vit's People

Contributors

lukemelas avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-pretrained-vit's Issues

cannot download pre-trained model

Hi!

When I try to download the pre-trained model, i get the error:
'NoneType' object has no attribute 'group'

This is what I do:

from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)

I'm using pytorch 1.1. Does the version matter? Because it works on Google Colab but not my machine.

B_16 return zeros

By some reason, the output of B_16 model is only zeros. I tested B_32 and L_32, seems like they work properly

TypeError: 'tuple' object is not callable

Thanks for your respectable work! But there seem to be some mistakes when I run the training code:
File "main.py", line 427, in
main()
File "main.py", line 111, in main
main_worker(args.gpu, ngpus_per_node, args)
File "main.py", line 134, in main_worker
model = ViT(args.arch, pretrained=args.pretrained,image_size=args.image_size)
File "/userhome/Transformer/PyTorch-Pretrained-ViT/pytorch_pretrained_vit/model.py", line 135, in init
resize_positional_embedding=(image_size != pretrained_image_size),
File "/userhome/Transformer/PyTorch-Pretrained-ViT/pytorch_pretrained_vit/utils.py", line 62, in load_pretrained_weights
has_class_token=hasattr(model, 'class_token'))
File "/userhome/Transformer/PyTorch-Pretrained-ViT/pytorch_pretrained_vit/utils.py", line 101, in resize_positional_embedding_
posemb_grid = zoom(posemb_grid, zoom, order=1)
TypeError: 'tuple' object is not callable

This mistake appears when I want to finetune the pretained model (B_16, 224*224) on the ImageNet-1K (384*384). Thank you again for this work. Looking forward to your reply!

Load from google's saved weights.

Hi, I was wondering if there would be a way to load the weights from google's saved checkpoint directly, instead of having to download them.

I see that in the init of ViT is:

            load_pretrained_weights(
                self, name, 
                load_first_conv=(in_channels == pretrained_num_channels),
                load_fc=(num_classes == pretrained_num_classes),
                load_repr_layer=load_repr_layer,
                resize_positional_embedding=(image_size != pretrained_image_size),
            )

So, the weights_path can't be given to ViT. Could this be ammended?

Thank you for your help.

Qs about your code compared with the orignal code.

Hi, I noticed that it:
your code:

x = self.positional_embedding(x)  # b,gh*gw+1,d 
x = self.transformer(x)  # b,gh*gw+1,d

Vision Transformer(from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py):

x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)

Actually, there are two differences:

  1. you don't use the dropout after positional_embedding
  2. the original positional_embedding is not used in the classification token

Could you please tell me the reasons for these changes?
Looking forward to your reply, thanks very much.

Cannot load custom config

Hey! First of all, thanks for your contribution! I have looked at multiple ViT implementations and yours seems like the most straightforward, well-organized and simple to use.

I'd like to use your from_config method to initiate the model, but I get this error. I was looking everywhere and couldn't find any from_config method so that may be the problem?

from pytorch_pretrained_vit import ViT
# The following is equivalent to ViT('B_16')
config = dict(hidden_size=512, num_heads=8, num_layers=6)
model = ViT.from_config(config)

AttributeError: type object 'ViT' has no attribute 'from_config'

Also, I'm guessing that if you change anything in the config, the model would have to be retrained from scratch, since the pretrained weights wouldn't fit the model anymore, is that right?

And another thing is that you mention that those are equivalent to ViT('B_16') but in B_16 shouldnt the num_heads=12, and num_layers=12? And what is hidden_size=512 for? I cannot find any part in the code that refers to it.

Thanks in advance.

how to extract features from an image?

Hi. I tried to remove the last layer because I only want the features, but I get the error:

The size of tensor a (24) must match the size of tensor b (768) at non-singleton dimension 3

modules = list(model.children())[:-1]
new_model = nn.Sequential(*modules)

Thanks

Evaluation Performance

I can't find he evaluation performance in the readme. Do you have them written somewhere?

TypeError with ProgressMeter class

While running the script main.py in \examples\imagenet\ I am getting the following error

Exception has occurred: TypeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
unsupported format string passed to Tensor.__format__
  File "/home/.........../.local/lib/python3.9/site-packages/torch/_tensor.py", line 660, in __format__
    return object.__format__(self, format_spec)
  File "/home/.........../VisionTransformer/main.py", line 380, in __str__
    return fmtstr.format(**self.__dict__)
  File "/home/.........../VisionTransformer/main.py", line 391, in <listcomp>
    entries += [str(meter) for meter in self.meters]
  File "/home/.........../VisionTransformer/main.py", line 391, in print
    entries += [str(meter) for meter in self.meters]
  File "/home/.........../VisionTransformer/main.py", line 344, in validate
    progress.print(i)
  File "/home/.........../VisionTransformer/main.py", line 236, in main_worker
    res = validate(val_loader, model, criterion, args)
  File "/home/.........../VisionTransformer/main.py", line 111, in main
    main_worker(args.gpu, ngpus_per_node, args)
  File "/home/.........../VisionTransformer/main.py", line 425, in <module>
    main()
  File "/tools/conda/anaconda3/envs/torch_new/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/tools/conda/anaconda3/envs/torch_new/lib/python3.9/runpy.py", line 197, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
TypeError: unsupported format string passed to Tensor.__format__

I googled a bit and it looks like the issue is happening because it is getting a format it was not expecting. I am not proficient in python to resolve this. Kindly suggest some fix, thanks!

Multi head uses just one set of Q, K, V?

In transformer.py, in class MultiHeadedSelfAttention() we have the var declaration:

  self.proj_q = nn.Linear(dim, dim)
  self.proj_k = nn.Linear(dim, dim)
  self.proj_v = nn.Linear(dim, dim)

but wasn't suposed to be Q, K and V an independent trainable matrix per head? E.g. if num_head = 12, wasn't that suposed to be like:

set = []
for i in range(12):
    set.append([nn.Linear(dim, dim), nn.Linear(dim, dim), nn.Linear(dim, dim)])

Regards!

Cannot load representation layer

I was trying to load the whole model with pretrained=True and representation layer = True but I get an error
Further inspection by looking at the keys of the state_dict in your config.py file I noticed that there's no such weights in the state_dict for any of the models. It jumps from the transformer weights to a norm layer weights then to the fc weights. It skips the pre_logits.weight. I noticed in your Jupyter notebook it seems to be working correctly. Got any idea what could be wrong?

How to train ViT from scratch on Imagenet1k?

I trying to train 'ViT-B-16' , but I think that it's very hard to train ViT from scratch on Imagenet1k.
The parameters used in the experiment are as follows.

  • batch size : 1024
  • optimizer : Adamw
  • weight decay : 0.3
  • learning rate : 0.001
  • cosine warmup : 40k

How to do "Fine-Tuning" or "Feature-Extraction" in the model B_16 (or even L_16)?

Hi there, do you know how I can I use one of the two techniques above to do image classification on "Stanford Dogs Dataset"?
I've already tried the "B_16_imagenet1k" model but the accuracy obtained on 4.160 images isn't that good.

I saw that the difference between B_16 and L_16 is in the model parameters so even in the structure of the network.
I didn't focus on it: can you explain it? Do you know where can I read about it?

About torch.no_gard()

Hi,
Thanks for this implementation.
I saw the parameters of nn.Linear() are set to no_gard() in models.py Line:139.

    @torch.no_grad()
    def init_weights(self):
        def _init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)  # _trunc_normal(m.weight, std=0.02)  # from .initialization import _trunc_normal
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)  # nn.init.constant(m.bias, 0)
        self.apply(_init)
        nn.init.constant_(self.fc.weight, 0)
        nn.init.constant_(self.fc.bias, 0)
        nn.init.normal_(self.positional_embedding.pos_embedding, std=0.02)  # _trunc_normal(self.positional_embedding.pos_embedding, std=0.02)
        nn.init.constant_(self.class_token, 0)

Does this mean this pro only supports eval?
These parameters should be trainable if I want train ViT on my own dataset?

Multilabel Image classification

@lukemelas thanks for sharing code base , i had few queries
1.can we train crossvit for multilabel classification problem , if so what is the procedure
2. i have a custom dataset of 10.5k with 25class labels with instance as label vectors of 0 and 1
3. can remove the pre-trained classifer head and add our customr classifier ?

Thanks in advance

self.pos_embedding error training on different dataset

I am using a pre-trained VIT model and trained on some different task but I got an error

model.py file

class PositionalEmbedding1D(nn.Module):
    """Adds (optionally learned) positional embeddings to the inputs."""

    def __init__(self, seq_len, dim):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim))
    
    def forward(self, x):
        """Input has shape `(batch_size, seq_len, emb_dim)`"""
        return x + self.pos_embedding

Traceback

    result = self.forward(*input, **kwargs)
  File "/media/khawar/HDD_Khawar/n/Pretrained_ViT/pytorch_pretrained_vit/model.py", line 24, in forward
    return x + self.pos_embedding

best performing model

Hello and many thanks for your code.
Can I know what is the best performing model? According to the paper on page 12, it seems that ViT-B/16 performs the best? So fewer layers work better?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.