Giter Site home page Giter Site logo

cambridge-mlg / clue Goto Github PK

View Code? Open in Web Editor NEW
35.0 11.0 4.0 632 KB

Code for the paper "Getting a CLUE: A Method for Explaining Uncertainty Estimates"

License: MIT License

Python 99.89% TeX 0.11%
interpretability research uncertainty deep-learning neural-networks aleatoric-uncertainty epistemic-uncertainty machine-learning publication paper

clue's Introduction

Getting a CLUE: A Method for Explaining Uncertainty Estimates, ICLR 2021

motivation image.

Both uncertainty estimation and interpretability are important factors for trustworthy machine learning systems. However, there is little work at the intersection of these two areas. We address this gap by proposing a novel method for interpreting uncertainty estimates from differentiable probabilistic models, like Bayesian Neural Networks (BNNs). Our method, Counterfactual Latent Uncertainty Explanations (CLUE), indicates how to change an input, while keeping it on the data manifold, such that a BNN becomes more confident about the input’s prediction. We validate CLUE through 1) a novel framework for evaluating counterfactual explanations of uncertainty, 2) a series of ablation experiments, and 3) a user study. Our experiments show that CLUE outperforms baselines and enables practitioners to better understand which input patterns are responsible for predictive uncertainty.

You can find the paper at: openreview.net/pdf?id=XSLF1XFq5h or arXiv

Index:
Dependencies
Loading PreTrained Models
Training Models
Generating CLUEs
Artificial Data Experiments
Ablation Experiments
Citation

Dependencies

python 2.7.17
torch 1.3.1
torchvision 0.4.2
urllib3 1.25.3
scikit-image 0.14.2
scikit-learn 0.20.3
scipy 1.2.1
numpy 1.16.5
matplotlib 2.2.5

Lime and Shap are only used for experiments from appendix F:
lime 0.1.1.37
shap 0.28.5

Downloading notebooks and PreTrained models

Due to their large size, our pre-trained models and notebooks which we use to run all experiments, are available at: https://drive.google.com/file/d/1wZqEUn0TylpSEpKOTRD4kIVOAa3iNb0T/view?usp=sharing\ The zip file (4.6GB) will create a folder called notebooks. It contains the notebooks described in the rest of this README as well as the datasets we use and our pre-trained models.

We recommend replicating our experiments using the provided pre-trained models. These should be stored in the notebooks/saves folder. They will be loaded automatically by the experiment notebooks.

Training Models

The BNN folder contains the source code for BNNs trained with scale adapted SG-HMC. (For different BNN inference methods, please refer to JavierAntoran/Bayesian-Neural-Networks) The NN folder contains the source for training regular NNs, used in appendix H1. The VAE and VAEAC folders contain the source for training VAEs and VAEACs respectively.

We use the notebooks contained in notebooks/train_models to train all models. Running the notebooks as provided will train and save models in the notebooks/saves directory. Models will be saved with the correct names to be loaded directly by experiment notebooks.

We also include scripts to train VAEs of multiple depths for our experiments from section 5.3. These are found in train_scripts/train_VAEs.

Experiments

The interpret folder contains the code for CLUE, U-FIDO, our computational evaluation framework and auxiliary functions.

Generating Counterfactuals

MNIST CLUEs can be generated using notebooks/experiments/CLUE_pythonclass_testing_MNIST.ipynb
U-FIDO counterfactuals on MNIST can be generated using notebooks/experiments/MNIST_FIDO.ipynb
COMPAS counterfactuals can be generated using notebooks/experiments/CLUE_pythonclass_wiewing_Tabular_COMPAS.ipynb
LSAT counterfactuals can be generated using notebooks/experiments/CLUE_pythonclass_wiewing_Tabular_LSAT.ipynb\

CLUEs on modified MNIST (479)

To generate CLUEs for our modified MNIST dataset, used in the additional human subject experiments from appendix J use notebooks/human_experiments/human_MNIST_479.ipynb.

Ablation Experiments

Tabular ablation experiments are contained in notebooks/experiments/art_data/ablation
MNIST ablation experiments are contained in notebooks/experiments/art_data/MNIST_ablation

