Giter Site home page Giter Site logo

ipn's Introduction

IPN

Introduction

A pytorch implementation of the IJCAI2020 paper "Few-shot Visual Learning with Contextual Memory and Fine-grained Calibration". The code is based on Edge-labeling Graph Neural Network for Few-shot Learning and Revisiting Local Descriptor based Image-to-Class Measure for Few-shot Learning.

Author: Yuqing Ma, Wei Liu, Shihao Bai, Qingyu Zhang, Aishan Liu, Weimin Chen and Xianglong Liu

Abstract: Few-shot learning aims to learn a model that can be readily adapted to new unseen classes (concepts) by accessing one or few examples. Despite the successful progress, most of the few-shot learning approaches, concentrating on either global or local characteristics of examples, still suffer from weak generalization abilities. Inspired by the inverted pyramid theory, to address this problem, we propose an inverted pyramid network (IPN) that intimates the human's coarse-to-fine cognition paradigm. The proposed IPN consists of two consecutive stages, namely global stage and local stage. At the global stage, a class-sensitive contextual memory network (CCMNet) is introduced to learn discriminative support-query relation embeddings and predict the query-to-class similarity based on the contextual memory. Then at the local stage, a fine-grained calibration is further appended to complement the coarse relation embeddings, targeting more precise query-to-class similarity evaluation. To the best of our knowledge, IPN is the first work that simultaneously integrates both global and local characteristics in few-shot learning, approximately imitating the human cognition mechanism. Our extensive experiments on multiple benchmark datasets demonstrate the superiority of IPN, compared to a number of state-of-the-art approaches.

Requirements

  • Python 3
  • Python packages
    • pytorch 1.0.0
    • torchvision 0.2.2
    • matplotlib
    • numpy
    • pillow
    • tensorboardX

An NVIDIA GPU and CUDA 9.0 or higher.

Getting started

mini-ImageNet

You can download miniImagenet dataset from here.

tiered-ImageNet

You can download tieredImagenet dataset from here.

Because WRN has a large amount of parameters. You can save the extracted feature before the classifaction layer to increase train or test speed. Here we provide the features extracted by WRN:

You also can use our pretrained WRN model to generate features for mini or tiered by yourself.

Training

# ************************** miniImagenet, 5way 1shot  *****************************
$ python3 train.py --dataset mini --num_ways 5 --num_shots 1 

# ************************** miniImagenet, 5way 5shot *****************************
$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 

# ************************** tieredImagenet, 5way 1shot *****************************
$ python3 train.py --dataset tiered --num_ways 5 --num_shots 1 

# ************************** tieredImagenet, 5way 5shot *****************************
$ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 

You can download our pretrained model from here to reproduce the results of the paper.

Testing

# ************************** miniImagenet, 5way 5shot *****************************
$ python3 eval.py --test_model your_path --dataset mini --num_ways C --num_shots K 

ipn's People

Contributors

shihaobai avatar vaynelau avatar vickyfox avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.