Giter Site home page Giter Site logo

rvcl's Introduction

Robustness Verification for Contrastive Learning

This repository contains code for ICML 2022 long presentation "Robustness Verification for Contrastive Learning" by Zekai Wang, Weiwei Liu.

Dependencies

You'll need a working Python environment to run the code. The code is based on:

  • python 3.7
  • torch 1.6.0
  • torchvision 0.7
  • CUDA 10.2
  • numpy == 1.19.2
  • pandas == 1.2.4
  • diffdist == 0.1
  • appdirs == 1.4.4
  • oslo.config == 8.7.0

The recommended way to set up your environment is through the Anaconda Python distribution which provides the conda package manager.

The required dependencies are specified in the file requirements.txt.

Run the following command in the repository folder (where requirements.txt is located) to create a separate environment and install all required dependencies in it:

conda create -n env_name python=3.7   # create new environment
source activate env_name
pip install -r requirements.txt

Code Overview

Our code is based on the several open source codes: RoCL (Kim et al., 2020); auto_LiRPA (Xu et al., 2020); Beta-CROWN (Wang et al., 2021);

.
├── beta_crown                          # Codes for beta-CROWN
├── rocl                                # Codes for ROCL
├── models 			
│   ├── linear_evaluate                 # Checkpoints for linear evaluation
│   └── unsupervised                    # Checkpoints trained by ROCL
├── binary_search_unsupervised.py       # ACR_CL computed by CROWN
├── binary_beta_unsupervised.py         # ACR_CL computed by beta-CROWN
├── binary_search_supervised.py         # ACR_LE
├── unsupervised_verification.py        # Certified instance accuracy
├── unsupervised_adv_train.py           # Contrastive adversarial training using ROCL
├── unsupervised_linear_eval.py         # Linear evaluation
├── robust_test.py                      # Robust test
└── requirements.txt                    # Required packages

Reproducing the results

Checkpoints in this paper are provided in models. Results generated by code are saved in results (will be automatically generated after training).

Compute ACR

Use binary search to calculate ACR_CL on unsupervised contrastive model. Use CROWN to certify cifar10_cnn_4layer_b_adv2 with 100 images on GPU 7, please run:

python binary_search_unsupervised.py --gpuno '7' --mini_batch 10 --ver_total 100 --dataset 'cifar-10' --model 'cnn_4layer_b' --load_checkpoint './models/unsupervised/cifar10_cnn_4layer_b_adv2.pkl'

Use beta-CROWN with timeout 0.3:

python binary_beta_unsupervised.py --gpuno '7' --timeout 0.3 --mini_batch 10 --ver_total 100 --dataset 'cifar-10' --model 'cnn_4layer_b' --load_checkpoint './models/unsupervised/cifar10_cnn_4layer_b_adv2.pkl'

After linear evaluation, we can compute ACR_LE:

python binary_search_supervised.py --gpuno '0' --ver_total 100 --dataset 'cifar-10' --model 'cnn_4layer_b' --load_checkpoint './models/linear_evaluate/cifar10_cnn_4layer_b_adv2.pkl'

Certified instance accuracy

To reproduce our results for certified instance accuracy, for example, runing MNIST CNN-A, with epsilon_test 0.1, epsilon_neg 0.3 and epsilon_train 0.2, please run:

python unsupervised_verification.py --gpuno '0' --model 'mnist_cnn_4layer' --dataset 'mnist' --mode 'incomplete' --ver_total 100 --timeout 180 --epsilon 0.1 --target_eps 0.3 --alpha 0.005 --load_checkpoint './models/unsupervised/mnist_cnn_4layer_a_adv2.pkl'

verified-acc mode, running CIFAR-10 CNN-B, with epsilon_test 4/255, epsilon_neg 16/255 and epsilon_train 4.4/255:

python unsupervised_verification.py --gpuno '0' --model 'cnn_4layer_b' --dataset 'cifar-10' --mode 'verified-acc' --ver_total 100 --timeout 180 --load_checkpoint './models/unsupervised/cifar10_cnn_4layer_b_adv4.pkl'

Note that --mode has three options: complete, incomplete and verified-acc. Setting --mode to incomplete to run incomplete verification with the specific timeout and it will keep tightening the verified lower bound until the timeout threshold is reached, and attempts to find a lower bound as tight as possible; while using verified-acc will stop the algorithm when the property is verified.

