Giter Site home page Giter Site logo

dm-improves-at's Introduction

Better Diffusion Models Further Improve Adversarial Training

Code for the paper Better Diffusion Models Further Improve Adversarial Training (ICML 2023).

Environment settings and libraries we used in our experiments

This project is tested under the following environment settings:

  • OS: Ubuntu 20.04.3
  • GPU: NVIDIA A100
  • Cuda: 11.1, Cudnn: v8.2
  • Python: 3.9.5
  • PyTorch: 1.8.0
  • Torchvision: 0.9.0

Acknowledgement

The adversarial training codes are modifed based on the PyTorch implementation of Rebuffi et al., 2021. The generation codes are modifed based on the official implementation of EDM. For data generation, please refer to edm/README.md for more details.

Requirements

pip install git+https://github.com/fra31/auto-attack
pip install git+https://github.com/ildoonet/pytorch-randaugment
  • Download EDM generated data to ./edm_data. For TinyImageNet, we provide data generated by ImageNet EDM. Since 20M and 50M data files are too large, we split them into several parts:
dataset size link
CIFAR-10 1M npz
CIFAR-10 5M npz
CIFAR-10 10M npz
CIFAR-10 20M part1 part2
CIFAR-10 50M part1 part2 part3 part4
CIFAR-100 1M npz
CIFAR-100 50M part1 part2 part3 part4
SVHN 1M npz
SVHN 50M part1 part2 part3 part4 part5
TinyImageNet 1M npz
  • Merge 20M and 50M generated data:
python merge-data.py

Training Commands

Run train-wa.py for reproducing the results reported in the papers. For example, train a WideResNet-28-10 model via TRADES on CIFAR-10 with the 1M additional generated data provided by EDM (Karras et al., 2022):

python train-wa.py --data-dir 'dataset-data' \
    --log-dir 'trained_models' \
    --desc 'WRN28-10Swish_cifar10s_lr0p2_TRADES5_epoch400_bs512_fraction0p7_ls0p1' \
    --data cifar10s \
    --batch-size 512 \
    --model wrn-28-10-swish \
    --num-adv-epochs 400 \
    --lr 0.2 \
    --beta 5.0 \
    --unsup-fraction 0.7 \
    --aux-data-filename 'edm_data/cifar10/1m.npz' \
    --ls 0.1

Evaluation Commands

The trained models can be evaluated by running eval-aa.py which uses AutoAttack for evaluating the robust accuracy. Run the command (taking the checkpoint above as an example):

python eval-aa.py --data-dir 'dataset-data' \
    --log-dir 'trained_models' \
    --desc 'WRN28-10Swish_cifar10s_lr0p2_TRADES5_epoch400_bs512_fraction0p7_ls0p1'

To evaluate the model on last epoch under AutoAttack, run the command:

python eval-last-aa.py --data-dir 'dataset-data' \
    --log-dir 'trained_models' \
    --desc 'WRN28-10Swish_cifar10s_lr0p2_TRADES5_epoch400_bs512_fraction0p7_ls0p1'

Pre-trained checkpoints

We provide the state-of-the-art pre-trained checkpoints of WRN-28-10 (Swish) and WRN-70-16 (Swish). Refer to argtxt below for specific hyper-parameters. Clean and robust accuracies are measured on the full test set. The robust accuracy is measured using AutoAttack.

dataset norm radius architecture clean robust link
CIFAR-10 8 / 255 WRN-28-10 92.44% 67.31% checkpoint argtxt
CIFAR-10 8 / 255 WRN-70-16 93.25% 70.69% checkpoint argtxt
CIFAR-10 2 128 / 255 WRN-28-10 95.16% 83.63% checkpoint argtxt
CIFAR-10 2 128 / 255 WRN-70-16 95.54% 84.86% checkpoint argtxt
CIFAR-100 8 / 255 WRN-28-10 72.58% 38.83% checkpoint argtxt
CIFAR-100 8 / 255 WRN-70-16 75.22% 42.67% checkpoint argtxt
SVHN 8 / 255 WRN-28-10 95.56% 64.01% checkpoint argtxt
TinyImageNet 8 / 255 WRN-28-10 65.19% 31.30% checkpoint argtxt

For evaluation under AutoAttack:

  1. Download checkpoint to trained_models/mymodel/weights-best.pt
  2. Download argtxt to trained_models/mymodel/args.txt
  3. Run the command:
python eval-aa.py --data-dir 'dataset-data' --log-dir 'trained_models' --desc 'mymodel'

We have uploaded CIFAR-10/CIFAR-100 models to the model zoo of RobustBench. See the tour to evaluate the performance by RobustBench.

References

If you find the code useful for your research, please consider citing

@inproceedings{wang2023better,
  title={Better Diffusion Models Further Improve Adversarial Training},
  author={Wang, Zekai and Pang, Tianyu and Du, Chao and Lin, Min and Liu, Weiwei and Yan, Shuicheng},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2023}
}

and/or our related works

@inproceedings{pang2022robustness,
  title={Robustness and Accuracy Could be Reconcilable by (Proper) Definition},
  author={Pang, Tianyu and Lin, Min and Yang, Xiao and Zhu, Jun and Yan, Shuicheng},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2022}
}
@inproceedings{pang2021bag,
  title={Bag of Tricks for Adversarial Training},
  author={Pang, Tianyu and Yang, Xiao and Dong, Yinpeng and Su, Hang and Zhu, Jun},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2021}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.