Giter Site home page Giter Site logo

cross_modal_adaptation's Introduction

Cross-Modal Adaptation with Multimodal Models

This repository contains code for CVPR 2023 paper Multimodality Helps Unimodality: Cross-Modal Few-Shot Learning with Multimodal Models. It contains the code for vision-language adaptation on 11 target image classification datasets and experiments on ImageNet-ESC benchmark for audiovisual few-shot learning.

Motivation Figure

Environment Configuration

We recommend to install the environment through conda and pip. You should make a new environment with python>=3.9, for example:

conda create -n cross_modal python=3.9

Next, you can download pytorch from official site, for example:

conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch

Next, run pip install -r requirements.txt in this repo to install a few more packages required by CLIP.

Dataset Installation

Follow DATASETS.md to install the downstream datasets. We use the CoOp split of data (including the few-shot splits for seed 1-3, except for ImageNet) to ensure a fair comparison.

Model Training

Method Figure

Path Configuration

You should modify the paths to dataset and results at engine/config/default.py, e.g., you may want to modify the DATA_DIR to where you install all the datasets. Default is to save under the current folder.

Sample few-shot train/val split

We already provide few-shot train/val splits for seed (1, 2, 3), and shots (1, 2, 4, 8, 16) in indices/, as if they were generated from the original CoOp codebase (except for ImageNet that we sampled our own split). If you just intend to follow CoOp's protocol, you may proceed to the next step.

If you want to generate more splits with different shots and seeds, please refer to [few_shot_split.py]. For example, to generate a few-shot train/val split for imagenet with seed 6, you may run the below script:

python few_shot_split.py --dataset imagenet --train-shot 1 --seed 6

Feature Extraction

For all the linear/partial/adapter experiments, we pre-extract the features to speed up training time. You can use features.py to pre-extract image and text features from a frozen CLIP model. For example, run the below script to pre-extract last layer features for imagenet-16-shot with RN50 backbone. Note that these features are not L2-normalized yet:

python features.py --dataset imagenet --train-shot 16 --seed 1 --clip-encoder RN50 --image-layer-idx 0 --text-augmentation hand_crafted --image-augmentation none --image-views 0

To reproduce the experiments in main paper (with flipped view and hand-crafted template), you may run the bash script below to extract for all 11 datasets and 3 seeds. (Tip: You can also parallelize the scripts in features.sh to speed up):

bash features.sh

Few-Shot Training

To perform cross-modal or uni-modal training, please refer to train.py. For example, if you want to run cross-modal adaptation for imagenet-16-shot, you can run:

python train.py --modality cross_modal --classifier_head linear --classifier_init zeroshot --logit 4.60517 --hyperparams linear --dataset imagenet --train-shot 16 --seed 1 --clip-encoder RN50 --image-layer-idx 0 --text-augmentation hand_crafted --image-augmentation flip --image-views 1

To reproduce the numbers in main paper, please run linear_probe.sh, partial_finetuning.sh, and adapter.sh. To speed up the experiments, you can run scripts in parallel if you have multiple GPUs. To check all the supported argparse arguments, please see this file.

Evaluation

To perform hyperparameter search with few-shot validation set performance, we provide eval.py. For example, to collect results of cross-modal linear probing:

python eval.py --mode linear --modality cross_modal --classifier_init zeroshot --clip-encoder RN50 --text-augmentation hand_crafted --image-augmentation flip --image-views 1

Average over 11 datasets

To compute average over 11 datasets, for example for the script above, you may run the following script to generate a csv file:

python average.py --name all_RN50_linear_hand_crafted_flip_1_cross_modal_text_wiseft_False

Test-time robustness to domain shift (ImageNet)

To reproduce the domain shift experiments in paper please run domain_shift.py. All the argparse arguments follow that of train.py:

python domain_shift.py --modality cross_modal --classifier_head linear --classifier_init zeroshot --logit 4.60517 --hyperparams linear --dataset imagenet --train-shot 16 --clip-encoder RN50 --image-layer-idx 0 --text-augmentation hand_crafted --image-augmentation none --seed 1

After training, to evaluate for 3 seeds, you can use eval_domain_shift.py:

python eval_domain_shift.py --mode linear --modality cross_modal --classifier_init zeroshot --clip-encoder RN50 --text-augmentation hand_crafted --image-augmentation none

You can get Cross-Modal WiSE-FT result via enabling the wise_ft flag:

python eval_domain_shift.py --mode linear --modality cross_modal --classifier_init zeroshot --clip-encoder RN50 --text-augmentation hand_crafted --image-augmentation none --wise_ft True

ImageNet-ESC Experiments

AudioCLIP feature extraction for ESC50

We follow the instruction offered in official AudioCLIP codebase to extract the feature. We notice that the AudioCLIP head does not produce good audio features with eval() mode, so we extract the features in train() mode with a batch size of 10. The ESC-50 dataset recommended 5-fold cross validation because the audio samples can be correlated within each of the 5 folds, so we follow the practice to offer 5 train/test split of ESC-50. For each split, one fold is used as trainset (400 audio samples per fold), and the rest 4 folds are used for evaluation.

To extract features (assuming you followed the dataset installation instruction) to ESC-50 folder, please run the script below. Before you run this, please modify the PATH variable if you install ESC-50 somewhere else.

cd audioclip/
python audio_features.py

Training on ImageNet-ESC

To reproduce all the experiments in paper with 1/2/4 shot classification on both image and audios, please run:

python imagenet_esc.py

Citation

If you use this code in your research, please kindly cite the following papers:

