Giter Site home page Giter Site logo

ssl-few-shot's Introduction

Self-Supervised Learning For Few-shot Image Classification

This repository contains a PyTroch implementation of the algorithm presented in the paper Self-Supervised Learning For Few-shot Image Classification. https://arxiv.org/abs/1911.06045

Few-Shot Learning & Self Supervised Learning

Few-shot image classification aims to robustly classify unseen classes with limited samples for each class. Recent works benefit from the meta-learning process with episodic tasks and can fast adapt to class change from training to testing. Due to the limited number of samples in each task, the initial embedding network for meta learning becomes an essential component of the network and can largely affects the classification performance in practice. Recent advanced methods based on pre-trained embedding network have significantly improved the state-of-the-art results in different few-shot classification datasets. These methods, however, are highly relied on the quality of the embedding network. In this paper, we proposed a novel method which includes a more generalized embedding network under self-supervised learning(SSL) and tackle the 'curse of layers issue' in few-shot learning.

Self-Supervised Learning For Deeper Few-shot Image Classification

Prerequisites

The following packages are required to run the scripts:

Result

MiniImage_Resule

Usage

Pre-Trained Models

MiniImagenet

ModelName Model Arch Mode URL
mini80_ssl ndf=192, rkhs=1536, rd=8 (SSL) mini_imagenet_ndf192_rkhs1536_rd8_ssl_cpt.pth
mini80_sl ndf=192, rkhs=1536, rd=8 (SL) mini_imagenet_ndf192_rkhs1536_rd8_sl_cpt.pth
imagenet900_ssl ndf=192, rkhs=1536, rd=8 (SSL) imagenet900_ndf192_rkhs_1536_rd8_sl_cpt.pth

CUB

ModelName Model Arch Mode URL
cub150_ssl ndf=192, rkhs=1536, rd=8 (SSL) cub_ndf192_rkhs1536_rd8_ssl_cpt.pth
cub150_sl ndf=192, rkhs=1536, rd=8 (SL) cub_ndf192_rkhs1536_rd8_sl_cpt.pth
imagenet1000_ssl ndf=192, rkhs=1536, rd=8 (SSL) imagenet1K_ndf192_rkhs_1536_rd8_ssl_cpt.pth

Training

Example of train miniimagenet from mini80_ssl

python train_protonet.py --lr 0.0001 --temperature 128   \
--max_epoch 100 --model_type AmdimNet --dataset MiniImageNet \
--init_weights ./saves/initialization/miniimagenet/mini_imagenet_ndf192_rkhs1536_rd8_ssl_cpt.pth  \
--save_path ./MINI_ProtoNet_MINI_1shot_5way/ \
--shot 1  --way 5 --gpu 4 --step_size 10 --gamma 0.5 \
--ndf 192 --rkhs 1536 --nd 8

References

Learning Representations by Maximizing Mutual Information Across Views

Learning Embedding Adaptation for Few-Shot Learning

Citation

@article{chen2019selfsupervised,
    title={Self-Supervised Learning For Few-Shot Image Classification},
    author={Da Chen and Yuefeng Chen and Yuhong Li and Feng Mao and Yuan He and Hui Xue},
    Journal={arXiv preprint arXiv:1911.06045},
    year={2019}
}

Contact

For questions please contact Yuefeng Chen at [email protected].

Acknowledgements

This code is built on FEAT (PyTorch) and AMDIM (PyTorch). We thank the authors for sharing their codes.

ssl-few-shot's People

Contributors

nessessence avatar darkxenoz avatar phecy avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.