Giter Site home page Giter Site logo

mlo-lab / uncertainty_estimates_via_bvd Goto Github PK

View Code? Open in Web Editor NEW
9.0 3.0 0.0 1.77 MB

The official source code to: Uncertainty Estimates of Predictions via a General Bias-Variance Decomposition (AISTATS'23)

License: MIT License

Jupyter Notebook 96.98% Python 3.02%
bias-variance bregman-divergence uncertainty-estimation bregman-information pytorch

uncertainty_estimates_via_bvd's Introduction

Uncertainty Estimates of Predictions via a General Bias-Variance Decomposition

The official source code to Uncertainty Estimates of Predictions via a General Bias-Variance Decomposition (AISTATS'23).

Code for Bregman divergence and Bregman Information generated by LogSumExp

For quality of life, the following Pytorch implementation should easily work via Copy-Pasting. It differs slightly from the experiment code, since there were unexpectedly positive mathematical results after the experiments had finished.

import torch

def BI_LSE(zs, axis=0, class_axis=-1):
    '''
    Bregman Information of random variable Z generated by G = LSE
    BI_G [ Z ] = E[ G( Z ) ] - G( E[ Z ] )
    We estimate with dataset zs = [Z_1, ..., Z_n] via
    1/n sum_i G( Z_i ) - G( 1/n sum_i Z_i )
    
    Arg zs: Tensor with shape length >= 2
    Arg axis: Axis of the samples to average over
    Arg class_axis: Axis of the class logits
    Output: Tensor with shape length reduced by two
    '''
    E_of_LSE = zs.logsumexp(axis=class_axis).mean(axis)
    LSE_of_E = zs.mean(axis).unsqueeze(axis).logsumexp(axis=class_axis).squeeze(axis)
    return E_of_LSE - LSE_of_E

def D_LSE(a, b):
    '''
    Bregman divergence generated by G = LSE
    D_G (a, b) = G(b) - G(a) - < grad G(a), b - a >
    We assume the classes are in the last axis.
    
    Arg a: Tensor with shape (batch size, classes)
    Arg b: Tensor with shape (batch size, classes)
    Output: Tensor with shape (batch size,)
    '''
    def inner_product(a, b):
        ''' Batch wise inner product of last axis in a and b'''
        n_size, n_classes = a.shape
        return torch.bmm(a.view(n_size, 1, n_classes), b.view(n_size, n_classes, 1)).squeeze(-1).squeeze(-1)

    G_of_a = a.logsumexp(axis=-1)
    G_of_b = b.logsumexp(axis=-1)
    grad_G_of_a = a.softmax(axis=-1)
    return G_of_b - G_of_a - inner_product(grad_G_of_a, b - a)

Example usage:

n_samples = 100
p_classes = 10
batch_size = 5
# sample logit tensors
zs_1 = torch.randn((n_samples, batch_size, p_classes))
zs_2 = torch.randn((batch_size, p_classes))

# Arg (default): First axis samples, middle axis batch size, last axis classes
print(BI_LSE(zs_1))

# Args (fixed): First axis batch size, last axis classes
print(D_LSE(zs_1.mean(0), zs_2))

Experiments

All experiments are run and plotted in Jupyter notebooks. Installing the full environment might only be necessary for the CIFAR10 and ImageNet experiments.

Environment Setup

The following allows to create and to run a python environment with all required dependencies using miniconda:

conda env create -f environment.yml
conda activate UQ

Iris Confidence Regions (Figure 4)

The experiments for the confidence regions of the Iris classifier (Figure 4) can be found in CR_iris.ipynb. They are done via Pytorch and are computationally light-weight (should run on any laptop).

Classifiers on Toy Simulations (Figure 5 & 6)

We train and evaluate SK-Learn classifiers on toy simulations in toy_simulations.ipynb. They are feasible to run locally on a laptop (they should finish in less than an hour).

ResNet on CIFAR10 and CIFAR10-C (Figure 1)

These experiments can be found in CIFAR10_ResNet.ipynb. They are expensive to evaluate and require a basic GPU (they needed >1 hour on a single RTX5000). The weight initialization ensembles are locally trained. The data is supposed to be stored in ../data/.

ResNet on ImageNet and ImageNet-C (Figure 7)

These experiments can be found in ImageNet_ResNet.ipynb. They are very expensive to evaluate and require an advanced GPU (they needed >10 hours on a single RTX5000). The weight initialization ensembles are taken from https://github.com/SamsungLabs/pytorch-ensembles. For this, download the folder deepens_imagenet from here and extract it into a folder ../saved_models/. This can be either done manually or by

pip3 install wldhx.yadisk-direct

% if folder does not exist yet
mkdir ../saved_models

% ImageNet
curl -L $(yadisk-direct https://yadi.sk/d/rdk6ylF5mK8ptw?w=1) -o ../saved_models/deepens_imagenet.zip
unzip deepens_imagenet.zip 

The data is supposed to be stored in ../data/.

Attribution

Citation

@misc{gruber2023uncertainty,
      title={Uncertainty Estimates of Predictions via a General Bias-Variance Decomposition}, 
      author={Sebastian G. Gruber and Florian Buettner},
      year={2023},
      eprint={2210.12256},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

uncertainty_estimates_via_bvd's People

Contributors

sebggruber avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

uncertainty_estimates_via_bvd's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.