Giter Site home page Giter Site logo

deepmodeling / jax-fem Goto Github PK

View Code? Open in Web Editor NEW
153.0 8.0 30.0 78.44 MB

Differentiable Finite Element Method with JAX

License: GNU General Public License v3.0

Python 99.98% Shell 0.02%
differentiable-programming finite-element-methods jax topology-optimization

jax-fem's Introduction

A GPU-accelerated differentiable finite element analysis package based on JAX. Used to be part of the suite of open-source python packages for Additive Manufacturing (AM) research, JAX-AM.

Finite Element Method (FEM)

Github Star Github Fork License

FEM is a powerful tool, where we support the following features

  • 2D quadrilateral/triangle elements
  • 3D hexahedron/tetrahedron elements
  • First and second order elements
  • Dirichlet/Neumann/Robin boundary conditions
  • Linear and nonlinear analysis including
    • Heat equation
    • Linear elasticity
    • Hyperelasticity
    • Plasticity (macro and crystal plasticity)
  • Differentiable simulation for solving inverse/design problems without human deriving sensitivities, e.g.,
    • Topology optimization
    • Optimal thermal control
  • Integration with PETSc for solver choices

Updates (Dec 11, 2023):

  • We now support multi-physics problems in the sense that multiple variables can be solved monolithically. For example, consider running python -m applications.stokes.example
  • Weak form is now defined through volume integral and surface integral. We can now treat body force, "mass kernel" and "Laplace kernel" in a unified way through volume integral, and treat "Neumann B.C." and "Robin B.C." in a unified way through surface integral.

Thermal profile in direct energy deposition.

Linear static analysis of a bracket.

Crystal plasticity: grain structure (left) and stress-xx (right).

Stokes flow: velocity (left) and pressure(right).

Topology optimization with differentiable simulation.

Installation

Create a conda environment from the given environment.yml file and activate it:

conda env create -f environment.yml
conda activate jax-fem-env

Install JAX

  • See jax installation instructions. Depending on your hardware, you may install the CPU or GPU version of JAX. Both will work, while GPU version usually gives better performance.

Then there are two options to continue:

Option 1

Clone the repository:

git clone https://github.com/deepmodeling/jax-fem.git
cd jax-fem

and install the package locally:

pip install -e .

Quick tests: You can check demos/ for a variety of FEM cases. For example, run

python -m demos.hyperelasticity.example

for hyperelasticity.

Also,

python -m tests.benchmarks

will execute a set of test cases.

Option 2

Install the package from the PyPI release directly:

pip install jax-fem

Quick tests: You can create an example.py file and run it:

python example.py
import jax
import jax.numpy as np
import os

from jax_fem.problem import Problem
from jax_fem.solver import solver
from jax_fem.utils import save_sol
from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_mesh

class Poisson(Problem):
    def get_tensor_map(self):
        return lambda x: x

    def get_mass_map(self):
        def mass_map(u, x):
            val = -np.array([10*np.exp(-(np.power(x[0] - 0.5, 2) + np.power(x[1] - 0.5, 2)) / 0.02)])
            return val
        return mass_map

    def get_surface_maps(self):
        def surface_map(u, x):
            return -np.array([np.sin(5.*x[0])])

        return [surface_map, surface_map]

ele_type = 'QUAD4'
cell_type = get_meshio_cell_type(ele_type)
Lx, Ly = 1., 1.
meshio_mesh = rectangle_mesh(Nx=32, Ny=32, domain_x=Lx, domain_y=Ly)
mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type])

def left(point):
    return np.isclose(point[0], 0., atol=1e-5)

def right(point):
    return np.isclose(point[0], Lx, atol=1e-5)

def bottom(point):
    return np.isclose(point[1], 0., atol=1e-5)

def top(point):
    return np.isclose(point[1], Ly, atol=1e-5)

def dirichlet_val_left(point):
    return 0.

def dirichlet_val_right(point):
    return 0.

location_fns = [left, right]
value_fns = [dirichlet_val_left, dirichlet_val_right]
vecs = [0, 0]
dirichlet_bc_info = [location_fns, vecs, value_fns]

location_fns = [bottom, top]

problem = Poisson(mesh=mesh, vec=1, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns)
sol = solver(problem, linear=True, use_petsc=True)

data_dir = os.path.join(os.path.dirname(__file__), 'data')
vtk_path = os.path.join(data_dir, f'vtk/u.vtu')
save_sol(problem.fes[0], sol[0], vtk_path)

License

This project is licensed under the GNU General Public License v3 - see the LICENSE for details.

Citations

If you found this library useful in academic or industry work, we appreciate your support if you consider 1) starring the project on Github, and 2) citing relevant papers:

@article{xue2023jax,
  title={JAX-FEM: A differentiable GPU-accelerated 3D finite element solver for automatic inverse design and mechanistic data science},
  author={Xue, Tianju and Liao, Shuheng and Gan, Zhengtao and Park, Chanwook and Xie, Xiaoyu and Liu, Wing Kam and Cao, Jian},
  journal={Computer Physics Communications},
  pages={108802},
  year={2023},
  publisher={Elsevier}
}

jax-fem's People

Contributors

