Giter Site home page Giter Site logo

tiny-transformers's Introduction

Locality Guidance for Improving Vision Transformers on Tiny Datasets (ECCV 2022)

[arXiv paper] [ECCV paper]

method

Description

This is a PyTorch implementation of the paper "Locality Guidance for Improving Vision Transformers on Tiny Datasets", supporting different Transformer models (including DeiT, T2T-ViT, PiT, PVT, PVTv2, ConViT, CvT) and different classification datasets (including CIFAR-100, Oxford Flowers, Tiny ImageNet, Chaoyang).

Abstract

While the Vision Transformer (VT) architecture is becoming trendy in computer vision, pure VT models perform poorly on tiny datasets. To address this issue, this paper proposes the locality guidance for improving the performance of VTs on tiny datasets. We first analyze that the local information, which is of great importance for understanding images, is hard to be learned with limited data due to the high flexibility and intrinsic globality of the self-attention mechanism in VTs. To facilitate local information, we realize the locality guidance for VTs by imitating the features of an already trained convolutional neural network (CNN), inspired by the built-in local-to-global hierarchy of CNN. Under our dual-task learning paradigm, the locality guidance provided by a lightweight CNN trained on low-resolution images is adequate to accelerate the convergence and improve the performance of VTs to a large extent. Therefore, our locality guidance approach is very simple and efficient, and can serve as a basic performance enhancement method for VTs on tiny datasets. Extensive experiments demonstrate that our method can significantly improve VTs when training from scratch on tiny datasets and is compatible with different kinds of VTs and datasets. For example, our proposed method can boost the performance of various VTs on tiny datasets (e.g., 13.07% for DeiT, 8.98% for T2T and 7.85% for PVT), and enhance even stronger baseline PVTv2 by 1.86% to 79.30%, showing the potential of VTs on tiny datasets.

Usage

Dependencies

The base environment we used for experiments is:

  • python = 3.8.12
  • pytorch = 1.8.0
  • cudatoolkit = 10.1

Other dependencies can be installed by:

pip install -r requirements.txt

Data Preparation

Step 1: download datasets from their official websites:

Step 2: move or link the datasets to data/ directory. We show the layout of data/ directory as follow:

data
└── cifar-100-python
|   ├── meta
|   ├── test
|   └── train
└── flowers
|   ├── jpg
|   ├── imagelabels.mat
|   └── setid.mat
└── tiny-imagenet-200
|   ├── train
|       ├── n01443537
|       └── ...
|   └── val
|       ├── images
|       └── val_annotations.txt
└── chaoyang
    ├── test
    ├── train
    ├── test.json
    └── train.json

Train from Scratch

For example, you can train DeiT-Tiny from scratch using:

python run_net.py --mode train --cfg configs/deit/deit-ti_c100_base.yaml

Besides, we provide configurations for different models and different datasets at configs/.

Train with Locality Guidance

Step 1: train the CNN guidance model (e.g., ResNet-56). This step will only take a little time and only needs to be executed once for each dataset.

python run_net.py --mode train --cfg configs/resnet/r-56_c100.yaml

Step 2: train the target VT.

python run_net.py --mode train --cfg configs/deit/deit-ti_c100_ours.yaml

As mentioned in the supplementary materials, the locality guidance can be executed offline using the per-computed features. To run in this setting, you can use:

# Pre-compute features
python precompute_feature.py --cfg configs/resnet/r-56_c100.yaml --ckpt work_dirs/r-56_c100/model.pyth
# Train the model
python run_net.py --mode train --cfg configs/deit/deit-ti_c100_ours_offline.yaml

Multi-GPU & Mixed Precision Support

Just one argument needs to be added for multi-gpu or mixed precision training, for example:

# Train DeiT from scratch with 2 gpus
python run_net.py --mode train --cfg configs/deit/deit-ti_c100_base.yaml NUM_GPUS 2

# Train DeiT from scratch with 2 gpus using mixed precision
python run_net.py --mode train --cfg configs/deit/deit-ti_c100_base.yaml NUM_GPUS 2 TRAIN.MIXED_PRECISION True

Test

python run_net.py --mode test --cfg configs/deit/deit-ti_c100_base.yaml TEST.WEIGHTS /path/to/model.pyth

Results

Model Top-1 Acc. (Base) Top-1 Acc. (Ours)
DeiT-Tiny 65.08 ( weights | log ) 78.15 ( weights | log )
T2T-ViT-7 69.37 ( weights | log ) 78.35 ( weights | log )
PiT-Tiny 73.58 ( weights | log ) 78.48 ( weights | log )
PVT-Tiny 69.22 ( weights | log ) 77.07 ( weights | log )
PVTv2-B0 77.44 ( weights | log ) 79.30 ( weights | log )
ConViT-Tiny 75.32 ( weights | log ) 78.95 ( weights | log )

Here we provide pre-trained models and training logs (can be viewed via TensorBoard).

Acknowledgement

This repository is built upon pycls and the official implementations of DeiT, T2T-ViT, PiT, PVTv1/v2, ConViT and CvT. We would like to thank authors of these open source repositories.

Citing

@article{li2022locality,
  title={Locality Guidance for Improving Vision Transformers on Tiny Datasets},
  author={Li, Kehan and Yu, Runyi and Wang, Zhennan and Yuan, Li and Song, Guoli and Chen, Jie},
  journal={arXiv preprint arXiv:2207.10026},
  year={2022}
}

tiny-transformers's People

Contributors

wooyang2018 avatar lkhl avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.