Giter Site home page Giter Site logo

hila-chefer / robustvit Goto Github PK

View Code? Open in Web Editor NEW
121.0 4.0 12.0 16.98 MB

[NeurIPS 2022] Official PyTorch implementation of Optimizing Relevance Maps of Vision Transformers Improves Robustness. This code allows to finetune the explainability maps of Vision Transformers to enhance robustness.

Python 7.34% Jupyter Notebook 92.66%
explainability robustness vision-transformer neurips neurips-2022

robustvit's Introduction

Official PyTorch implementation of Optimizing Relevance Maps of Vision Transformers Improves Robustness [NeurIPS 2022]

This code allows to finetune the explainability maps of Vision Transformers to enhance robustness.

HuggingFace space + Colab notebook to run examples of the finetuned vs the original models:

Open In ColabHugging Face Spaces Open In YouTube

Updates:

06/05/2022 Added a HuggingFace Spaces demo:

Method overview:

The method employs loss functions directly to the explainability maps to ensure that the model is focused mostly on the foreground of the image:

Using a short finetuning process with only 3 labeled examples from 500 classes, our method improves robustness of ViT models across different model sizes and training techniques, even when data augmentations/ regularization are applied.

Model zoo

Below are links to download finetuned models for the base models of ViT AugReg (this is also the model that appears on timm), vanilla ViT, and DeiT. These are also the weights used in our colab notebook.

Path Description
AugReg-B Finetuned ViT Augreg base model.
ViT-B Finetuned vanilla ViT base model.
DeiT-B Finetuned DeiT base model.

Requirements

  • pytorch==1.7.1
  • torchvision==0.8.2
  • timm==0.4.12

Producing Segmentation Data

Using ImageNet-S

To use the ImageNet-S labeled data, download the ImageNetS919 dataset

Using TokenCut for unsupervised segmentation

  1. Clone the TokenCut project
    git clone https://github.com/YangtaoWANG95/TokenCut.git
    
  2. Install the dependencies Python 3.7, PyTorch 1.7.1, and CUDA 11.2. Please refer to the official installation. If CUDA 10.2 has been properly installed:
    pip install torch==1.7.1 torchvision==0.8.2
    
    Followed by:
    pip install -r TokenCut/requirements.txt
    
    
  3. Use the following command to extract the segmentation maps:
    python tokencut_generate_segmentation.py --img_path <PATH_TO_IMAGE> --out_dir <PATH_TO_OUTPUT_DIRECTORY>    
    

Finetuning ViT models

To finetune a pretrained ViT model use the imagenet_finetune.py script. Notice to uncomment the import line containing the pretrained model you wish to finetune.

Usage example:

python imagenet_finetune.py --seg_data <PATH_TO_SEGMENTATION_DATA> --data <PATH_TO_IMAGENET> --gpu 0  --lr <LR> --lambda_seg <SEG> --lambda_acc <ACC> --lambda_background <BACK> --lambda_foreground <FORE>

Notes:

  • For all models we use :
    • lambda_seg=0.8
    • lambda_acc=0.2
    • lambda_background=2
    • lambda_foreground=0.3
  • For DeiT models, a temperature is required as follows:
    • temperature=0.65 for DeiT-B
    • temperature=0.55 for DeiT-S
  • The learning rates per model are:
    • ViT-B: 3e-6
    • ViT-L: 9e-7
    • AR-S: 2e-6
    • AR-B: 6e-7
    • AR-L: 9e-7
    • DeiT-S: 1e-6
    • DeiT-B: 8e-7

Baseline methods

Notice to uncomment the import line containing the pretrained model you wish to finetune in the code.

GradMask

Run the following command:

python imagenet_finetune_gradmask.py --seg_data <PATH_TO_SEGMENTATION_DATA> --data <PATH_TO_IMAGENET> --gpu 0  --lr <LR> --lambda_seg <SEG> --lambda_acc <ACC>

All hyperparameters for the different models can be found in section D of the supplementary material.

Right for the Right Reasons

Run the following command:

python imagenet_finetune_rrr.py --seg_data <PATH_TO_SEGMENTATION_DATA> --data <PATH_TO_IMAGENET> --gpu 0  --lr <LR> --lambda_seg <SEG> --lambda_acc <ACC>

All hyperparameters for the different models can be found in section D of the supplementary material.

Evaluation

Robustness Evaluation

  1. Download the evaluation datasets:

  2. Run the following script to evaluate:

python imagenet_eval_robustness.py --data <PATH_TO_ROBUSTNESS_DATASET> --batch-size <BATCH_SIZE> --evaluate --checkpoint <PATH_TO_FINETUNED_CHECKPOINT>
  • Notice to uncomment the import line containing the pretrained model you wish to evaluate in the code.
  • To evaluate the original model simply omit the checkpoint parameter.
  • For the INet-v2 dataset add --isV2.
  • For the ObjectNet dataset add --isObjectNet.
  • For the SI datasets add --isSI.