itk22 avatar qiwei-chen avatar snms95 avatar tajtac avatar tianjuxue avatar xwpken avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jax-fem's Issues

Docs / Tutorial Notebook

It seems like there is currently no proper docs / tutorial for this package. I'm interested in applying this to some problems in my research, but it'd be very helpful to have some sort of tutorial for extending this to new problems.

Cheers!

Question about stresses for Linear_Elasticity demo

Hi, I noticed there was a recent commit made by @xwpken to the Linear_Elasticity demo that implemented post-processing for stresses in the beam. I'm just confused on how to interpret the data that has been computed at the end of the script (Sigma and sigma_average). I don't think they get saved into the VTU file, so I can't use Paraview to display it.

Could anyone help me out on how I can visualise them? Thank you!

Problem in the readme file.

Can you please change this in the readme (here)?

git clone [email protected]:tianjuxue/jax-fem.git

You can replace this with the HTTP cloning.

git clone https://github.com/tianjuxue/jax-fem.git

Otherwise people will get access denied.

Question about the demo thermal_mechanical_full

Hello everyone,

I've been exploring this code with a keen interest in solving coupling problems. The code successfully passed the benchmark tests and have been able to run demos.hyperelasticity.example without any issues. However, I've encountered a stumbling block while trying to work with the thermal_mechanical_full demo. I'm facing a FileNotFoundError as detailed below:

FileNotFoundError: [Errno 2] No such file or directory: '/home/myname/Desktop/jax-fem/demos/thermal_mechanical_full/input/numpy/points.npy'

I'm wondering if there's something I might miss. Any guidance or suggestions would be greatly appreciated!

Save output files at user specified intervals

In cases of hyperelasticity where we need incremental load/displacement application, or in the case of rate or history dependent problems, the user would like to save output files at specific intervals. The current implementation 1cc776c makes this difficult as the solver handles the incremental loading internally and the solution saving step gets only the final solution. The user should be able to specify an output save (to disk) frequency and get the desired outputs.

Remove from jax.config import config, from fe.py

Hello everyone,
Is it possible to remove "from jax.config import config" from the files ?
I have to create a dockerfile with JAX-FEM, and it will not be working because of that.
It works well in a conda venv without "from jax.config import config", otherwise I get an error because of the update.
Thanks a lot

How to do parallel solver

Hi, I'm trying to create an hybrid model, that use your FEM-solver and a neural network.
To do that, I need to solve the same equations with different parameters (e.g the heat diffusion, where the diffusion coefficient and ic are different for each data).
I can't use vmap because the solver is using scipy and numpy, which isn't compatible.
Do you think the solver can be adapted so it can managed batches or can be pass into vmap ?
Thanks in advance for any idea !

JAX-FEM Refactor

Below is a tasklist for the refactor of the JAX-FEM core code. The first stage of the refactor will focus on the core.py file. The proposed file structure is detailed on the refactor branch and this is where all the new changes should be merged.

Tasks
Stage 1:

  • Fill in the core_fem file - @itk22
  • Create a general object to represent kernels and integrate with the core code - @SNMS95 + @itk22 (support)
  • Create wrappers for linear solvers based on Lineax - @SNMS95
  • Checking all files and ammendments - Gawel

Boundary conditions using mesh properties.

When defining boundary conditions for nodes, is there a way the interface can be extended to not only take in the node coordinates, but also the node index? Something like this:

def top_nodes(point, index):
    top = jnp.isclose(point[2], upper_surface_height, atol=TOLERANCE    # Nodes on the top surface of my mesh
    active = jnp.any(jnp.in1d(point.index, active_nodes))    # Nodes in my predefined active list
    return top & active

I am implementing a layer-by-layer model with element activation. I can deactivate elements above a certain height based on their centroid, but when applying the boundary condition, I cannot select only the nodes connected to my active elements with the XYZ position alone.

Any suggestions welcome!

Initial task list

  • General cleanup + Flake8 file -> Surya
  • Enable pip install in editable mode [Setup] -> Surya
  • Environment file -> Surya
  • All tests in one location -> Surya
  • Baseline performance - Memory, Speed & accuracy - benchmark -> Igor
  • GH workflow setup -> Igor

Create a uniform approach/template for post processing

Currently the post processing steps are a bit unorganized, and the templates for that aren't easy to find (example). It would be a big help to have a general template (or say a utils submodule) made available so that users can use/ modify them as per their need.

PETSc solver not working with JAX transforms

  • Currently, the solver [when using petsc] cannot be combined with any of Jax's transforms
  • This prevents vmapping and higher order derivatives
  • Solution:
    • Wrap petsc with Lineax or
    • Wrap it a pure-callback + custom-jvp

Multi-step loading + auto-diff

  • In the plasticity example, the load is applied in smaller steps. This is beneficial for easy convergence of iterative solvers for other cases as well.
  • Solution:
    • Just as ad_wrapper performs implicit differentiation through a non-linear solve, we should allow either implicit differentiation through load steps!

issue with petsc4py

I am trying to install the jax-fem on a windows platform. I am unable to install Petsc4py using pip install. I am not able to install using condo as well. It says package not found in the conda-forge channels. Is JAX-FEM not meant to work on windows? Please help as this package will be extremely useful for my research.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.