@misc{lin2023crossmodal,
  title={Multimodality Helps Unimodality: Cross-Modal Few-Shot Learning with Multimodal Models},
  author={Lin, Zhiqiu and Yu, Samuel and Kuang, Zhiyi and Pathak, Deepak and Ramanan, Deva},
  year={2023},
  eprint={2301.06267},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}

cross_modal_adaptation's People

Contributors

linzhiqiu avatar samuelyu2002 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cross_modal_adaptation's Issues

ResNet50

Where are the clip-resnet50 raw weights please?

problem when running train.py

when running train.py, it occurs error. I need some help, pls

python train.py --modality cross_modal --classifier_head linear --classifier_init zeroshot --logit 4.60517 --hyperparams linear --dataset safety --train-shot 16 --seed 1 --clip-encoder RN50 --image-layer-idx 0 --text-augmentation hand_crafted --image-augmentation flip --image-views 1
/root/autodl-tmp/cross_modal_adaptation-main/engine/clip/clip.py:23: UserWarning: PyTorch version 1.7.1 or higher is recommended
  warnings.warn("PyTorch version 1.7.1 or higher is recommended")
Setting fixed seed: 1
Valid batch sizes: 2/2
Starting: optim_adamw-lr_0.001-wd_0.0-bs_8-iters_12800 1/12
Building text dataset per class...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:00<00:00, 28181.98it/s]
Traceback (most recent call last):
  File "/root/autodl-tmp/cross_modal_adaptation-main/train.py", line 564, in <module>
    main(args)
  File "/root/autodl-tmp/cross_modal_adaptation-main/train.py", line 451, in main
    result_dict = train(
  File "/root/autodl-tmp/cross_modal_adaptation-main/train.py", line 162, in train
    image, image_label = next(image_loader_iter)
  File "/root/miniconda3/envs/cross_modal/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 652, in __next__
    data = self._next_data()
  File "/root/miniconda3/envs/cross_modal/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 694, in _next_data
    data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
  File "/root/miniconda3/envs/cross_modal/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 65, in pin_memory
    return type(data)([pin_memory(sample, device) for sample in data])  # type: ignore[call-arg]
  File "/root/miniconda3/envs/cross_modal/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 65, in <listcomp>
    return type(data)([pin_memory(sample, device) for sample in data])  # type: ignore[call-arg]
  File "/root/miniconda3/envs/cross_modal/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 50, in pin_memory
    return data.pin_memory(device)
RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned

About Cross-modal ensembles

Hi, I have a question about Cross-modal ensembles. I read the paper, but I couldn't find the part of Cross-modal ensembles in this repository. So I guess it's not used in this codes. Because also it uses cross-entropy loss not softmax loss. Do you know something? Thank you in advance!

Where is the path to the“ text_encoder” defined

When running domain_shift.py, I get the following error:
Traceback (most recent call last):
File "\cross_modal_adaptation-main\domain_shift.py", line 425, in
main(args)
File "\cross_modal_adaptation-main\domain_shift.py", line 215, in main
head, num_classes, in_features = make_classifier_head(
TypeError: make_classifier_head() missing 1 required positional argument: 'text_encoder'
I want you to tell me the exact location of the text_encoder path so that I can modify the path to avoid the error

Question about the applicability of cross-modal adaptation in normal classification settings

Hi! Thanks for the interesting insights and impressive work.
I understand that your paper focuses on few-shot learning experiments, where you use different modalities as additional few-shot samples.
I wonder if your method can also improve the performance of normal classification tasks, where you have enough labeled data for each modality. Do you think cross-modal information can still help the model learn better features and generalize better?
I would appreciate it if you could share your insights or point me to some relevant references. Thank you very much for your time and attention.

About image encoder.

A very good job!
Since my image data has a very large number of bands, I have to use my own image encoder. The text encoder still comes with clip. But the results were so-so. Why? Any good suggestions?
Thank you very much!

Validation set size

Hi,

Thank you for the very interesting repository and congrats for the paper!
I have a question regarding the number of images on validation set for the hyperparameters selection. How many images of class do you use?
Thank you in advance,

Best regards,

Ask for inference code

Great job for clip fine-tune ! Is there end-to-end inference code available to directly use the fine-tuned model for downstream tasks?

关于重复进行encode的问题

您好,我仔细看了您的代码。我发现您在feature.py文件中,对于image和text进行了extract feature,并保存到为文件。但是在train文件中,您将这些feature做成了tensor dataset。然后,这些tensor dataset制作得到的loader送到了train函数中,在train函数中,您又使用image encoder和text encoder对这些feature进行了encode的操作,然后进行了梯度更新。这样不就是对于image数据进行了两次的encode吗?
不知道我对代码的解读对不对,如果确实是这样的操作,请问为什么要这么做呢?

关于同一 batch cross modal输入的问题

    if image_feature is not None and text_feature is not None:
        feature = torch.cat([image_feature, text_feature], dim=0)
        label = torch.cat([image_label, text_label], dim=0)
    elif image_feature is not None:
        feature = image_feature
        label = image_label
    elif text_feature is not None:
        feature = text_feature
        label = text_label
    else:
        raise ValueError("Both image_feature and text_feature are None")

实际上我们的batch大小等于config里的2倍,这样理解对吗

Where do the results cited in the paper comparing to SOTA methods come from?

I read your paper carefully. When citing other methods for performance comparison (Table 1 and Table 10 in the original paper), I did not find the results of the other methods you cited in the original paper, such as CoOp. As a result, I only found in the original paper of CoOp that they gave a line graph on 11 datasets and a line graph of the average results of these 11 datasets. Can you tell me how you got these specific values? Thank you so much.

Problems with training on ImageNet-ESC

After successfully running audio_features.py the features.pt is generated, but when running imagenet_esc.py later it says that the required classname.pth could not be found, may I know how to generate the corresponding file?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.