Segmentation Evaluation

Our segmentation tests are based on the test in the official implementation of Transformer Interpretability Beyond Attention Visualization.

  1. Download the ImageNet segmentation test set.
  2. Run the following script to evaluate:
PYTHONPATH=./:$PYTHONPATH python SegmentationTest/imagenet_seg_eval.py  --imagenet-seg-path <PATH_TO_gtsegs_ijcv.mat>
  • Notice to uncomment the import line containing the pretrained model you wish to evaluate in the code.

Credits

We would like to sincerely thank the authors for their great works.

Citing our paper

If you make use of our work, please cite our paper:

@inproceedings{
chefer2022optimizing,
title={Optimizing Relevance Maps of Vision Transformers Improves Robustness},
author={Hila Chefer and Idan Schwartz and Lior Wolf},
booktitle={Thirty-Sixth Conference on Neural Information Processing Systems},
year={2022},
url={https://openreview.net/forum?id=upuYKQiyxa_}
}

robustvit's People

Contributors

andreped avatar hila-chefer avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

robustvit's Issues

Trying to reproduce the main table from the paper, the result always off especially the one from Imagenet-A

I'm trying to reproduce this part of the main table.
image
However, the result alsways seems to be off. Especially the score from ImageNet-A, which always lying around 20-22

Here is the result from 3 seeds, using the same version of given dependencies (python 3.8)

  val_top1 val_top5 imagenet-a_top1 imagenet-a_top5 imagenet-r_top1 imagenet-r_top5 sketch_top1 sketch_top5 imagenetv2-matched-frequency-format-val_top1 imagenetv2-matched-frequency-format-val_top5 imagenet-style_top1 imagenet-style_top5
vitb_robustvit_environment_seed_bckg_2.0_fgd_0.3_num_epochs_50_seed_1 81.69 96.078 20.787 43.987 35.233 50.2 35.788 57.684 71.17 90.49 17.842 31.726
vitb_robustvit_environment_seed_bckg_2.0_fgd_0.3_num_epochs_50_seed_27 81.586 96.066 21.147 44.227 35.053 49.967 35.56 57.399 71.28 90.45 17.656 31.644
vitb_robustvit_environment_seed_bckg_2.0_fgd_0.3_num_epochs_50_seed_42 81.598 96.088 20.653 43.933 35.26 49.923 35.825 57.682 71.41 90.29 17.78 31.634

Here is the result from the same 3 seeds, using different version of dependencies (similar results from above)

  val_top1 val_top5 imagenet-a_top1 imagenet-a_top5 imagenet-r_top1 imagenet-r_top5 sketch_top1 sketch_top5 imagenetv2-matched-frequency-format-val_top1 imagenetv2-matched-frequency-format-val_top5 imagenet-style_top1 imagenet-style_top5
vitb_robustvit_seed_bckg_2.0_fgd_0.3_num_epochs_50_seed_1 81.676 96.13 18.36 41.08 34.863 49.9 35.803 57.893 71.31 90.37 17.388 31.048
vitb_robustvit_seed_bckg_2.0_fgd_0.3_num_epochs_50_seed_27 81.63 96.108 20.56 43.587 35.21 50.053 35.827 57.661 71.35 90.35 17.676 31.65
vitb_robustvit_seed_bckg_2.0_fgd_0.3_num_epochs_50_seed_42 81.66 96.108 20.013 42.84 35.27 49.93 35.837 57.832 71.2 90.29 17.708 31.476

Here is the setting I used
{
"data": "Dataset/CV/imagenet/train",
"seg_data": "work/data/general/imagenet-s/ImageNetS919/train-semi-segmentation",
"workers": 4,
"epochs": 50,
"start_epoch": 0,
"batch_size": 8,
"lr": 3e-06,
"momentum": 0.9,
"weight_decay": 0.0001,
"print_freq": 10,
"resume": "",
"evaluate": false,
"pretrained": false,
"world_size": -1,
"rank": -1,
"dist_url": "tcp://224.66.41.62:23456",
"dist_backend": "nccl",
"gpu": 1,
"save_interval": 20,
"num_samples": 3,
"multiprocessing_distributed": false,
"lambda_seg": 0.8,
"lambda_acc": 0.2,
"experiment_folder": "experiment/vitb_robustvit_environment_seed/lr_3e-06_seg_0.8_acc_0.2_bckg_2.0_fgd_0.3_num_epochs_50_seed_1",
"dilation": 0,
"lambda_background": 2.0,
"lambda_foreground": 0.3,
"num_classes": 500,
"temperature": 1.0,
"class_seed": 1, # or 27, 42
"folder_name": "vitb_robustvit_environment_seed"
}

I used model_best.pth.tar to make an evaluation. Anything I should do or try to make the result closer to the paper?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.