Giter Site home page Giter Site logo

pengchengh / s-fsvi Goto Github PK

View Code? Open in Web Editor NEW

This project forked from timrudner/s-fsvi

0.0 0.0 0.0 1.43 MB

Code for the paper 'Continual Learning via Sequential Function-Space Variational Inference'

Home Page: https://timrudner.com/sfsvi

License: MIT License

Shell 0.01% Python 27.08% Jupyter Notebook 72.90%

s-fsvi's Introduction

Continual Learning via Sequential Function-Space Variational Inference (S-FSVI)

This repository contains the official implementation for

Continual Learning via Sequential Function-Space Variational Inference; Tim G. J. Rudner, Freddie Bickford Smith, Qixuan Feng, Yee Whye Teh, Yarin Gal. ICML 2022.

Abstract: Sequential Bayesian inference over predictive functions is a natural framework for continual learning from streams of data. However, applying it to neural networks has proved challenging in practice. Addressing the drawbacks of existing techniques, we propose an optimization objective derived by formulating continual learning as sequential function-space variational inference. In contrast to existing methods that regularize neural network parameters directly, this objective allows parameters to vary widely during training, enabling better adaptation to new tasks. Compared to objectives that directly regularize neural network predictions, the proposed objective allows for more flexible variational distributions and more effective regularization. We demonstrate that, across a range of task sequences, neural networks trained via sequential function-space variational inference achieve better predictive accuracy than networks trained with related methods while depending less on maintaining a set of representative points from previous tasks.

View Paper

In particular, this codebase includes:

  • An implementation of the sequential function-space variational objective [1];
  • Notebooks that reproduce the results in the paper;
  • A general, easy-to-extend continual learning training and evaluation protocol;
  • A set of framework-agnostic dataloader methods for widely used continual learning tasks;

[1] The implementation is based on the approximation proposed in Tractable Function-Space Variational Inference in Bayesian Neural Networks (Rudner et al., 2022).



Figure 1
Figure 1. Schematic of sequential function-space variational inference.

Installation

To install requirements:

$ conda env update -f environment.yml
$ conda activate fsvi

This environment includes all necessary dependencies.

To create an fsvi executable to run experiments, run pip install -e ..

Reproducing results

Split MNIST, Permuted MNIST, and Split FashionMNIST

Method Split MNIST (MH)
Open In Colab
Split FashionMNIST (MH)
Open In Colab
Permuted MNIST (SH)
Open In Colab
Split MNIST (SH)
Open In Colab
S-FSVI (ours) 99.54% ± 0.04 99.05% ± 0.03 95.76% ± 0.02 92.87% ± 0.14
S-FSVI (larger networks) 99.76% ± 0.00 98.50% ± 0.11 97.50% ± 0.01 93.38% ± 0.10
S-FSVI (no coreset) 99.62% ± 0.01 99.17% ± 0.06 84.06% ± 0.46 20.15% ± 0.52
S-FSVI (minimal coreset [2]) NA [3] NA [3] 89.59% ± 0.30 51.44% ± 1.22

[2] "Minimal coresets" are constructed by randomly selecting one data point per class for a given task.

[3] Since S-FSVI already performs well without a coreset, the minimal coreset option is not useful.

Split CIFAR

Method Split CIFAR (MH)
Open In Colab
S-FSVI [4] 77.57% ± 0.84

Sequential Omniglot

Method Sequential Omniglot (MH)
Open In Colab
S-FSVI [4] 83.29% ± 1.2

[4] To speed up training and reduce the memory requirements, only the variance parameters in the final layer of the network are learned variationally and the linearization is computed on the final layer only.

2D Visualization

This notebook Open In Colab demonstrates continual learning via S-FSVI on a sequence of five binary-classification tasks in a 2D input space.

Figure 2
Figure 2. Predictive distributions of a model trained via S-FSVI on tasks 1-5.

Adding new methods or tasks

  • To implement a new method, create a file method_cl_methodname.py in /benchmarking. For reference, see /benchmarking/method_cl_template.py and /benchmarking/method_cl_fsvi.py.
  • To implement a new dataloader, add a new method to benchmarking/data_loaders.

Citation

@InProceedings{rudner2022continual,
      author={Tim G. J. Rudner and Freddie Bickford Smith and Qixuan Feng and Yee Whye Teh and Yarin Gal},
      title = {{C}ontinual {L}earning via {S}equential {F}unction-{S}pace {V}ariational {I}nference},
      booktitle ={Proceedings of the 39th International Conference on Machine Learning},
      year = {2022},
      series ={Proceedings of Machine Learning Research},
      publisher ={PMLR},
}

Please cite our paper if you use this code in your own work.

s-fsvi's People

Contributors

qixuanf avatar timrudner avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.