Giter Site home page Giter Site logo

jaxdecomp's Issues

Things to update for version 0.0.1

  • slide_unpad does not seem to be too happy with getting complex numbers as inputs:
  File "/mnt/home/flanusse/repo/jaxDecomp/scripts/demo.py", line 58, in <module>
    recarray = slice_unpad(exchanged_reduced, padding_width, pdims)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: custom_partitioner: TypeError: pad operand and padding_value must be same dtype: got complex64 and float32.
  • transpose operations are not yet updated to new API
  • Can we remove jaxdecomp.finalize() entirely?
  • #11
  • #13
  • Add a mechanism to run all the tests at once (currently, they can only be ran one by one)

Currently unsupported features and caveats (aka the big TODO list)

Because implementing all features is not necessarily interesting unless there is a need for it, here are the current restrictions of the code in its current version. All these restrictions can be lifted if you have a need for them, don't hesitate to comment on this issue if there is something you would like to be able to do!

General

  • Configuration mechanism to choose the communication API (by default CUDA aware MPI) only works at initialization
  • No interface to run autotuning to figure out the best communication strategy
  • No interface with the JAX 0.4 Array API

Transpose operations

  • Only transpose operations that preserve the size of local slices are supported
  • Transpose ops do not have batching or gradient operations implemented
  • No support for letting XLA allocate the workspace size needed for the transpose
  • Double precision is silently not supported (returns crazy values!)

FFTs

  • Only complex FFTs are implemented as a single CUDA-level operation
  • Only FFTs operations that preserve the size of local slices are supported

Easier way to chose cuda versions for end users

Currently in the CMake, the cuda version is set to be 12.2.

jaxDecomp (and cuDecomp) can be compiled with 11.8 (no specific cuda 12 code)

JAX 0.4.26 and above no longer supports cuda 11, but some machines do not have the latest drivers so some users have to use JAX 0.4.25.

I propose to allow users to chose which version to compile jaxDecomp with like so

By default 12.2

pip install jaxdecomp 

or

pip install jaxdecomp[cuda11]
pip install jaxdecomp[cuda12]

But obviously we don't dowload the nvidia wheels, we still expect the user to have the modules loaded.

[Installation error] Run setup.py,command execution error

error message:

CMake Error at CMakeLists.txt:5 (find_package):
By not providing "FindNVHPC.cmake" in CMAKE_MODULE_PATH this project has
asked CMake to find a package configuration file provided by "NVHPC", but
CMake did not find one.

Could not find a package configuration file provided by "NVHPC" with any of
the following names:

NVHPCConfig.cmake
nvhpc-config.cmake
Add the installation prefix of "NVHPC" to CMAKE_PREFIX_PATH or set
"NVHPC_DIR" to a directory containing one of the above files. If "NVHPC"
provides a separate development package or SDK, be sure it has been
installed.

-- Configuring incomplete, errors occurred!

FFTs are not working properly

Comparing the 3D FFT computed by jaxdecomp and manually in jax, I realized that the result of fft3d does not match with the non-distributed version.
This could be due to a transposition of the pfft3d result, which is something more or less conventional, to save 2 all-to-all communications in a forward-backward step, but depending on the partitioning scheme, I get a result that is in different orders.

I have modified the FFT test to actually detect this problem in the fix_fft branch in #12

@ASKabalan can you take a look?

If we don't provide any other information to the user regarding the order of dimensions in the FFT, the user expects the following to be true:

pdims = (2, 2)
mesh_shape = (4, 4, 4)

devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('z', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('z', 'y'))

local_mesh_shape = [mesh_shape[0]//pdims[0], mesh_shape[1]//pdims[1], mesh_shape[2]]

z = jax.make_array_from_single_device_arrays(shape=mesh_shape,
                                             sharding=sharding,
                                             arrays=[jax.random.normal(key, local_mesh_shape)])

with mesh:
    kfield_dist = jaxdecomp.fft.pfft3d(z)

kfield_dist = multihost_utils.process_allgather(kfield_dist, tiled=True)

kfield = np.fft.fftn(multihost_utils.process_allgather(z, tiled=True))

# This should be true to within numerical accuracy
assert_allclose(kfield_dist, kfield )

Manual assigment of GPU per rank

I noticed the following potential issue here:

CHECK_CUDA_EXIT(cudaSetDevice(local_rank));

This should not be necessary, the gpu we use is going to be decided by the stream we are provided with by jax. It also has the potential of being different between what jax decides and the rank decided here.

This is might also possibly be why the gpu binding is not working on jean zay.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.