differentiableuniverseinitiative / jaxdecomp Goto Github PK
View Code? Open in Web Editor NEWJAX bindings for the NVIDIA cuDecomp library
License: MIT License
JAX bindings for the NVIDIA cuDecomp library
License: MIT License
File "/mnt/home/flanusse/repo/jaxDecomp/scripts/demo.py", line 58, in <module>
recarray = slice_unpad(exchanged_reduced, padding_width, pdims)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: custom_partitioner: TypeError: pad operand and padding_value must be same dtype: got complex64 and float32.
Because implementing all features is not necessarily interesting unless there is a need for it, here are the current restrictions of the code in its current version. All these restrictions can be lifted if you have a need for them, don't hesitate to comment on this issue if there is something you would like to be able to do!
General
Transpose operations
FFTs
Currently in the CMake, the cuda version is set to be 12.2.
jaxDecomp (and cuDecomp) can be compiled with 11.8 (no specific cuda 12 code)
JAX 0.4.26 and above no longer supports cuda 11, but some machines do not have the latest drivers so some users have to use JAX 0.4.25.
I propose to allow users to chose which version to compile jaxDecomp with like so
By default 12.2
pip install jaxdecomp
or
pip install jaxdecomp[cuda11]
pip install jaxdecomp[cuda12]
But obviously we don't dowload the nvidia wheels, we still expect the user to have the modules loaded.
error message:
CMake Error at CMakeLists.txt:5 (find_package):
By not providing "FindNVHPC.cmake" in CMAKE_MODULE_PATH this project has
asked CMake to find a package configuration file provided by "NVHPC", but
CMake did not find one.
Could not find a package configuration file provided by "NVHPC" with any of
the following names:
NVHPCConfig.cmake
nvhpc-config.cmake
Add the installation prefix of "NVHPC" to CMAKE_PREFIX_PATH or set
"NVHPC_DIR" to a directory containing one of the above files. If "NVHPC"
provides a separate development package or SDK, be sure it has been
installed.
-- Configuring incomplete, errors occurred!
Comparing the 3D FFT computed by jaxdecomp and manually in jax, I realized that the result of fft3d does not match with the non-distributed version.
This could be due to a transposition of the pfft3d result, which is something more or less conventional, to save 2 all-to-all communications in a forward-backward step, but depending on the partitioning scheme, I get a result that is in different orders.
I have modified the FFT test to actually detect this problem in the fix_fft
branch in #12
@ASKabalan can you take a look?
If we don't provide any other information to the user regarding the order of dimensions in the FFT, the user expects the following to be true:
pdims = (2, 2)
mesh_shape = (4, 4, 4)
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('z', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('z', 'y'))
local_mesh_shape = [mesh_shape[0]//pdims[0], mesh_shape[1]//pdims[1], mesh_shape[2]]
z = jax.make_array_from_single_device_arrays(shape=mesh_shape,
sharding=sharding,
arrays=[jax.random.normal(key, local_mesh_shape)])
with mesh:
kfield_dist = jaxdecomp.fft.pfft3d(z)
kfield_dist = multihost_utils.process_allgather(kfield_dist, tiled=True)
kfield = np.fft.fftn(multihost_utils.process_allgather(z, tiled=True))
# This should be true to within numerical accuracy
assert_allclose(kfield_dist, kfield )
I noticed the following potential issue here:
jaxDecomp/src/grid_descriptor_mgr.cc
Line 35 in 95415b1
This is might also possibly be why the gpu binding is not working on jean zay.
The test_ttf code demonstrate distributing data to multiple processes(GPUs),but how to gather processed data(fft3d)and merge back into one array, thinks!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.