Giter Site home page Giter Site logo

izsahara / gpjax Goto Github PK

View Code? Open in Web Editor NEW

This project forked from jaxgaussianprocesses/gpjax

0.0 0.0 0.0 14.75 MB

A didactic Gaussian process package for researchers in Jax.

Home Page: https://gpjax.readthedocs.io/en/latest/

License: Apache License 2.0

Python 98.83% Makefile 1.17%

gpjax's Introduction

GPJax

codecov CodeFactor Documentation Status

Quickstart | Install guide | Documentation

GPJax aims to provide a low-level interface to Gaussian process models. Code is written entirely in Jax to enhance readability, and structured to allow researchers to easily extend the code to suit their own needs. When defining GP prior in GPJax, the user need only specify a mean and kernel function. A GP posterior can then be realised by computing the product of our prior with a likelihood function. The idea behind this is that the code should be as close as possible to the maths that we would write on paper when working with GP models.

Supported methods and interfaces

Examples

Guides for customisation

Simple example

After importing the necessary dependencies, we'll first simulate some data.

import gpjax
from gpjax import Dataset
import jax
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(123)

x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(50,)).sort().reshape(-1, 1)
y = jnp.sin(x) + jr.normal(key, shape=x.shape)*0.05
training = Dataset(X=x, y=y)

As can be seen, the latent function of interest here is a sinusoidal function. However, it has been perturbed by some zero-mean Gaussian noise with variance of 0.05. We can use a Gaussian process model to try and recover this latent function.

from gpjax.kernels import RBF
from gpjax.gps import Prior

f = Prior(kernel = RBF())

In the presence of a likelihood function which we'll here assume to be Gaussian, we can optimise the marginal log-likelihood of the Gaussian process prior multiplied by the likelihood to obtain a posterior distribution over the latent function.

from gpjax.likelihoods import Gaussian

likelihood = Gaussian()
posterior = f * likelihood

Equipped with the Gaussian process posterior, we can now optimise the model's hyperparameters (note, we need not optimise the latent function here due to the Gaussian conjugacy.). To do this, we can either define our parameters by hand through a dictionary, or realise a set of default parameters through the initialise callable. For brevity, we'll do the latter here but see the regression notebook for a full discussion on parameter initialisation and transformation.

from gpjax.parameters import initialise, build_all_transforms
from gpjax.config import get_defaults

params = initialise(posterior)
configs = get_defaults()
constrainer, unconstrainer = build_all_transforms(params.keys(), configs)
params = unconstrainer(params)

With initial values defined, we can now optimise the hyperparameters' value by carrying out gradient-based optimisation with respect to the GP's marginal log-likelihood. We'll do this now using Jax's built in optimisers, namely the Adam optimiser with a step-size of 0.01. We can also Jit compile our objective function to accelerate training. You'll notice that it is only now that we have incorporated any data into our GP. This is desirable, as this is exactly how model building works in principle too, where we first build our prior model, then observe some data and use this data to build a posterior.

from gpjax.objectives import marginal_ll
from jax.experimental import optimizers

mll = jit(marginal_ll(posterior, transform=constrainer, negative=True))

opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
opt_state = opt_init(params)
def step(i, opt_state):
    p = get_params(opt_state)
    g = jax.grad(mll)(p, training)
    return opt_update(i, g, opt_state)


for i in range(100):
    opt_state = step(i, opt_state)

Our parameters are now optimised. We can retransfrom these back onto the parameter's original constrained space and, using this learned value, query the GP at a set of test points.

from gpjax.predict import mean, variance


final_params = constrainer(get_params(opt_state))

xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)

predictive_mean = mean(posterior, final_params, training)(xtest)
predictive_variance = variance(posterior, final_params, training)(xtest)

Installation

Stable version

To install the latest stable version of gpjax run

pip install gpjax

Development version

To install the lastest, possibly unstable, version, the following steps should be followed. It is by no means compulsory, but we do advise that you do all of the below inside a virtual environment.

git clone https://github.com/thomaspinder/GPJax.git
cd GPJax 
python setup.py develop

It is then recommended that you check your installation using the supplied unit tests

python -m pytest tests/

Note that installing the latest version of GPJax on Apple M1 devices is currently unstable.

gpjax's People

Contributors

thomaspinder avatar jejjohnson avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.