Giter Site home page Giter Site logo

vita-group / alleviate-robust-overfitting Goto Github PK

View Code? Open in Web Editor NEW
44.0 9.0 5.0 661 KB

[ICLR 2021] "Robust Overfitting may be mitigated by properly learned smoothening" by Tianlong Chen*, Zhenyu Zhang*, Sijia Liu, Shiyu Chang, Zhangyang Wang

License: MIT License

Python 100.00%
adversarial-robustness overfitting smoothing generalization-ability

alleviate-robust-overfitting's Introduction

Robust Overfitting may be mitigated by properly learned smoothening

License: MIT

Code for this paper Robust Overfitting may be mitigated by properly learned smoothing

Tianlong Chen*, Zhenyu Zhang*, Sijia Liu, Shiyu Chang, Zhangyang Wang

Overview

To alleviate the intriguing problem of robust overfitting, we investigate two empirical means to inject more learned smoothening during adversarial training (AT): one leveraging knowledge distillation (KD) and self-training to smooth the logits, the other performing stochastic weight averaging (SWA) to smooth the weights

Highlights:

  • Smoothening mitigates robust overfitting: After adopting KD and SWA in AT, we mitigated robust overfitting and achieve a better trade-off between standard test accuracy and robustness than early stopping.
  • Rich ablation experiments: We conducted plenty of ablation experiments and visualizations to investigate the reason why robust overfitting may be mitigated by these smoothening approaches.

Experiment Results

Training with KD and SWA to mitigate robust overfitting

Flattening the rugged input space

Prerequisites

  • pytorch 1.5.1
  • torchvision 0.6.1
  • advertorch 0.2.3

Usage

Standard Training:

python -u main_std.py \
	--data [dataset direction] \ 
	--dataset cifar10 \
	--arch resnet18 \
	--save_dir std_cifar10_resnet18 

PGD Adversarial Training:

python -u main_adv.py \
	--data [dataset direction] \ 
	--dataset cifar10 \
	--arch resnet18 \
	--save_dir AT_cifar10_resnet18 

Adversarial Training with KD&SWA:

python -u main_adv.py \
	--data [dataset direction] \ 
	--dataset cifar10 \
	--arch resnet18 \
	--save_dir KDSWA_cifar10_resnet18 \
	--swa \
	--lwf \
	--t_weight1 pretrained_models/cifar10_resnet18_std_SA_best.pt \
	--t_weight2 pretrained_models/cifar10_resnet18_adv_RA_best.pt

Testing under PGD-20 Linf eps=8/255 :

python -u main_adv.py \
	--data [dataset direction] \
	--dataset cifar10 \
	--arch resnet18 \
	--eval \
	--pretrained pretrained_models/**.pt \
	--swa #if test with swa_model

Citation

@inproceedings{
	chen2021robust,
	title={Robust Overfitting may be mitigated by properly learned smoothening},
	author={Tianlong Chen and Zhenyu Zhang and Sijia Liu and Shiyu Chang and Zhangyang Wang},
	booktitle={International Conference on Learning Representations},
	year={2021},
	url={https://openreview.net/forum?id=qZzy5urZw9}
}

alleviate-robust-overfitting's People

Contributors

tianlong-chen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

alleviate-robust-overfitting's Issues

Pretrained models

Hi, where are the pretrained models? Can you provide?

--t_weight1 pretrained_models/cifar10_resnet18_std_SA_best.pt
--t_weight2 pretrained_models/cifar10_resnet18_adv_RA_best.pt

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.