Giter Site home page Giter Site logo

ood-bench's Introduction

OoD-Bench

OoD-Bench is a benchmark for both datasets and algorithms of out-of-distribution generalization. It positions datasets along two dimensions of distribution shift: diversity shift and correlation shift, unifying the disjoint threads of research from the perspective of data distribution. OoD algorithms are then evaluated and compared on two groups of datasets, each dominanted by one kind of the distribution shift. See our paper (CVPR 2022 oral) for more details.

This repository contains the code to produce the benchmark, which has two main components:

  • a framework for quantifying distribution shift that benchmarks the datasets, and
  • a modified version of DomainBed that benchmarks the algorithms.

Environment requirements

  • Python 3.6 or above
  • The packages listed in requirements.txt. You can install them via pip install -r requirements.txt. Package torch_scatter may require a manual installation
  • Submodules are added to the path:
export PYTHONPATH="$PYTHONPATH:$(pwd)/external/DomainBed/"
export PYTHONPATH="$PYTHONPATH:$(pwd)/external/wilds/"

Data preparation

Please follow this instruction.

Quantifying diversity and correlation shift

The quantification process consists of three main steps: (1) training an environment classifier, (2) extracting features from the trained classifier, and (3) measuring the shifts with the extracted features. The module ood_bench.scripts.main will handle the whole process for you. For example, to quantify the distribution shift between the training environments (indexed by 0 and 1) and the test environment (indexed by 2) of Colored MNIST with 16 trials, you can simply run:

python -m ood_bench.scripts.main\
       --n_trials 16\
       --data_dir /path/to/my/data\
       --dataset ColoredMNIST_IRM\
       --envs_p 0 1\
       --envs_q 2\
       --backbone mlp\
       --output_dir /path/to/store/outputs

In other cases where pretrained models are used, --pretrained_model_path must be specified. For models in torchvision model zoo, you can pass auto to the argument and the pretrained model will be downloaded automatically.

These two optional arguments are also useful:

  • --parallel: utilize multiple GPUs to conduct the trials in parallel. The maximum number of parallel trials is the number of visible GPUs which can be controlled by setting CUDA_VISIBLE_DEVICES.
  • --calibrate: calibrate the thresholds eps_div and eps_cor so that the estimated diversity and correlation shift are ensured to be within a range close to 0 under i.i.d. condition.

Results

The following results are produced by the scripts under ood_bench/examples, all being automatically calibrated.

Dataset Diversity shift Correlation shift
PACS 0.6715 ± 0.0392* 0.0338 ± 0.0156*
Office-Home 0.0657 ± 0.0147* 0.0699 ± 0.0280*
Terra Incognita 0.9846 ± 0.0935* 0.0002 ± 0.0003*
DomainNet 0.3740 ± 0.0343* 0.1061 ± 0.0181*
WILDS-Camelyon 0.9632 ± 0.1907 0.0000 ± 0.0000
Colored MNIST 0.0013 ± 0.0006 0.5468 ± 0.0278
CelebA 0.0031 ± 0.0017 0.1868 ± 0.0530
NICO 0.0176 ± 0.0158 0.1968 ± 0.0888
ImageNet-A † 0.0435 ± 0.0123 0.0222 ± 0.0192
ImageNet-R † 0.1024 ± 0.0188 0.1180 ± 0.0311
ImageNet-V2 † 0.0079 ± 0.0017 0.2362 ± 0.0607

* averaged over all leave-out-domain-out splits     † with respect to the original ImageNet

Note: there is some difference between the results shown above and those reported in our paper mainly because we reworked the original implementation to ease public use and to improve quantification stability. One of the main improvements is the use of calibration. Previously, the same thresholds that are empirically sound are used across all the datasets studied in our paper (but this may not hold for other datasets).

Extending OoD-Bench

  • New datasets must first be added to external/DomainBed/domainbed/datasets.py as a subclass of MultipleDomainDataset, for example:
class MyDataset(MultipleDomainDataset):
    ENVIRONMENTS = ['env0', 'env1']        # at least two environments
    def __init__(self, root, test_envs, hparams):
        super().__init__()

        # you may change the transformations below
        transform = get_transform()
        augment_scheme = hparams.get('data_augmentation_scheme', 'default')
        augment_transform = get_augment_transform(augment_scheme)

        self.datasets = []                 # required
        for i, env_name in enumerate(self.ENVIRONMENTS):
            if hparams['data_augmentation'] and (i not in test_envs):
                env_transform = augment_transform
            else:
                env_transform = transform
            # load the environments, not necessarily as ImageFolders;
            # you may write a specialized class to load them; the class
            # must possess an attribute named `samples`, a sequence of
            # 2-tuples where the second elements are the labels
            dataset = ImageFolder(Path(root, env_name), transform=env_transform)
            self.datasets.append(dataset)

        self.input_shape = (3, 224, 224,)  # required
        self.num_classes = 2               # required
  • New network backbones must be first added to ood_bench/networks.py as a subclass of Backbone, for example:
