ucl-bug / jwave Goto Github PK
View Code? Open in Web Editor NEWA JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs
License: GNU Lesser General Public License v3.0
A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs
License: GNU Lesser General Public License v3.0
For scalar field, if missing it automatically adds a dimension to the field parameters by looking at the domain
It is amazing to see that people are starting to help with this project. Even if only for a bug report, it is important to acknowledge when people offer their free time to improve a project.
One simple way would be to add the all-contributors
bot, and follow the specification highlighted here: allcontributors.org
Others
Implement the B/A parameter for modeling non-linearity in acoustic media
See eq. (2.7) of the k-Wave manual
It is not obvious how to model it for the Helmholtz equation, since the separation of variable which is assumed for deriving the Helmholtz equation from the wave equation doesn't hold anymore, so better skip it for this milestone.
Describe the bug
Running the helmholtz problem tutorial produces the following error NotFoundLookupError: For function "helmholtz", signature Signature(jaxdf.discretization.FourierSeries, jwave.geometry.MediumType[<class 'float'>, <class 'float'>, <class 'float'>], builtins.float) could not be resolved.
which occurs at the line where params = helmholtz.default_params(src, medium, omega)
is defined.
To Reproduce
Steps to reproduce the behavior:
Expected behavior
A clear and concise description of what you expected to happen.
I was expecting j-Wave to solve the helmholtz equation :)
Desktop (please complete the following information):
Additional context
Add any other context about the problem here.
N/A
Need to add some more text and explanation.
There's now support for progress bar in jax scan
and while
loops!
See https://github.com/jeremiecoullon/jax-tqdm
It should be possible to add progress bars to simulations, which will greatly improve the experience for large problems.
The progress bars should be disabled by default, as they are probably inconvenient when jwave is used for machine learning applications on relatively small problems
The current implementation of get_field
for RealFourierSeries
uses the standard fft
and casts the result to real.
It should instead use the real FFT for improving performances
Accurate time varying source in k-space pseudo...
The user should receive an error when tries to do something like
sources = Sources(
positions=((32,),(32,),(32,)),
signals=jnp.stack([s1, s1]),
dt=time_axis.dt,
domain=domain,
)
since the number of signals is not the same as the number of sources positions
Hi, Can I ask about some bugs here?
When executing this script, it returns
"plum.function.NotFoundLookupError: For function "gradient", signature Signature(jaxdf.discretization.FiniteDifferences, builtins.list) could not be resolved."
It seems only "FourierSeries" is used for solving helmholtz for now.
For the theory on spectral methods, see the k-wave manual and the relative paper.
Currently the Helmholtz equation assumes Stoke's attenuation, which is equivalent to y=2
. It would be nice to use a fractional Laplacian in the time harmonic case too.
We need tests for the Rayleigh integral (harmonic formulation) in jwave.acoustics.time_harmonic.rayleigh_integral
It could be tested against angularSpectrum
of k-Wave or using some convergence plot against a known solution (e.g. propagation of a single-frequency gaussian beam)
Wouldbe great to optionally allow the user to give the slowness map instead of the speed of sound in Medium
.
This will require some checking that not both of them are defined, as well as rewriting of the appropriate functions of the solvers and propagators
The algorithm is described here: https://www.autodiff.org/ad16/Oral/Siskind_Checkpointing.pdf
Useful for time-domain simulations.
Would be good to have a table (or functions) to extract acoustic parameters of biological media that are relevant for biomedical ultrasound.
See for example:
Which is a fancy way to say that would be nice to show that the Helmholtz equation works with
A few times we have been discussing about being able to compile in advance a jitted simulator, to save time on the first execution.
It seems that it is now possible to do it on CPU and GPU: google/jax#13736
Just leaving this here as a reminder.
The package needs a support function for setting optimal PML sizes, probably as a method of the Medium
class. It should be analogous to the one provided by kWave, but discretization dependent
The current time integrator is only "optimal" for FourierSeries
fields thanks to the k-space operator correction. For finite differences, it would probably be better to allow the user to experiment with different kind of integrators.
This is a good reason for using diffrax
as the main ODE integrator library for time domain solvers.
Another good reason would be to expose to the user the advanced checkpointing methods enabled by diffrax.
An initial attempt to code the modified semi-implicit Eluer integrator used by the k-space equations is here: https://github.com/astanziola/diffrax/blob/main/diffrax/solver/semi_implicit_euler.py
It works pretty well, timings are roughly the same, and checkpointing works.
The API needs to be changed to make sure that the user can generally choose his favorite integrator.
Describe the bug
It does not make sense to call the Helmholtz operator with a real source field, as it is defined on complex fields.
One option would be to convert the field to complex, but this could be confusing, so it is probably better to just raise an error when one attempts to do so.
The size of the PML is fundamental for FourierSeries
simulations.
Consider implementing an utility function analogous to getOptimalPMLSize
From https://discordapp.com/channels/1116725189972078639/1151518909254610984 by Wael - MIT
:
I successfully ran the initial value problem you show in the documentation where jwave.png is loaded and uses as sources. However, I've been running into issues trying to recreate the results from the helmholtz problem test. The specific issue comes from this line: params = helmholtz.default_params(src, medium, omega=omega), which is always returning None for params. Would you mind helping with this? I see that default_params uses the method _bound_init_params from jaxdf.core, but for some reason the params passed back are always None. Thank you!
There are community-driven wheels for jax
and jaxlib
with GPU support via conda-forge
Relevant links:
I managed to get a working installation of jax
on with only conda
installed (no cuda or cudnn prior installation) using
$ conda create -n condajax python=3.9
$ conda activate condajax
$ conda install jaxlib jax --channel conda-forge
It is interesting to see if this works under windows, to avoid using the WSL or building jax
from scratch
Despite the fact that, as it is, the Convergent Born Series does not allow for an heterogeneous density, it is still a powerful technique, especially for low-contrast sound speed media.
In absence of density heterogeneity the Helmholtz operator is self-adjoint (up to a change of sign for the absorption term). This means that it can be possible to override the vjp
operation from jax with one that uses the CBS.
Potential things to look for:
The new lineax
library has a bunch of solvers for linear systems, including GMRES, that would be worth exploring to see if there's any benefit compared to jax
GMRES.
Furthermore, since at some point we would like to test out the WaveHoltz method, which kindly provides a positive definite and potentially symmetric operator ๐ , we might want to use different solvers such as CG, so setting up this framework now seems beneficial.
Relevant k-Wave function
It is good to have a logger, for example to let the user know what is the maximum supported frequency, what's the PML size and things like that.
The mathematical model using a k-space pseudo-spectral solver is described in this paper
The implementation could be tested against those examples from k-Wave
Describe the bug
Hi guys, I am trying to reproduce the Full Wave Inversion tutorial example. When batching the single simulation across sensor indices with jax.vmap
, I am getting NaNs, even though running those sensor simulations individually runs stably.
I'm running on a Colab with GPU.
To Reproduce
Colab to reproduce:
https://colab.research.google.com/drive/1U3ttIMkn4lZnlu4SMNf4Y0IcuB7Sv9MH?usp=sharing
Expected behavior
Batching vs. looping over single sensor simulations should produce the same result.
Desktop (please complete the following information):
Hi! I thought it would be really nice to introduce off-grid sensors into j-Wave, as they have done in k-Wave (http://www.k-wave.org/forum/topic/alpha-version-of-kwavearray-off-grid-sources).
I've implemented a basic version of this myself using the band-limited interpolant, which you can see in this gist: https://gist.github.com/tomelse/f9ba7508b75f44f34ebdbf25d5f5b0a3.
I'm happy to put a pull request in myself, I just wanted to get some feedback on how best to integrate it into the j-Wave code.
The results look really nice, removing all the staircasing that you can see in the on-grid version, which you can see in an example simulation above (x axis = time, y axis= detector). It's a little bit slower than the original j-Wave implementation but not prohibitively so (maybe that's to be expected?).
Line 356 in b293d95
It is tricky to configure mypy
and flake8
for this project, for several reason
mypy
sees a redefinition of unused functions. This is a problems inherited from plum.See the dissussion in #104
This will probably require to extend the Transducer class, for example by introducing some Transmit dataclasses that are used to define the forward simulation and are also used by the beamforming algorithm.
It should be clearer in the documentation the default PML value, and the fact that the PML is inside the domain.
There are some edge cases for which Helmholtz simulations in 3D give incorrect gradients when checkpoint=True
is used in the helmholtz_solver
for FourierSeries
fields.
This is likely to be a jax
issue, since results were correct up to jax 0.3.20
.
It has been raised upstream: google/jax#14302
For the moment, if memory requirements allow, setting checkpoint=False
should get around the problem.
May make sense to add a warning when checkpoint=True
and setting its default value to False
, especially if this is going to take long to be solved from the jax side.
Implementing the WaveHoltz hybrid solver
Describe the bug
Helmholtz notebook throws a leaked tracer error
To Reproduce
Run the helmholtz_problem.ipynb notebook
Expected behavior
No error
The Helmholtz operator needs to evaluate the following differential operator internally
with the caveat that the differential operators need to be modified to account for the PML
At the moment, this is done by modifying the equation as
(see this).
However, in the case of FourierSeries
fields, the partial derivatives are not implemented correctly for even derivatives: in particular, the Nyquist frequency is not handled correctly. See Algorithm 1 here for more info on this regard.
Any other kind of implementation seems to work worse. One could use the identity
and implement the RHS using Algorithm 4, which corresponds to the heterog_laplacian
operator of jaxdf
. This also seems to reduce accuracy, although the accuracy of heterog_laplacian
has not been tested yet.
This calls for further investigation.
Relevant k-Wave function: http://www.k-wave.org/documentation/acousticFieldPropagator.php
For large domain and a lot of damping, the gradient calculation for the Helmholtz solver may be inaccurate using float32. It should be possible to ensure that all operations are done with float64
We need convergence plots, especially for FiniteDifferences
(since FourierSeries
is well tested against k-wave).
Both self-convergence and convergence against analytical solutions.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.