Giter Site home page Giter Site logo

haoqing-wang / cdfsl-ata Goto Github PK

View Code? Open in Web Editor NEW
55.0 2.0 7.0 17.78 MB

[IJCAI 2021 & AIJ 2023] Cross-Domain Few-Shot Classification via Adversarial Task Augmentation

Python 99.18% Shell 0.82%
few-shot-learning meta-learning adversarial-learning cross-domain

cdfsl-ata's Introduction

Cross-Domain Few-Shot Classification via Adversarial Task Augmentation

PyTorch implementation of

IJCAI 2021:
Cross-Domain Few-Shot Classification via Adversarial Task Augmentation
Haoqing Wang, Zhi-hong Deng

Artificial Intelligence 2023:
Towards well-generalizing meta-learning via adversarial task augmentation
Haoqing Wang, Huiyu Mai, Yuhang Gong, Zhi-hong Deng

Abstract

Few-shot classification aims to recognize unseen classes with few labeled samples from each class. Many meta-learning models for few-shot classification elaborately design various task-shared inductive bias (meta-knowledge) to solve such tasks, and achieve impressive performance. However, when there exists the domain shift between the training tasks and the test tasks, the obtained inductive bias fails to generalize across domains, which degrades the performance of the meta-learning models. In this work, we aim to improve the robustness of the inductive bias through task augmentation. Concretely, we consider the worst-case problem around the source task distribution, and propose the adversarial task augmentation method which can generate the inductive bias-adaptive 'challenging' tasks. Our method can be used as a simple plug-and-play module for various meta-learning models, and improve their cross-domain generalization capability. We conduct extensive experiments under the cross-domain setting, using nine few-shot classification datasets: mini-ImageNet, CUB, Cars, Places, Plantae, CropDiseases, EuroSAT, ISIC and ChestX. Experimental results show that our method can effectively improve the few-shot classification performance of the meta-learning models under domain shift, and outperforms the existing works.

Citation

If you use this code for your research, please cite our paper:

