Giter Site home page Giter Site logo

yunkai696 / constellationnet Goto Github PK

View Code? Open in Web Editor NEW

This project forked from mlpc-ucsd/constellationnet

1.0 0.0 0.0 812 KB

(ICLR 2021) ConstellationNet: Attentional Constellation Nets for Few-Shot Learning

License: Apache License 2.0

Python 99.34% Shell 0.66%

constellationnet's Introduction

Attentional Constellation Nets For Few-shot Learning

Introduction

This repository contains the official code and pretrained models for Attentional Constellation Nets For Few-shot Learning (ICLR 2021). In this paper, we tackle the few-shot learning problem and make an effort to enhance structured features by expanding CNNs with a constellation model, which performs cell feature clustering and encoding with a dense part representation; the relationships among the cell features are further modeled by an attention mechanism. With the additional constellation branch to increase the awareness of object parts, our method is able to attain the advantages of the CNNs while making the overall internal representations more robust in the few-shot learning setting. Our approach attains a significant improvement over the existing methods in few-shot learning on the CIFAR-FS, FC100, and mini-ImageNet benchmarks.

For more details, please refer to Attentional Constellation Nets For Few-shot Learning by Weijian Xu*, Yifan Xu*, Huaijin Wang*, and Zhuowen Tu.

Changelog

09/17/2021: Code for ConstellationNet are released.

Usage

Environment Preparation

  1. Set up a new conda environment and activate it.

    # Create an environment with Python 3.8.
    conda create -n constells python==3.8
    conda activate constells
  2. Install required packages.

    # Install PyTorch 1.8.0 w/ CUDA 11.1.
    # cuda 10.2 在服务器上不报错,11.1会报错找不到GPU
    conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch -c conda-forge
    
    # Install yaml
    conda install -c anaconda pyyaml
    
    # Install tensorboardx.
    conda install -c conda-forge tensorboardx tqdm

Code and Datasets Preparation

  1. Clone the repo.

    git clone https://github.com/mlpc-ucsd/ConstellationNet.git
    cd ConstellationNet
  2. Download datasets

# 这里修改下datasets下的with open('绝对路径','rb') as f :
# 相对路径会报错(暂时没找到为什么)
 materials
├── mini-imagenet
│   ├── miniImageNet_category_split_test.pickle
│   ├── miniImageNet_category_split_train_phase_test.pickle
│   ├── miniImageNet_category_split_train_phase_train.pickle
│   ├── miniImageNet_category_split_train_phase_val.pickle
│   ├── miniImageNet_category_split_val.pickle
├── cifar-fs
│   ├── CIFAR_FS_test.pickle
│   ├── CIFAR_FS_train.pickle
│   ├── CIFAR_FS_val.pickle
├── fc100
│   ├── FC100_test.pickle
│   ├── FC100_train.pickle
│   ├── FC100_val.pickle

Pre-trained Checkpoints

We provide the Constellation Nets checkpoints pre-trained on the Mini-Imagenet, CIFAR-FS and FC100.

Dataset Model Backbone Acc@ 5-way 1-shot Acc@ 5-way 5-shot #Params SHA-256 (first 8 chars) URL
Mini-ImageNet ConstellationNets Conv-4 59.67 ± 0.23 75.98 ± 0.18 200K d76075a5 model
Mini-ImageNet ConstellationNets ResNet-12 65.53 ± 0.23 80.55 ± 0.16 8.4M cf716d90 model
CIFAR-FS ConstellationNets Conv-4 69.1 ± 0.3 83.0 ± 0.2 200K 4ea590f9 model
CIFAR-FS ConstellationNets ResNet-12 76.1 ± 0.2 87.4 ± 0.2 8.4M dc5d56fa model
FC100 ConstellationNets ResNet-12 43.9 ± 0.2 59.7 ± 0.2 8.4M d9a829f7 model

Train

The following commands provide an example to train the Constellation Net .

# Usage: bash ./scripts/train.sh [Dataset (mini, cifar-fs, fc100)] [Backbone (conv4, res12)] [GPU index] [Tag]
bash ./scripts/train.sh mini conv4 0 trial1

Evaluate

The following commands provide an example to evaluate the checkpoint after training.

# Usage: bash ./scripts/test.sh [Dataset (mini, cifar-fs, fc100)] [Backbone (conv4, res12)] [GPU index] [Tag]
bash ./scripts/eval.sh mini conv4 0 trial1

数据集测试

datasets miniImageNet 5-way miniImageNet 5-way
BackBone BackBone-ResNet12 BackBone-Conv-4
setups 5-way 1-shot 5-way 5-shot 5-way 1-shot 5-way 5-shot
accuracy 64.89 ± 0.23 79.95 ± 0.17 58.82 ± 0.23 75.00 ± 0.18
datasets CIFAR-FS 5-way CIFAR-FS 5-way
BackBone BackBone-ResNet12 BackBone-Conv-4
setups 5-way 1-shot 5-way 5-shot 5-way 1-shot 5-way 5-shot
accuracy 75.4 ± 0.2 86.8 ± 0.2 69.3 ± 0.3 82.7 ± 0.2
datasets FC100 5-way FC100 5-way
BackBone BackBone-ResNet12 BackBone-Conv-4
setups 5-way 1-shot 5-way 5-shot 5-way 1-shot 5-way 5-shot
accuracy 43.8 ± 0.2 59.7 ± 0.2 待测试 待测试

Citation

@inproceedings{
xu2021attentional,
title={Attentional Constellation Nets for Few-Shot Learning},
author={Weijian Xu and yifan xu and Huaijin Wang and Zhuowen Tu},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=vujTf_I8Kmc}
}

License

This repository is released under the Apache License 2.0. License can be found in LICENSE file.

Acknowledgment

constellationnet's People

Contributors

ptrnn avatar yix081 avatar yunkai696 avatar

Stargazers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.