class MyBackbone(Backbone):
    def __init__(self, hdim, pretrained_model_path=None):
        self._hdim = hdim
        super(MyBackbone, self).__init__(pretrained_model_path)

    @property
    def hdim(self):
        return self._hdim

    def _load_modules(self):
        self.modules_ = nn.Sequential(
            nn.Linear(3 * 14 * 14, self.hdim),
            nn.ReLU(True),
        )

    def forward(self, x):
        return self.modules_(x)

Benchmarking OoD algorithms

Please refer to this repository.

Citing

If you find the code useful or find our paper relevant to your research, please consider citing:

@inproceedings{ye2022ood,
    title={OoD-Bench: Quantifying and Understanding Two Dimensions of Out-of-Distribution Generalization},
    author={Ye, Nanyang and Li, Kaican and Bai, Haoyue and Yu, Runpeng and Hong, Lanqing and Zhou, Fengwei and Li, Zhenguo and Zhu, Jun},
    booktitle={CVPR},
    year={2022}
}

ood-bench's People

Contributors

dependabot[bot] avatar m-just avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

ood-bench's Issues

wrong proof of theorem 1 in the paper?

I was reading the CVPR paper and confused by the following highlighted statements in proof of theorem 1:

image

Could you explain why there must exist some x ′ != x such that p(x ′ ) < q(x ′ ) with f(x ′ ) = f(x) and g(x ′ ) = g(x)?

Especially, I don't think we can ensure g() will ever map two inputs to the same vector. I feel like some assumptions are missing here?

Questions about Environment requirements

Hi, thanks for your great work!

There are some questions about environment requirements in your benchmark. I use your benchmark to run ERM in several datasets, while all performences are lower 1%-2% than the results in the paper, so I think it may be caused by the environment. Then I try to follow the requirements.txt, but get this:

ERROR: Cannot install -r requirements.txt (line 6) and torch==1.10.2 because these package versions have conflicting dependencies.

The conflict is caused by:
    The user requested torch==1.10.2
    torchvision 0.11.2 depends on torch==1.10.1

I also find the correspondence between torch and torchvision in https://github.com/pytorch/vision#installation. I am not sure if it's a mistake, and I hope to know more detailed virtual environment requirements in your benchmark with OoD algorithms. Thanks a lot.

Questions about the data

Hello, very nice and inspiring paper.
I would like to try your proposed two metrics on my own dataset and well-trained model.

However, I have a problem about this codes in the file "quantify.py":
data = np.load(Path(args.feature_dir, 'data.npz'))
y_p, z_p, y_q, z_q = data['y_p'], data['z_p'], data['y_q'], data['z_q']

How should I calculate y_p, z_p and y_q, z_q based on my own dataset?
Thank you very much.

Questions about model selection in CelebA_Blond

Excuse me. I have some questions about model selection in CelebA_Blond.

In your paper CelebA uses test-domain validation, that means we choose the model which gets best 'env2_out_acc'. And the distribution of test environment is like this:

                  Male       Female
blond             362         362
not blond         362         362

In the experiment holdout_fraction is set to 0.1. However, the test environment is randomly splited to 9:1. I think it may cause the distribution of two splited dataset to be inconsistent. For example:

env2_in           Male       Female
blond             222         360
not blond         360         362
env2_out          Male       Female
blond             140         2
not blond         2           0

I'm not sure if this will make a difference, or it just doesn't matter.

Looking for you reply, thanks.

Questions about the results and executions

Hello! Great paper! congrats!
I'm trying to use your work in my masters research and having some questions:

  1. Could you release the results of every split for PACS, OfficeHome and Domain Net? My results are close to yours in average (1~2% difference at most), but I want to see if any of the splits results have a much higher difference.

  2. I'm was having a problem with the dataloader: some times (at random) it gets stuck trying to get next data and I have to terminate the execution and start again. Have you seen this king of problem? I found some issues saying that it may be associated with parallelism when n_workers > 0. When I set n_workers = 0, it worked, but for some datasets like domain net, it took 9897minutes to execute for the Sketch split, which is a lot.

  3. Do you have the times of execution for each of your run.sh examples? I would like to compare with the times I'm having running them.

  4. Have you tried to update torch, torch vision and other libraries to evaluate if the results change? In the newer version of pytorch (1.12.0), I could execute with parallelism, but in the paper version, I was getting the random dataloader stuck problem, but the results change a little.
    Results in torch 1.10: div 1.0367 +/- 0.0536 | cor 0.0002 +/- 0.0003
    Results in torch 1.12: div 0.9695 +/- 0.0159 | cor 0.0000 +/- 0.0000

I'm running on NVIDIA RTX 5000 16Gib, using docker versions of pytorch 1.10.2 and 1.12.01.
For each version, I installed the respective version of torchvision following their github https://github.com/pytorch/vision#installation

I know its a lot of information, I'll be really happy if we could talk and hope to hear from you soon.
Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.