@inproceedings{ijcai2021-149,
  title     = {Cross-Domain Few-Shot Classification via Adversarial Task Augmentation},
  author    = {Wang, Haoqing and Deng, Zhi-Hong},
  booktitle = {Proceedings of the Thirtieth International Joint Conference on
               Artificial Intelligence, {IJCAI-21}},
  publisher = {International Joint Conferences on Artificial Intelligence Organization},
  editor    = {Zhi-Hua Zhou},
  pages     = {1075--1081},
  year      = {2021},
  month     = {8},
  note      = {Main Track}
  doi       = {10.24963/ijcai.2021/149},
  url       = {https://doi.org/10.24963/ijcai.2021/149},
}

and

@article{wang2023towards,
  title={Towards well-generalizing meta-learning via adversarial task augmentation},
  author={Wang, Haoqing and Mai, Huiyu and Gong, Yuhang and Deng, Zhi-Hong},
  journal={Artificial Intelligence},
  pages={103875},
  year={2023},
  publisher={Elsevier}
}

Dependencies

Datasets

We use miniImageNet as the single source domain, and use CUB, Cars, Places, Plantae, CropDiseases, EuroSAT, ISIC and ChestX as the target domains.

For miniImageNet, CUB, Cars, Places and Plantae, download and process them seperately with the following commands.

  • Set DATASET_NAME to: miniImagenet, cub, cars, places or plantae.
cd filelists
python process.py DATASET_NAME
cd ..

For CropDiseases, EuroSAT, ISIC and ChestX, download them from

and put them under their respective paths, e.g., 'filelists/CropDiseases', 'filelists/EuroSAT', 'filelists/ISIC', 'filelists/chestX', then process them with following commands.

  • Set DATASET_NAME to: CropDiseases, EuroSAT, ISIC or chestX.
cd filelists/DATASET_NAME
python write_DATASET_NAME_filelist.py
cd ..

Pre-training

We adopt baseline pre-training from CloserLookFewShot for all models.

  • Download the pre-trained feature encoders from CloserLookFewShot.
  • Or train your own pre-trained feature encoder.
python pretrain.py --dataset miniImagenet --name Pretrain --train_aug

Training

1.Train meta-learning models

Set method to MatchingNet, RelationNet, ProtoNet, GNN or TPN. For MatchingNet, RelationNet and TPN models, we set the training shot be 5 for both 1s and 5s evaluation.

python train.py --model ResNet10 --method GNN --n_shot 5 --name GNN_5s --train_aug
python train.py --model ResNet10 --method TPN --n_shot 5 --name TPN --train_aug

2.Train meta-learning models with feature-wise transformations.

Set method to MatchingNet, RelationNet, ProtoNet, GNN or TPN.

python train_FT.py --model ResNet10 --method GNN --n_shot 5 --name GNN_FWT_5s --train_aug
python train_FT.py --model ResNet10 --method TPN --n_shot 5 --name TPN_FWT --train_aug

3.Explanation-guided train meta-learning models.

Set method to RelationNetLRP or GNNLRP.

python train.py --model ResNet10 --method GNNLRP --n_shot 5 --name GNN_LRP_5s --train_aug
python train.py --model ResNet10 --method RelationNetLRP --n_shot 5 --name RelationNet_LRP --train_aug

4.Train meta-learning models with Adversarial Task Augmentation.

Set method to MatchingNet, RelationNet, ProtoNet, GNN or TPN.

python train_ATA.py --model ResNet10 --method GNN --max_lr 80. --T_max 5 --prob 0.5 --n_shot 5 --name GNN_ATA_5s --train_aug
python train_ATA.py --model ResNet10 --method TPN --max_lr 20. --T_max 5 --prob 0.6 --n_shot 5 --name TPN_ATA --train_aug

To get the results of the iteration goal without the regularization term, with the sample-wise Euclidean distance regularization term and with the maximum mean discrepancy (MMD) distance regularization term, run

python train_NR.py --model ResNet10 --method GNN --max_lr 80. --T_max 5 --n_shot 5 --name GNN_NR_5s --train_aug
python train_Euclid.py --model ResNet10 --method GNN --max_lr 40. --T_max 5 --lamb 1. --n_shot 5 --name GNN_Euclid_5s --train_aug
python train_MMD.py --model ResNet10 --method GNN --max_lr 80. --T_max 5 --lamb 1. --n_shot 5 --name GNN_MMD_5s --train_aug

Evaluation and Fine-tuning

1.Test the trained model on the unseen domains.

  • Specify the target dataset with --dataset: cub, cars, places, plantae, CropDiseases, EuroSAT, ISIC or chestX.
  • Specify the saved model you want to evaluate with --name.
python test.py --dataset cub --n_shot 5 --model ResNet10 --method GNN --name GNN_ATA_5s
python test.py --dataset cub --n_shot 5 --model ResNet10 --method GNN --name GNN_LRP_5s

2.Fine-tuning with linear classifier. To get the results of traditional pre-training and fine-tuning, run

python finetune.py --dataset cub --n_shot 5 --finetune_epoch 50 --model ResNet10 --name Pretrain

3.Fine-tuning the meta-learning models.

python finetune_ml.py --dataset cub --method GNN --n_shot 5 --finetune_epoch 50 --model ResNet10 --name GNN_ATA_5s

Note

cdfsl-ata's People

Contributors

haoqing-wang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

cdfsl-ata's Issues

Plantae Dataset NOT Available?

hello, I tried to get plantae dataset by link you provided in process.py but the website it points responded with "404 not found".
Could you please provide a new link? Thanks!

more explanation of Lemma1 and Lemma 2

Dear authors:
I find no supplementary material for the paper "Cross-Domain Few-Shot Classification via Adversarial Task Augmentation".
The conclusions in Lemma1 and Lemma 2 seem a little hard to deduce by myself. And the given references are full of too much distractors that I can't grasp the key content.
Could you please give me more explanation or some proof material about Lemma1 and Lemma 2?

Request for pretrained model weights.

Hello, your work has inspired me a lot! I hope to further study and re-implement your result, but I found that using different pre-training models (specifically the "399.tar" in your code) had a great impact on the result. However, I used a variety of pre-training weights but did not get the result as in your paper. Could you please public pretrained model weights you used for different methods? Thanks a lot!

About results in table1

Hi, thanks for you sharing codes. Here I have a question about which model exactly you used when test in cross-domain few-shot dataset in table 1. For example, you use the one tested best in mini-test or you have another evaluation to choose the model which can generalize best in all this datasets? Looking forward to your response, thanks.

Dataset Download Error

Hello, the dataset download links of cars, miniImagenet and planate in filelists->process.py are wrong, can you provide a right link? Many thanks.

The results of GNN method

Nice work! However, I cannot reimplement the result in table 1 using the GNN method. I notice that your finetune.py document also covers the augmentation method. I finetuned the GNN checkpoint which was trained by step 1 and get very low accuracy using the finetune.py from https://github.com/IBM/cdfsl-benchmark. So can you provide me some detailes on how to get the GNN results in table 1. Also, I noticed that your code can not cover more shots(5way-20s or 5w-50s), since this may cost more CUDA memerory, have you solved this issue?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.