Giter Site home page Giter Site logo

haoqing-wang / cdfsl-ata Goto Github PK

View Code? Open in Web Editor NEW
56.0 2.0 7.0 17.78 MB

[IJCAI 2021 & AIJ 2023] Cross-Domain Few-Shot Classification via Adversarial Task Augmentation

Python 99.18% Shell 0.82%
few-shot-learning meta-learning adversarial-learning cross-domain

cdfsl-ata's Introduction

Cross-Domain Few-Shot Classification via Adversarial Task Augmentation

PyTorch implementation of

IJCAI 2021:
Cross-Domain Few-Shot Classification via Adversarial Task Augmentation
Haoqing Wang, Zhi-hong Deng

Artificial Intelligence 2023:
Towards well-generalizing meta-learning via adversarial task augmentation
Haoqing Wang, Huiyu Mai, Yuhang Gong, Zhi-hong Deng

Abstract

Few-shot classification aims to recognize unseen classes with few labeled samples from each class. Many meta-learning models for few-shot classification elaborately design various task-shared inductive bias (meta-knowledge) to solve such tasks, and achieve impressive performance. However, when there exists the domain shift between the training tasks and the test tasks, the obtained inductive bias fails to generalize across domains, which degrades the performance of the meta-learning models. In this work, we aim to improve the robustness of the inductive bias through task augmentation. Concretely, we consider the worst-case problem around the source task distribution, and propose the adversarial task augmentation method which can generate the inductive bias-adaptive 'challenging' tasks. Our method can be used as a simple plug-and-play module for various meta-learning models, and improve their cross-domain generalization capability. We conduct extensive experiments under the cross-domain setting, using nine few-shot classification datasets: mini-ImageNet, CUB, Cars, Places, Plantae, CropDiseases, EuroSAT, ISIC and ChestX. Experimental results show that our method can effectively improve the few-shot classification performance of the meta-learning models under domain shift, and outperforms the existing works.

Citation

If you use this code for your research, please cite our paper:

@inproceedings{ijcai2021-149,
  title     = {Cross-Domain Few-Shot Classification via Adversarial Task Augmentation},
  author    = {Wang, Haoqing and Deng, Zhi-Hong},
  booktitle = {Proceedings of the Thirtieth International Joint Conference on
               Artificial Intelligence, {IJCAI-21}},
  publisher = {International Joint Conferences on Artificial Intelligence Organization},
  editor    = {Zhi-Hua Zhou},
  pages     = {1075--1081},
  year      = {2021},
  month     = {8},
  note      = {Main Track}
  doi       = {10.24963/ijcai.2021/149},
  url       = {https://doi.org/10.24963/ijcai.2021/149},
}

and

@article{wang2023towards,
  title={Towards well-generalizing meta-learning via adversarial task augmentation},
  author={Wang, Haoqing and Mai, Huiyu and Gong, Yuhang and Deng, Zhi-Hong},
  journal={Artificial Intelligence},
  pages={103875},
  year={2023},
  publisher={Elsevier}
}

Dependencies

Datasets

We use miniImageNet as the single source domain, and use CUB, Cars, Places, Plantae, CropDiseases, EuroSAT, ISIC and ChestX as the target domains.

For miniImageNet, CUB, Cars, Places and Plantae, download and process them seperately with the following commands.

  • Set DATASET_NAME to: miniImagenet, cub, cars, places or plantae.
cd filelists
python process.py DATASET_NAME
cd ..

For CropDiseases, EuroSAT, ISIC and ChestX, download them from

and put them under their respective paths, e.g., 'filelists/CropDiseases', 'filelists/EuroSAT', 'filelists/ISIC', 'filelists/chestX', then process them with following commands.

  • Set DATASET_NAME to: CropDiseases, EuroSAT, ISIC or chestX.
cd filelists/DATASET_NAME
python write_DATASET_NAME_filelist.py
cd ..

Pre-training

We adopt baseline pre-training from CloserLookFewShot for all models.

  • Download the pre-trained feature encoders from CloserLookFewShot.
  • Or train your own pre-trained feature encoder.
python pretrain.py --dataset miniImagenet --name Pretrain --train_aug

Training

1.Train meta-learning models

Set method to MatchingNet, RelationNet, ProtoNet, GNN or TPN. For MatchingNet, RelationNet and TPN models, we set the training shot be 5 for both 1s and 5s evaluation.

python train.py --model ResNet10 --method GNN --n_shot 5 --name GNN_5s --train_aug
python train.py --model ResNet10 --method TPN --n_shot 5 --name TPN --train_aug

2.Train meta-learning models with feature-wise transformations.

Set method to MatchingNet, RelationNet, ProtoNet, GNN or TPN.

python train_FT.py --model ResNet10 --method GNN --n_shot 5 --name GNN_FWT_5s --train_aug
python train_FT.py --model ResNet10 --method TPN --n_shot 5 --name TPN_FWT --train_aug

3.Explanation-guided train meta-learning models.

Set method to RelationNetLRP or GNNLRP.

python train.py --model ResNet10 --method GNNLRP --n_shot 5 --name GNN_LRP_5s --train_aug
python train.py --model ResNet10 --method RelationNetLRP --n_shot 5 --name RelationNet_LRP --train_aug

4.Train meta-learning models with Adversarial Task Augmentation.

Set method to MatchingNet, RelationNet, ProtoNet, GNN or TPN.

python train_ATA.py --model ResNet10 --method GNN --max_lr 80. --T_max 5 --prob 0.5 --n_shot 5 --name GNN_ATA_5s --train_aug
python train_ATA.py --model ResNet10 --method TPN --max_lr 20. --T_max 5 --prob 0.6 --n_shot 5 --name TPN_ATA --train_aug

To get the results of the iteration goal without the regularization term, with the sample-wise Euclidean distance regularization term and with the maximum mean discrepancy (MMD) distance regularization term, run

python train_NR.py --model ResNet10 --method GNN --max_lr 80. --T_max 5 --n_shot 5 --name GNN_NR_5s --train_aug
python train_Euclid.py --model ResNet10 --method GNN --max_lr 40. --T_max 5 --lamb 1. --n_shot 5 --name GNN_Euclid_5s --train_aug
python train_MMD.py --model ResNet10 --method GNN --max_lr 80. --T_max 5 --lamb 1. --n_shot 5 --name GNN_MMD_5s --train_aug

Evaluation and Fine-tuning

1.Test the trained model on the unseen domains.

  • Specify the target dataset with --dataset: cub, cars, places, plantae, CropDiseases, EuroSAT, ISIC or chestX.
  • Specify the saved model you want to evaluate with --name.
python test.py --dataset cub --n_shot 5 --model ResNet10 --method GNN --name GNN_ATA_5s
python test.py --dataset cub --n_shot 5 --model ResNet10 --method GNN --name GNN_LRP_5s

2.Fine-tuning with linear classifier. To get the results of traditional pre-training and fine-tuning, run

python finetune.py --dataset cub --n_shot 5 --finetune_epoch 50 --model ResNet10 --name Pretrain

3.Fine-tuning the meta-learning models.

python finetune_ml.py --dataset cub --method GNN --n_shot 5 --finetune_epoch 50 --model ResNet10 --name GNN_ATA_5s

Note

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.