Giter Site home page Giter Site logo

jiazhi412 / learning-not-to-learn Goto Github PK

View Code? Open in Web Editor NEW

This project forked from feidfoe/learning-not-to-learn

0.0 0.0 0.0 325 KB

[CVPR2019]Learning Not to Learn : An adversarial method to train deep neural networks with biased data

Shell 0.68% Python 99.32%

learning-not-to-learn's Introduction

Learning-Not-to-Learn

This repository is the official implementation(PyTorch) of https://arxiv.org/abs/1812.10352 which is published in CVPR2019.

Conceptual Illustration

teaser2

Since a neural network dfficiently learns data distribution, a network is likely to learn the bias information; the network can be as biased as the given data.

In the figure above, the points colored with high saturation indicate samples provided during training, while the points with low saturation would appear in test scenario. Although the classifier is well-trained to categorize the training data, it performs poorly with test samples because the classifier learns the latent bias in the training samples.

In this paper (and repo), we propose an iterative algorithm to unlearn the bias information.

Requirements

  1. NVIDIA docker : This code requires nvidia docker to run. If the nvidia docker is installed, the docker image will be automatically pulled. Other required libraries are installed in the docker image.

  2. Pretrained model : If you don't want to use the pretrained parameters, erase the 'use_pretrain' and 'checkpoint' flags from the 'train.sh' script. They are trained without using the unlearning algoorithm. The checkpoint file can be found here

  3. Dataset : Colored-MNIST dataset is constructed by the protocol proposed in https://arxiv.org/abs/1812.10352. They can be found here. More details for the datasets are in dataset directory.

Usage

First, download the pretrained model and dataset. You can provide the directory for the dataset in 'option.py' (data_dir).

To train, modify the path to the pretrained checkpoint in train.sh. Then, run the script with bash.

bash train.sh

For evaluation, run the test.sh script after modifying the paths.

bash test.sh

Results

confmat

Top row denotes the mean colors and their corresponding digit classes in training data. The confusion matrices of baseline model show the network is biased owing to the biased data. On the contrary, the networks trained by our algorithm are not biased to the color although they were trained with the same training data with the baseline

For comparison, we provide the experimental results with colored-MNIST. The table below is an alternative of Fig.4 in the paper.

0.02 0.025 0.03 0.035 0.04 0.045 0.05
Baseline 0.4055 0.4813 0.5996 0.6626 0.7333 0.7973 0.8450
BlineEye 0.6741 0.7123 0.7883 0.8203 0.8638 0.8927 0.9159
Gray 0.8374 0.8751 0.8996 0.9166 0.9325 0.9472 0.9596
Ours 0.8185 0.8854 0.9137 0.9306 0.9406 0.9555 0.9618

Notes

  1. This is an example of unlearning using colored-MNIST data.

  2. The main purpose of this code is to remove color information from extracted features.

  3. Since this algorithm uses adversarial training, it is not very stable. In case you are suffering from the unstability, try pre-train f and g networks with h network detached, so the networks learn the bias. Then, take h network in the training loop (adversarial training).

Contact

Byungju Kim([email protected])

BibTeX for Citation

@InProceedings{Kim_2019_CVPR,
author = {Kim, Byungju and Kim, Hyunwoo and Kim, Kyungsu and Kim, Sungjin and Kim, Junmo},
title = {Learning Not to Learn: Training Deep Neural Networks With Biased Data},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2019}
}

learning-not-to-learn's People

Contributors

feidfoe avatar jiazhi412 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.