Giter Site home page Giter Site logo

f-dangel / curvlinops Goto Github PK

View Code? Open in Web Editor NEW
14.0 4.0 4.0 766 KB

scipy linear operators for the Hessian, Fisher/GGN, and more in PyTorch

Home Page: https://curvlinops.readthedocs.io/en/latest/

License: MIT License

Makefile 0.75% Python 99.25%
fisher hessian linalg linear operators pytorch scipy ggn linearoperator

curvlinops's Introduction

Logo scipy linear operators of deep learning matrices in PyTorch

Python 3.8+ tests Coveralls

This library implements scipy.sparse.linalg.LinearOperators for deep learning matrices, such as

  • the Hessian
  • the Fisher/generalized Gauss-Newton (GGN)
  • the Monte-Carlo approximated Fisher
  • the Fisher/GGN's KFAC approximation (Kronecker-Factored Approximate Curvature)
  • the uncentered gradient covariance (aka empirical Fisher)
  • the output-parameter Jacobian of a neural net and its transpose

Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU. The library supports defining these matrices not only on a mini-batch, but on data sets (looping over batches during a matvec operation).

You can plug these linear operators into scipy, while carrying out the heavy lifting (matrix-vector multiplies) in PyTorch on GPU. My favorite example for such a routine is scipy.sparse.linalg.eigsh that lets you compute a subset of eigen-pairs.

The library also provides linear operator transformations, like taking the inverse (inverse matrix-vector product via conjugate gradients) or slicing out sub-matrices.

Finally, it offers functionality to probe properties of the represented matrices, like their spectral density, trace, or diagonal.

Installation

pip install curvlinops-for-pytorch

Examples

Future ideas

Other features that could be supported in the future include:

Logo mage credits

curvlinops's People

Contributors

f-dangel avatar ltatzel avatar runame avatar wiseodd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

curvlinops's Issues

Verify that the order of the mini-batches is deterministic where it matters

Some linear operators like FisherMCLinearOperator rely on deterministic mini-batches, i.e. that shuffle=False for the used data loader. Otherwise, there will be the following error which does not point out the cause of the issue:

RuntimeError: Check for deterministic matvec failed.

To avoid this we could check that shuffle=False upon construction.

Improve efficiency of MC Fisher

Currently, we loop over the data points and MC samples, which is the 'best' way in the sense of FLOPs (batch_size * mc_samples backpropagations), but suffers from poor parallelization. We could use functorch's vmap, but currently the library relies mostly on autograd.grad and I would like to keep it this way as there are certain benefits (e.g. recycling the computation graph when multiplying onto multiple vectors, block-diagonal approximations).

Another way I propose to implement is phrasing multiplication with the MC Fisher as a GGN-vector product where the loss function is such that it's Hessian corresponds to the summed outer product of sampled gradients.
This uses more FLOPs (batch_size * C backpropagations), but does not require a loop over batch_size. Therefore, this should often yield better run time. One downside of this approach is that it costs as much as multiplying with the exact GGN. This renders the motivation for having an MC-sampled version (less accurate, same cost) weaker.

Linear operator for KFAC

Implement a linear operator that multiplies with the Kronecker-factorized curvature approximation of the Fisher.

  • The first version will only support networks with parameters in Linear and Conv2d layers
  • There should be an option to treat .weight and .bias jointly or separately.
  • There could be an option for KFAC-expand and KFAC-reduce

Make less/no assumption about the data

Currently, curvlinops assumes that X and y are both torch.Tensor. This works for the old deep learning paradigm. But with the rise of LLMs and other complicated models, one should not make that assumption.

For instance, the input of a Huggingface model is a UserDict:

data = UserDict({
    'input_ids': torch.LongTensor(...),
    'attention_mask': torch.LongTensor(...),
    'labels': torch.LongTensor(...)
})

In this case, one can extract X and y via (example from laplace-torch)

if isinstance(data, UserDict)  or isinstance(data, dict): # To support Huggingface dataset
    X, y = data, data['labels'].to(self._device)

However, curvlinops is strongly assuming X to be a tensor, see for example here and here and here.

I think the best way to circumvent this is to not assume anything about X (preferably also y, e.g. for multi-output models).

Add `state_dict` functionality to `KFACLinearOperator` and `KFACInverseLinearOperator`

