Giter Site home page Giter Site logo

mtitze / njet Goto Github PK

View Code? Open in Web Editor NEW
4.0 4.0 2.0 3.44 MB

A lightweight AD package, using forward-mode automatic differentiation, in order to determine the higher-order derivatives of a given function in multiple variables.

License: Other

Python 100.00%

njet's People

Contributors

mtitze avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar

Forkers

adtzlr

njet's Issues

Gradient w.r.t. an Array (Tensor)

Hi @mtitze,

do you think that this is the recommended way to use njet on many 3x3 tensors? I took the strain energy function of a Neo-Hookean solid as an example.

from njet.functions import log
from njet import derive
import numpy as np

# init `n` random 3x3 tensors
n = 2 ** 18
x = (np.eye(3) + np.random.rand(n, 3, 3) / 10).T

def grad(fun):
    "A decorator to evaluate the gradient."
    
    def inner(x, **kwargs):
        # init the output
        out = np.zeros((9, *x.shape[2:]))
        
        # obtain the gradient and loop over items
        for key, value in fun.grad(*flatten(x), **kwargs).items():
            out[key] = value
        
        # reshape the gradient
        return out.reshape(3, 3, -1)
    
    return inner

def hess(fun):
    "A decorator to evaluate the hessian."

    def inner(x, **kwargs):
        # init the output
        out = np.zeros((9, 9, *x.shape[2:]))
        
        # obtain the hessian and loop over items
        for key, value in fun.hess(*flatten(x), **kwargs).items():
            out[key] = value
        
        # reshape the hessian
        return out.reshape(3, 3, 3, 3, -1)
    
    return inner

def trace(x):
    return x[0][0] + x[1][1] + x[2][2]

def det(x):
    a = x[0][0] * x[1][1] * x[2][2]
    b = x[0][1] * x[1][2] * x[2][0]
    c = x[0][2] * x[1][0] * x[2][1]
    d = x[2][0] * x[1][1] * x[0][2]
    e = x[2][1] * x[1][2] * x[0][0]
    f = x[2][2] * x[1][0] * x[0][1]
    return a + b + c - d - e - f

flatten = lambda x: x.reshape(9, -1)
reshape = lambda x: [[x[0], x[1], x[2]], [x[3], x[4], x[5]], [x[6], x[7], x[8]]]
# reshape = lambda x: np.array(x, dtype=object).reshape(3, 3, -1) # this fails

def tensorjet(fun):
    "A decorator to reshape the input."
    def inner(*x, **kwargs):
        return fun(reshape(x), **kwargs)
    return inner

@tensorjet
def fun(C):
    "The Neo-Hookean material formulation (isotropic hyperelasticity)."
    return trace(C) - log(det(C))

dfun = derive(fun, order=2, n_args=9)
dfundx = grad(dfun)(x)
d2fundx2 = hess(dfun)(x)

# %timeit hess(dfun)(x)
# 2.08 s ± 27.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Any chance to use NumPy instead of nested lists, i.e. replace the reshape-method?

 reshape = lambda x: np.array(x, dtype=object).reshape(3, 3, -1) # (this fails)

I couldn't get it to work.

Thanks!

Package contains .ipynb_checkpoints directory

Do git clean -fdx when packaging :).

ps. got some scary moments debugging the import errors in CI pipelines because of naming collision with ours internal package having current version 0.1.1 :-D
pps. congrats with public release!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.