Giter Site home page Giter Site logo

ucl-bug / jwave Goto Github PK

View Code? Open in Web Editor NEW
134.0 8.0 21.0 56.13 MB

A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs

License: GNU Lesser General Public License v3.0

Python 80.56% Makefile 1.64% Shell 1.42% PowerShell 1.67% MATLAB 5.61% Jupyter Notebook 9.11%
jax simulation acoustics differentiable-simulations physics-informed-neural-networks scientific-machine-learning gpu-acceleration kwave ultrasound wave-equation

jwave's Issues

Add `all-contributors` bot

It is amazing to see that people are starting to help with this project. Even if only for a bug report, it is important to acknowledge when people offer their free time to improve a project.

One simple way would be to add the all-contributors bot, and follow the specification highlighted here: allcontributors.org

Further tests

  • Tests Helmholtz
  • Test BLI sources
  • Rectangualr domains
  • 3D
  • Test point sensors
  • Test BLI sensors
  • Test TimeHarmonicSource
  • Check for duplicated functions
  • Add tests for differentiability
  • Self consistency test between time varying and time harmonic (scaling, frequencies, etc...)

Others

  • Mirror notebooks to python files
  • Transform notebooks into tests before release
  • Regression tests

Non-linearity

Implement the B/A parameter for modeling non-linearity in acoustic media

See eq. (2.7) of the k-Wave manual

It is not obvious how to model it for the Helmholtz equation, since the separation of variable which is assumed for deriving the Helmholtz equation from the wave equation doesn't hold anymore, so better skip it for this milestone.

Helmholtz Problem tutorial NotFoundLookupError

Describe the bug
Running the helmholtz problem tutorial produces the following error NotFoundLookupError: For function "helmholtz", signature Signature(jaxdf.discretization.FourierSeries, jwave.geometry.MediumType[<class 'float'>, <class 'float'>, <class 'float'>], builtins.float) could not be resolved. which occurs at the line where params = helmholtz.default_params(src, medium, omega) is defined.

To Reproduce
Steps to reproduce the behavior:

  1. Running the first five cells of the notebook produces the error.

Expected behavior
A clear and concise description of what you expected to happen.
I was expecting j-Wave to solve the helmholtz equation :)

Desktop (please complete the following information):

  • OS: macOS
  • Version: Ventura 13.5.2

Additional context
Add any other context about the problem here.
N/A

Add progress bars

There's now support for progress bar in jax scan and while loops!
See https://github.com/jeremiecoullon/jax-tqdm

It should be possible to add progress bars to simulations, which will greatly improve the experience for large problems.

The progress bars should be disabled by default, as they are probably inconvenient when jwave is used for machine learning applications on relatively small problems

Speed up Fourier Interpolation with real FFT

The current implementation of get_field for RealFourierSeries uses the standard fft and casts the result to real.

It should instead use the real FFT for improving performances

Check the number of signal and number of sources positions

The user should receive an error when tries to do something like

sources = Sources(
  positions=((32,),(32,),(32,)),
  signals=jnp.stack([s1, s1]),
  dt=time_axis.dt,
  domain=domain,
)

since the number of signals is not the same as the number of sources positions

" test_kwave_helmholtz_fd.py "

Hi, Can I ask about some bugs here?
When executing this script, it returns
"plum.function.NotFoundLookupError: For function "gradient", signature Signature(jaxdf.discretization.FiniteDifferences, builtins.list) could not be resolved."
It seems only "FourierSeries" is used for solving helmholtz for now.

Add tests for the rayleigh_integral operator

We need tests for the Rayleigh integral (harmonic formulation) in jwave.acoustics.time_harmonic.rayleigh_integral

It could be tested against angularSpectrum of k-Wave or using some convergence plot against a known solution (e.g. propagation of a single-frequency gaussian beam)

Slowness as medium variable

Wouldbe great to optionally allow the user to give the slowness map instead of the speed of sound in Medium.
This will require some checking that not both of them are defined, as well as rewriting of the appropriate functions of the solvers and propagators