Artificial Data Experiments

notebooks/experiments/art_data/Artificial_data_Tabular.ipynb Contains the code to reproduce our computational evaluation framework results for LSAT COMPAS, Credit and Wine.

notebooks/experiments/art_data/Artificial_data_MNIST.ipynb Contains the code to reproduce our computational evaluation framework results for MNIST.

Human Experiment Notebook

The notebook used to generate the questions for the user studies is found at notebooks/human_experiments/Tabular_Data_Main_Survey.ipynb

Citation

If this code was useful, please cite

Javier Antorán, Umang Bhatt, Tameem Adel, Adrian Weller & José Miguel Hernández-Lobato (2021). Getting a CLUE: A Method for Explaining Uncertainty Estimates. [bibtex]

@inproceedings{
antoran2021getting,
title={Getting a {\{}CLUE{\}}: A  Method for Explaining Uncertainty Estimates},
author={Javier Antoran and Umang Bhatt and Tameem Adel and Adrian Weller and Jos{\'e} Miguel Hern{\'a}ndez-Lobato},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=XSLF1XFq5h}
}

clue's People

Contributors

javierantoran avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

clue's Issues

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [200, 2]], which is output 0 of AsStridedBackward0, is at version 1401; expected version 1400 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Hi!
I'm currently trying to implement CLUE in Python3. I ran into this error when using CLUE for the lsat example (CLUE_pythonclass_wiewing_Tabular_LSAT.ipynb) you provided.

It seems that the backward-pass in clue.py changes the value of a Tensor inplace, which results in the following message.

Did you encounter a similar problem?

Here is the full Traceback:

/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/autograd/__init__.py:266: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 2235, in <module>
    main()
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 2217, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1527, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1534, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/lukasscholz/repositorys/counterfactuals/examples/test.py", line 210, in <module>
    main()
  File "/Users/lukasscholz/repositorys/counterfactuals/examples/test.py", line 189, in main
    z_vec, x_vec, uncertainty_vec, epistemic_vec, aleatoric_vec, cost_vec, dist_vec = CLUE_explainer.optimise(
  File "/Users/lukasscholz/repositorys/counterfactuals/counterfactual_xai/methods/clue.py", line 191, in optimise
    total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, x, preds = self.uncertainty_from_z()
  File "/Users/lukasscholz/repositorys/counterfactuals/counterfactual_xai/methods/clue.py", line 124, in uncertainty_from_z
    mu_vec, std_vec = self.BNN.sample_predict(to_BNN, num_samples=0, grad=True)
  File "/Users/lukasscholz/repositorys/counterfactuals/counterfactual_xai/utils/clue/bnn/gaussian_bnn.py", line 134, in sample_predict
    mu, std = self.model(x)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/counterfactual_xai/utils/clue/gaussian_mlp.py", line 27, in forward
    x = self.block(x)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:118.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1534, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/lukasscholz/repositorys/counterfactuals/examples/test.py", line 210, in <module>
    main()
  File "/Users/lukasscholz/repositorys/counterfactuals/examples/test.py", line 189, in main
    z_vec, x_vec, uncertainty_vec, epistemic_vec, aleatoric_vec, cost_vec, dist_vec = CLUE_explainer.optimise(
  File "/Users/lukasscholz/repositorys/counterfactuals/counterfactual_xai/methods/clue.py", line 194, in optimise
    objective.sum(dim=0).backward()  # backpropagate
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [200, 2]], which is output 0 of AsStridedBackward0, is at version 1401; expected version 1400 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Algorithm 1 CLUE updating z incrementally

I was wondering how the latent variable z in algorithm 1 of the paper is being updated. In the paper line 6 of the algorithm 1, I understand that you update z and then use that updated z in line 3. However, the code here just takes the initial z value at each iteration through the function CLUE.uncertainty_from_z()

def uncertainty_from_z(self):
Can you provide a small clarification on this please?
Thanks in advance :)

Licence

Hey there,

thanks for publishing this work!

Could you please provide a licence (e.g. MIT Licence)? Otherwise, we can not use this work for further research.

Thank you very much in advance.

Pascal

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.