Giter Site home page Giter Site logo

asem000 / finitediffx Goto Github PK

View Code? Open in Web Editor NEW
6.0 1.0 3.0 167 KB

Finite difference tools in JAX

Home Page: https://finitediffx.readthedocs.io/en/latest/index.html

License: Apache License 2.0

Python 93.59% Jupyter Notebook 6.41%
finite-difference-method jax python

finitediffx's Introduction

Differentiable finite difference tools in jax

๐Ÿ› ๏ธ Installation

pip install FiniteDiffX

Install development version

pip install git+https://github.com/ASEM000/FiniteDiffX

If you find it useful to you, consider giving it a star! ๐ŸŒŸ


โฉ Quick Example

import jax.numpy as jnp
import finitediffx as fdx

# lets first define a vector valued function F: R^3 -> R^3
# F = F1, F2
# F1 = x^2 + y^3
# F2 = x^4 + y^3
# F3 = 0
# F = [x**2 + y**3, x**4 + y**3, 0]

x, y, z = [jnp.linspace(0, 1, 100)] * 3
dx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]
X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij")
F1 = X**2 + Y**3
F2 = X**4 + Y**3
F3 = jnp.zeros_like(F1)
F = jnp.stack([F1, F2, F3], axis=0)

# โˆ‡.F : the divergence of F
divF = fdx.divergence(
    F,
    step_size=(dx, dy, dz),
    keepdims=False,
    accuracy=6,
    method="central",
)

jax.grad, jax.value_and_grad finite difference counterpart to be used on unimplemented rules in jax or on non-traceable numpy code

import jax
from jax import numpy as jnp
import numpy as onp  # Not jax-traceable
import finitediffx as fdx
import functools as ft
from jax.experimental import enable_x64

with enable_x64():

    @fdx.fgrad
    @fdx.fgrad
    def np_rosenbach2_fdx_style_1(x, y):
        """Compute the Rosenbach function for two variables in numpy."""
        return onp.power(1-x, 2) + 100*onp.power(y-onp.power(x, 2), 2)

    @ft.partial(fdx.fgrad, derivative=2)
    def np2_rosenbach2_fdx_style2(x, y):
        """Compute the Rosenbach function for two variables."""
        return onp.power(1-x, 2) + 100*onp.power(y-onp.power(x, 2), 2)

    @jax.grad
    @jax.grad
    def jnp_rosenbach2(x, y):
        """Compute the Rosenbach function for two variables."""
        return jnp.power(1-x, 2) + 100*jnp.power(y-jnp.power(x, 2), 2)

    print(np_rosenbach2_fdx_style_1(1.,2.))
    print(np2_rosenbach2_fdx_style2(1.,2.))
    print(jnp_rosenbach2(1., 2.))
# 402.0000951997936
# 402.0000000002219
# 402.0

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.