Linear evaluation & Robust test

After unsupervised training, to evaluate the quality of the representations learned by contrastive learning, one standard way is to use linear evaluation. For cifar10_cnn_4layer_b_adv2, please run:

python unsupervised_linear_eval.py --gpuno '0' --epoch 100 --batch-size 256 --train_type 'linear_eval' --dataset 'cifar-10' --trans=True --clean=True --model 'cnn_4layer_b' --load_checkpoint './models/unsupervised/cifar10_cnn_4layer_b_adv4.pkl'

Then we can get the accuracy of adversarial samples with different values of attack strength on the supervised downstream task. For example, epsilon_test 0.2, ruuning on mnist_cnn_4layer_a_adv3:

python robust_test.py --gpuno '0' --epsilon 0.2 --alpha 0.02 --test_mode 'unsupervised' --dataset 'mnist' --model 'mnist_cnn_4layer' --load_checkpoint './models/linear_evaluate/mnist_cnn_4layer_a_adv3.pkl'

Contrastive adversarial training

epsilon_train 0.1 on MNIST mnist_cnn_4layer, please run:

python unsupervised_adv_train.py --gpuno '0' --batch-size 256 --model 'mnist_cnn_4layer' --epsilon 0.1 --alpha 0.05 --dataset 'mnist' --train_type 'contrastive' --no_load_weight

epsilon_train 0 on CIFAR-10 cnn_4layer_b:

python unsupervised_adv_train.py --gpuno '0' --batch-size 256 --model 'cnn_4layer' --advtrain_type 'None' --dataset 'cifar-10' --train_type 'contrastive' --no_load_weight

Models provided

model type database epsilon_train model output dimensions file name
contrastive mnist 0 base 100 mnist_base
0.1 mnist_base_adv1
0.3 mnist_base_adv3
0 cnn_4layer_a mnist_cnn_4layer_a
0.1 mnist_cnn_4layer_a_adv1
0.2 mnist_cnn_4layer_a_adv2
0.3 mnist_cnn_4layer_a_adv3
cifar-10 4.4/255 base cifar10_base_adv4
deep cifar10_deep_adv4
cnn_4layer_a cifar10_cnn_4layer_a_adv4
0 cnn_4layer_b cifar10_cnn_4layer_b
2.2/255 cifar10_cnn_4layer_b_adv2
4.4/255 cifar10_cnn_4layer_b_adv4
8.8/255 cifar10_cnn_4layer_b_adv8
4.4/255 50 cifar10_cnn_4layer_b_adv4_dim50
150 cifar10_cnn_4layer_b_adv4_dim150
200 cifar10_cnn_4layer_b_adv4_dim200
250 cifar10_cnn_4layer_b_adv4_dim250
linear evaluate mnist 0 cnn_4layer_a 100 mnist_cnn_4layer_a
0.1 mnist_cnn_4layer_a_adv1
0.2 mnist_cnn_4layer_a_adv2
0.3 mnist_cnn_4layer_a_adv3
cifar-10 0 cnn_4layer_b cifar10_cnn_4layer_b
2.2/255 cifar10_cnn_4layer_b_adv2
4.4/255 cifar10_cnn_4layer_b_adv4
8.8/255 cifar10_cnn_4layer_b_adv8

References

If you find the code useful for your research, please consider citing

@inproceedings{wang2022robustness,
  title = {Robustness Verification for Contrastive Learning},
  author = {Wang, Zekai and Liu, Weiwei},
  booktitle = {International Conference on Machine Learning (ICML)},
  volume = {162},
  pages = {22865--22883},
  year = {2022}
}

and/or the journal extension of this paper, which is accepted by JMLR

@article{wang2024rvcl,
  title = {RVCL: Evaluating the Robustness of Contrastive Learning via Verification},
  author = {Wang, Zekai and Liu, Weiwei},
  journal = {Journal of Machine Learning Research},
  year = {2024}
}

rvcl's People

Contributors

wzekai99 avatar

Stargazers

xaddwell avatar ZouXinn avatar  avatar Chenglin Yu avatar  avatar Jiahao Zhao avatar Jiaxin Zhang avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.