Giter Site home page Giter Site logo

hypertorch's Introduction

HyperTorch

Lightweight flexible research-oriented package to compute hypergradients in PyTorch.

What is an hypergradient?

Given the following bi-level problem.

bilevel

We call hypegradient the following quantity.

hypergradient

Where:

  • outerobjective is called the outer objective (e.g. the validation loss).
  • Phi is called the fixed point map (e.g. a gradient descent step or the state update function in a recurrent model)
  • finding the solution of the fixed point equation fixed_point_eq is referred to as the inner problem. This can be solved by repeatedly applying the fixed point map or using a different inner algorithm.

Quickstart

hyperparameter optimization

See this notebbook, where we show how to compute the hypergradient to optimize the regularization parameters of a simple logistic regression model.

meta-learning

examples/iMAML.py shows an implementation of the method described in the paper Meta-learning with implicit gradients. The code uses higher to get stateless version of torch nn.Module-s and torchmeta for meta-dataset loading and minibatching.

This notebbook shows how to train a simple equilibrium network with "RNN-style" dynamics.

MORE EXAMPLES COMING SOON

Use cases

Hypergadients are useful to perform

  • gradient-based hyperparamter optimization
  • meta-learning
  • training models that use an internal state (some types of RNNs and GNNs, Deep Equilibrium Networks, ...)

Install

Requires python 3 and PyTorch version >= 1.4.

git clone [email protected]:prolearner/hypergrad.git
cd hypergrad
pip install .

python setup.py install would also work.

Implmented methods

The main methods for computing hypergradients are in the module hypergrad/hypergradients.py.

All methods require as input:

  • a list of tensors representing the inner variables (models' weights);
  • another list of tensors for the outer variables (hyperparameters/meta-learner paramters);
  • a callable differentiable outer objective;
  • a callable that represents the differentiable update mapping (except reverse_unroll). For example this can be an SGD step.

Iterative differentiation methods:

These methods differentiate through the update dynamics used to solve the inner problem. This allows to optimize the inner solver parameters such as the learning rate and momentum.

Methods in this class are:

  • reverse_unroll: computes the approximate hypergradient by unrolling the entire computational graph of the update dynamics for solving the inner problem. The method is essentially a wrapper for standard backpropagation. IMPORTANT NOTE: the weights must be non-leaf tensors obtained through the application of "PyThorch differentiable" update dynamics (do not use built-in optimizers!). NOTE N2.: this method is memory hungry!
  • reverse: computes the hypergradient as above but uses less memory. It uses the trajectory information and recomputes all other necessary intermediate variables in the backward pass. It requires the list of past weights and the list of callable update mappings applied during the inner optimization.

Approximate Implicit Differentiation methods:

These methods approximate the hypergradient equation directly by:

  • Using an approximate solution to the inner problem instead of the true one.
  • Computing an approximate solution to the linear system (I-J)x_star = b, where J and b are respectively the transpose of the jacobian of the fixed point map and the gradient of the outer objective both w.r.t the inner variable and computed on the approximate solution to the inner problem.

Since computing and storing J is usually unfeasible, these methods exploit torch.autograd to compute the Jacobian-vector product Jx efficiently. Additionally, they do not require storing the trajectory of the inner solver, thus providing a potentially large memory advantage over iterative differentiation. These methods are not suited to optimize the parameters of the inner solver like the learning rate.

Methods in this class are:

  • fixed_point: it approximately solves the linear system by repeatedly applying the map T(x) = Jx + b. NOTE: this method converges only when the fixed point map and consequently the map T are contractions.
  • CG: it approximately solves the linear system with the conjugate gradient method. IMPORTANT N0TE: I-J must be symmetric and positive definite for this to work!
  • CG_normal_eq: As above, but uses conjugate gradient on the normal equations (i.e. solves J^TJx = J^Tb instead) which works also whenI-J is not symmetric and positive definite. NOTE: the cost per iteration can be much higher than the other methods.

Cite

If you use this code, plese cite our paper

@inproceedings{grazzi2020iteration,
  title={On the Iteration Complexity of Hypergradient Computation},
  author={Grazzi, Riccardo and Franceschi, Luca and Pontil, Massimiliano and Salzo, Saverio},
  journal={Thirty-seventh International Conference on Machine Learning (ICML)},
  year={2020}
}

hypertorch's People

Contributors

prolearner avatar lucfra avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.