Giter Site home page Giter Site logo

jiweitian / adversarial-explanations-cifar Goto Github PK

View Code? Open in Web Editor NEW

This project forked from wwoods/adversarial-explanations-cifar

0.0 1.0 0.0 9.92 MB

Code example for the paper, "Adversarial Explanations for Understanding Image Classification Decisions and Improved Neural Network Robustness."

License: MIT License

Python 100.00%

adversarial-explanations-cifar's Introduction

This code demonstrates the techniques from the above paper, a pre-print of which is available on ArXiv. Note that this was not the exact code used in the research, but is a cleaned-up reproduction of the paper's key insights.

Installation

From scratch without a Python environment, installation takes 10-20 minutes. With Python already installed, installation takes only a few minutes.

Install PyTorch, torchvision, and click, potentially via Miniconda with Python 3:

$ conda install -c pytorch pytorch torchvision
$ pip install click

Code was tested with:

  • Python 3.6
  • PyTorch 1.1 + torchvision 0.2.2
  • click 7.0

Any operating system supporting the above libraries should work; we tested using Ubuntu 18.04.

An NVIDIA GPU is not required, but one or more GPUs will greatly accelerate network training.

Usage

This repository contains several pre-built networks, corresponding with the CIFAR-10 networks highlighted in the paper.

The application has two modes: explaining a trained model, and training a model from scratch.

When running the application, the CIFAR-10 dataset will be automatically downloaded via the torchvision library; the desired download location for the CIFAR-10 data must be specified via the environment variable CIFAR10_PATH.

Prebuilt Networks

The repository contains four prebuilt networks:

  1. prebuilt/resnet44-standard.pt: A standard ResNet-44 with no special training.
  2. prebuilt/resnet44-adv-train.pt: A ResNet-44 trained with --adversarial-training.
  3. prebuilt/resnet44-all.pt: A ResNet-44 trained with --robust-additions, --adversarial-training, and --l2-min.
  4. prebuilt/resnet44-robust.pt: A ResNet-44 trained with --robust-additions.

These correspond with, but are not the same as, the networks denoted N1, N2, N3, and N4 in the paper. The training of these networks resulted in the following statistics:

Network Final Training Loss Final Psi Test Accuracy Attack ARA BTR ARA Ship -> Explain     Frog             Cat              Automobile
resnet44-standard.pt 0.0075 N/A 0.9384 0.0013 0.0015 ShipShipExplain FrogFrogExplain CatCatExplain AutomobileAutomobileExplain
resnet44-adv-train.pt 0.5313 N/A 0.8643 0.0100 0.0157 ShipShipExplain FrogFrogExplain CatCatExplain AutomobileAutomobileExplain
resnet44-all.pt 1.4799 14240 0.679 0.0188 0.0414 ShipShipExplain FrogFrogExplain CatCatExplain AutomobileAutomobileExplain
resnet44-robust.pt 1.4799 33778 0.6758 0.0142 0.0395 ShipShipExplain FrogFrogExplain CatCatExplain AutomobileAutomobileExplain

See the paper or the "github-prebuilt-images" command in main.py for additional information on the above table and its images.

Calculate ARA

Attack and BTR ARAs may be calculated via the calculate-ara command. For example, to use a pre-built network with both adversarial training and the robustness additions from the paper:

$ python main.py calculate-ara prebuilt/resnet44-all.pt [--n-images 1000] [--eps 20] [--steps 450] [--momentum 0.9]

Note that arguments in [brackets] are optional. This produces textual output which indicates the calculated attack and BTR ARAs as per Section III.A of the paper. The resulting ARAs for all prebuilt networks are demonstrated in the table above. Calculating both ARAs as in the original paper (default settings) takes around 30 minutes per network, depending on GPU.

Explain

To generate explanations on the first 10 CIFAR-10 testing examples with a trained network, use the explain command. For example, to use a pre-built network with both adversarial training and the robustness additions from the paper:

$ python main.py explain prebuilt/resnet44-all.pt [--eps 0.1]

This will create images in the output/ folder, designed to be viewed in alphabetical order. For example, output/0-cat will contain _input.png, the unmodified input image; _real_was_xxx.png, an explanation using g_{explain+} from the paper on the real class (cat); _second_dog_was_xxx.png, an explanation using g_{explain+} on the most confident class that was not the correct class; and 0_airplane_was_xxx.png, 1_automobile_was_xxx.png, 2_bird_was_xxx.png, ..., 9_truck_was_xxx.png, an explanation targeted at each class of CIFAR-10 as indicated in the filename. In all cases, the _xxx preceding the .png extension indicates the post-softmax confidence of that class on the original image. The images look like this:

                   
_input Input image _real Real target _second Second target
0_airplane Airplane 1_automobile Automobile 2_bird Bird 3_cat Cat 4_deer Deer
5_dog Dog 6_frog Frog 7_horse Horse 8_ship Ship 9_truck Truck

Note that arguments in [brackets] are optional. --eps X specifies that the adversarial explanations should be built with rho=X. The process could be further optimized, but presently takes a minute or two.

Train

To train a new network:

$ python main.py train path/to/model.pt [--adversarial-training] [--robust-additions] [--l2-min]

See python main.py train --help for additional information on these options.

Training time varies greatly based on available GPU(s). With both adversarial training and the robustness additions from the paper, training can take up to several days on a single computer. Turning off either adversarial training or robustness additions would lead to a significant speedup.

At the top of the main.py file are many CAPITAL_CASE variables which may be modified to affect the training process. Their definitions match those in the paper.

adversarial-explanations-cifar's People

Contributors

wwoods avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.