Example with Laplace Helmholtz equation

Which is a fancy way to say that would be nice to show that the Helmholtz equation works with $\omega \in \mathbb{C}$, as long as $\text{Im}(\omega) \ in \ mathbb{R}^+$

Pre-compilation

A few times we have been discussing about being able to compile in advance a jitted simulator, to save time on the first execution.

It seems that it is now possible to do it on CPU and GPU: google/jax#13736

Just leaving this here as a reminder.

Optimal PML for various methods

The package needs a support function for setting optimal PML sizes, probably as a method of the Medium class. It should be analogous to the one provided by kWave, but discretization dependent

Add `diffrax` time integrators

The current time integrator is only "optimal" for FourierSeries fields thanks to the k-space operator correction. For finite differences, it would probably be better to allow the user to experiment with different kind of integrators.

This is a good reason for using diffrax as the main ODE integrator library for time domain solvers.

Another good reason would be to expose to the user the advanced checkpointing methods enabled by diffrax.

An initial attempt to code the modified semi-implicit Eluer integrator used by the k-space equations is here: https://github.com/astanziola/diffrax/blob/main/diffrax/solver/semi_implicit_euler.py
It works pretty well, timings are roughly the same, and checkpointing works.

The API needs to be changed to make sure that the user can generally choose his favorite integrator.

Raise error when `helmholtz_solver` is called with a real field

Describe the bug
It does not make sense to call the Helmholtz operator with a real source field, as it is defined on complex fields.

One option would be to convert the field to complex, but this could be confusing, so it is probably better to just raise an error when one attempts to do so.

Helmholtz operators are missing initializers

From https://discordapp.com/channels/1116725189972078639/1151518909254610984 by Wael - MIT:

I successfully ran the initial value problem you show in the documentation where jwave.png is loaded and uses as sources. However, I've been running into issues trying to recreate the results from the helmholtz problem test. The specific issue comes from this line: params = helmholtz.default_params(src, medium, omega=omega), which is always returning None for params. Would you mind helping with this? I see that default_params uses the method _bound_init_params from jaxdf.core, but for some reason the params passed back are always None. Thank you!

Support `conda` feedstock jax wheels

There are community-driven wheels for jax and jaxlib with GPU support via conda-forge

Relevant links:

I managed to get a working installation of jax on with only conda installed (no cuda or cudnn prior installation) using

$ conda create -n condajax python=3.9
$ conda activate condajax
$ conda install jaxlib jax --channel conda-forge

It is interesting to see if this works under windows, to avoid using the WSL or building jax from scratch

CBS adjoint

Despite the fact that, as it is, the Convergent Born Series does not allow for an heterogeneous density, it is still a powerful technique, especially for low-contrast sound speed media.

In absence of density heterogeneity the Helmholtz operator is self-adjoint (up to a change of sign for the absorption term). This means that it can be possible to override the vjp operation from jax with one that uses the CBS.

Potential things to look for:

  • I suspect this will not allow for backpropagation through the solver operations, so there should be a clear warning for the user if one tries to find such gradients
  • Make sure to calculate gradients for all remaining terms, including source terms
  • It would be very nice to be able to mix and match the forward and adjoint solver. For example, solving for the forward map using Waveholtz and the backward map with CBS.

Add `lineax` solvers

The new lineax library has a bunch of solvers for linear systems, including GMRES, that would be worth exploring to see if there's any benefit compared to jax GMRES.

Furthermore, since at some point we would like to test out the WaveHoltz method, which kindly provides a positive definite and potentially symmetric operator ๐Ÿ˜„ , we might want to use different solvers such as CG, so setting up this framework now seems beneficial.

Implement logger

It is good to have a logger, for example to let the user know what is the maximum supported frequency, what's the PML size and things like that.

NaNs when batching simulations in tutorial example.

Describe the bug
Hi guys, I am trying to reproduce the Full Wave Inversion tutorial example. When batching the single simulation across sensor indices with jax.vmap, I am getting NaNs, even though running those sensor simulations individually runs stably.

I'm running on a Colab with GPU.

