Giter Site home page Giter Site logo

lc's Introduction

Avoiding spurious correlations via logit correction

This repository provides the official PyTorch implementation of the following paper:

Avoiding spurious correlations via logit correction
Sheng Liu (NYU), Xu Zhang (Amazon), Nitesh Sekhar (Amazon), Yue Wu (Amazon), Prateek Singhal (Amazon), Carlos Fernandez-Granda (NYU) ICLR 2023

Paper: Arxiv

Abstract: Empirical studies suggest that machine learning models trained with empirical risk minimization (ERM) often rely on attributes that may be spuriously correlated with the class labels. Such models typically lead to poor performance during inference for data lacking such correlations. In this work, we explicitly consider a situation where potential spurious correlations are present in the majority of training data. In contrast with existing approaches, which use the ERM model outputs to detect the samples without spurious correlations and either heuristically upweight or upsample those samples, we propose the logit correction (LC) loss, a simple yet effective improvement on the softmax cross-entropy loss, to correct the sample logit. We demonstrate that minimizing the LC loss is equivalent to maximizing the group-balanced accuracy, so the proposed LC could mitigate the negative impacts of spurious correlations. Our extensive experimental results further reveal that the proposed LC loss outperforms state-of-the-art solutions on multiple popular benchmarks by a large margin, an average 5.5% absolute improvement, without access to spurious attribute labels. LC is also competitive with oracle methods that make use of the attribute labels.

Pytorch Implementation

Installation

Clone this repository.

git clone --recursive https://github.com/shengliu66/LC-private
cd Logit-Correction
pip install -r requirements.txt

Datasets

Running LC

Waterbird

python train.py --dataset waterbird --exp=waterbird_ours --lr=1e-3 --weight_decay=1e-4 --curr_epoch=50 --lr_decay_epoch 50 --use_lr_decay --lambda_dis_align=2.0 --ema_alpha 0.5 --tau 0.1 --train_ours --q 0.8 --avg_type batch --data_dir /dir/to/data/

CivilComments

python train.py --dataset civilcomments --exp=civilcomments_ours --lr=1e-5 --q 0.7 --log_freq 400 --valid_freq 400 --weight_decay=1e-2 --curr_step 400 --use_lr_decay --num_epochs 3 --lr_decay_step 4000 --lambda_dis_align=0.1 --ema_alpha 0.9 --tau 1.0 --avg_type mv_batch --train_ours --data_dir /dir/to/data/

Monitoring Performance

We use Weights & Biases to monitor training, you can follow the doc here to install and log in to W&B, and add the argument --wandb .

Evaluate Models

In order to test our pretrained models, run the following command.

python test.py --pretrained_path=<path_to_pretrained_ckpt> --dataset=<dataset_to_test> 

Citations

BibTeX

@Inproceedings{
Liu2023, 
author = {Sheng Liu and Xu Zhang and Nitesh Sekhar and Yue Wu and Prateek Singhal and Carlos Fernandez-Granda}, 
title = {Avoiding spurious correlations via logit correction}, 
year = {2023}, 
url = {https://www.amazon.science/publications/avoiding-spurious-correlations-via-logit-correction}, 
booktitle = {ICLR 2023}, 
}

Contact

Sheng Liu ([email protected])

Acknowledgments

This work was mainly done the authors was doing internship at Amazon Science. Our pytorch implementation is based on Disentangled. We would like to thank for the authors for making their code public.

lc's People

Contributors

shengliu66 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.