Since KFACLinearOperator and KFACInverseLinearOperator both have state that is potentially expensive to compute, it is often convenient to store the (inverted) Kronecker factors to disk to save computation. To implement this, it makes sense to add a state_dict method to them, together with load_state_dict and a classmethod from_state_dict.

Feature request: add method to convert to matrix

To convert a linear operator to an explicit matrix, we need the following one-liner:

matrix = operator @ np.eye(operator.shape[-1])

SciPy also has the aslinearoperator function, that converts a matrix to an operator.

What do you think about adding a method that does the converse, i.e.

matrix = operator.asmatrix()

It would basically involve the one-liner above. Pros, cons?

Allow arbitrarily-ordered parameters in KFAC

At the moment, the params supplied to KFAC must be in the same order as the NN's parameters, i.e.

model = Sequential(Linear(...), Linear(...))

# supported
params_allowed = [model.0.weight, model.0.bias, model.1.weight, model.1.bias]

# not supported
params_forbidden = [model.1.bias, model.0.bias, model.0.weight, model.1.weight]

While parameters will often be supplied in the correct order, dealing with arbitrary orders should be supported.

Feature request: Diagonal estimation algorithms

Similar to Hutchinson's method for trace estimation, one can approximate the diagonal of a matrix from projections onto random vectors, see for instance Equation 16, or Equation 9.

There are no implementations for Hessian diagonal estimation in scipy, so it would be nice to offer such methods through a LinearOperator interface through this library.

[BUG | KFAC] Changing device invalidates parameter mapping

Whenever KFAC is instantiated on GPU, and we call .to_device(device("cpu)), this will invalidate the internal mapping between parameter .data_ptr()s to module names which is needed to identify a parameter's position in the list format. Currently, this bug can silently pass because we do not check in matmat whether all parameter positions are processed.

[BUG] Scaling of (KFAC) empirical/MC Fisher broken in some cases when using mean reduction for loss

The reduction_factor is assumed to be just the dataset size in the implementation of the empirical Fisher, which is incorrect in the case of MSELoss, BCEWithLogitsLoss, and in some cases also for CrossEntropyLoss (when the model output has more than two dimensions). KFAC with fisher_type="empirical" also requires a change in the scaling, similar to here.

The same applies to FisherMCLinearOperator, but not for KFAC with fisher_type="mc".

Add option for heuristic and exact damping to `KFACInverseLinearOperator`

There are two different ways to set the damping for the KFACInverseLinearOperator that we should implement:

  1. 'Heuristic' damping, introduced in section 6.3 in the original K-FAC paper.
  2. 'Exact' damping, which can be efficiently implemented for Kronecker factored matrices, e.g. see equation (21) in Grosse et al., 2023.

There is also an 'adaptive' damping scheme, as described in the original K-FAC paper, but I do not plan to implement this for now.

Support `BCEWithLogitsLoss`

Requested by @wiseodd for laplace-torch.

This consists of three parts:

  • Support sqrt_hessian for KFAC (type 2)
  • Support sample_grad_output for FisherMC
  • Support draw_label for KFAC (type 1)

Feature request: Trace estimation algorithms

The trace is often used to summarize curvature matrices in second-order methods or for generalization metrics.

I could not find libraries that provide trace estimation methods for scipy.sparse.LinearOperators. The closest library is Nico's matfree which has Hutchinson trace estimation for JAX. pyhessian has Hutchinson trace estimation in PyTorch, but does not use a LinearOperator interface and only considers the Hessian.

So it would be useful to offer trace estimation through a scipy-based linear operator interface through this library.

Possible algorithms are:

  • (Basic) Hutchinson trace estimation (see Section 4)
  • (Advanced) Hutch++ (paper, matlab implementation)
  • (Advanced) NA-Hutch++ (paper): I decided against implementing NA-Hutch++, since it does not offer memory savings over Hutch++. According to the paper, non-adaptive methods have practical benefits when used with batch-multiplies of the linear operator. The linear operators offered by this library however do only support efficient matvecs (matmats are for loops) and hence do not allow to leverage this benefit. Another point against implementing and maintaining this method is that according to the meyer2020hutch paper, NA-Hutch++ "tends to perform slightly worse in our experiments."

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.