To Reproduce
Colab to reproduce:

https://colab.research.google.com/drive/1U3ttIMkn4lZnlu4SMNf4Y0IcuB7Sv9MH?usp=sharing

Expected behavior
Batching vs. looping over single sensor simulations should produce the same result.

Desktop (please complete the following information):

  • Running on Google Colab with GPU.

Implementing off-grid sensors

Hi! I thought it would be really nice to introduce off-grid sensors into j-Wave, as they have done in k-Wave (http://www.k-wave.org/forum/topic/alpha-version-of-kwavearray-off-grid-sources).

I've implemented a basic version of this myself using the band-limited interpolant, which you can see in this gist: https://gist.github.com/tomelse/f9ba7508b75f44f34ebdbf25d5f5b0a3.

I'm happy to put a pull request in myself, I just wanted to get some feedback on how best to integrate it into the j-Wave code.

bli_comparison

The results look really nice, removing all the staircasing that you can see in the on-grid version, which you can see in an example simulation above (x axis = time, y axis= detector). It's a little bit slower than the original j-Wave implementation but not prohibitively so (maybe that's to be expected?).

Add proper configurations for flake8 and mypy

It is tricky to configure mypy and flake8 for this project, for several reason

  • Operator overloading means that mypy sees a redefinition of unused functions. This is a problems inherited from plum.
  • Use black and flake8? I like the 2 spaces indent, but doesn't conform with PEP, so perhaps it is time to let it go..

Make an example of B-Mode image in 2D

See the dissussion in #104

This will probably require to extend the Transducer class, for example by introducing some Transmit dataclasses that are used to define the forward simulation and are also used by the beamforming algorithm.

Document default PML

It should be clearer in the documentation the default PML value, and the fact that the PML is inside the domain.

Incorrect gradients for 3D Helmholtz simulations using checkpointing

There are some edge cases for which Helmholtz simulations in 3D give incorrect gradients when checkpoint=True is used in the helmholtz_solver for FourierSeries fields.

This is likely to be a jax issue, since results were correct up to jax 0.3.20.
It has been raised upstream: google/jax#14302

For the moment, if memory requirements allow, setting checkpoint=False should get around the problem.

May make sense to add a warning when checkpoint=True and setting its default value to False, especially if this is going to take long to be solved from the jax side.

Optimal FourierSeries implementation of Helmholtz operator

The Helmholtz operator needs to evaluate the following differential operator internally

$$ \nabla^2u + \frac{1}{\rho}(\nabla \rho \cdot \nabla u) $$

with the caveat that the differential operators need to be modified to account for the PML

At the moment, this is done by modifying the equation as

$$ \hat \nabla^2u + \frac{1}{\rho}(\nabla \rho \cdot \nabla u), \qquad \hat \nabla^2 = \sum_{\xi \in {x,y,z}} \frac{1}{\gamma_{\xi}} \frac{\partial}{\partial \xi}\frac{1}{\gamma_{\xi}} \frac{\partial}{\partial \xi} $$

(see this).

However, in the case of FourierSeries fields, the partial derivatives are not implemented correctly for even derivatives: in particular, the Nyquist frequency is not handled correctly. See Algorithm 1 here for more info on this regard.

Any other kind of implementation seems to work worse. One could use the identity

$$ \hat \nabla^2u + \frac{1}{\rho}(\nabla \rho \cdot \nabla u) =\frac{1}{\rho}[\nabla \rho \cdot \nabla u] $$

and implement the RHS using Algorithm 4, which corresponds to the heterog_laplacian operator of jaxdf. This also seems to reduce accuracy, although the accuracy of heterog_laplacian has not been tested yet.

This calls for further investigation.

Allow float64 for helmholtz solver

For large domain and a lot of damping, the gradient calculation for the Helmholtz solver may be inaccurate using float32. It should be possible to ensure that all operations are done with float64

Convergence plots

We need convergence plots, especially for FiniteDifferences (since FourierSeries is well tested against k-wave).

Both self-convergence and convergence against analytical solutions.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.