Giter Site home page Giter Site logo

jaxdecomp's Introduction

jaxDecomp: JAX Library for 3D Domain Decomposition and Parallel FFTs

Code Formatting

JAX bindings for NVIDIA's cuDecomp library (Romero et al. 2022), allowing for efficient multi-node parallel FFTs and halo exchanges directly in low level NCCL/CUDA-Aware MPI from your JAX code 🎉

Usage

Here is an example of how to use jaxDecomp to perform a 3D FFT on a 3D array distributed across multiple GPUs. This example also includes a halo exchange operation, which is a common operation in many scientific computing applications.

import jax
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P
import jaxdecomp

# Initialise the library, and optionally selects a communication backend (defaults to NCCL)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)

# Initialize jax distributed to instruct jax local process which GPU to use
jax.distributed.initialize()
rank = jax.process_index()

# Setup a processor mesh (should be same size as "size")
pdims= (1,4)
global_shape=[1024,1024,1024]

# Initialize an array with the expected gobal size
local_array = jax.random.normal(
    shape=[
        global_shape[0] // pdims[1], global_shape[1] // pdims[0],
        global_shape[2]
    ],
    key=jax.random.PRNGKey(rank))

 # Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
    local_array, mesh, P('z', 'y'))

# Forward FFT, note that the output FFT is transposed
@jax.jit
def modify_array(array):
    return 2 * array + 1

with mesh:
    # Forward FFT
    karray = jaxdecomp.fft.pfft3d(global_array)
    # Do some operation on your array
    karray = modify_array(karray)
    # Reverse FFT
    recarray = jaxdecomp.fft.pifft3d(karray).astype('float32')
    # Add halo regions to our array
    padding_width = ((32,32),(32,32),(32,32)) # Has to a tuple of tuples
    padded_array = jaxdecomp.slice_pad(recarray, padding_width , pdims)
    # Perform a halo exchange + reduce
    exchanged_reduced = jaxdecomp.halo_exchange(padded_array,
                                           halo_extents=(32,32,32),
                                           halo_periods=(True,True,True),
                                           reduce_halo=True)
    # Remove the halo regions
    recarray = jaxdecomp.slice_unpad(exchanged_reduced, padding_width, pdims)

    # Gather the results (only if it fits on CPU memory)
    gathered_array = multihost_utils.process_allgather(recarray, tiled=True)

# Finalize the library
jaxdecomp.finalize()
jax.distributed.shutdown()

Note: All these functions are jittable and have well defined derivatives!

This script would have to be run on 8 GPUs in total with something like

$ mpirun -n 8 python demo.py

On an HPC cluster like Jean Zay you should do this

$ srun python demo.py

Check the slurm README and template for more information on how to run on a Jean Zay.

Caveats

The code presented above should work, but there are a few caveats mentioned in this issue. If you need a functionality that is not currently implemented, feel free to mention it on that issue.

Install

Start by cloning this repository locally on your cluster:

$ git clone --recurse-submodules https://github.com/DifferentiableUniverseInitiative/jaxDecomp

Requirements

This install procedure assumes that the NVIDIA HPC SDK is available in your environment. You can either install it from the NVIDIA website, or better yet, it may be available as a module on your cluster.

Make sure all environment variables relative to the SDK are properly set.

Building jaxDecomp

From this directory, install & build jaxDecomp via pip

$ pip install --user .

If CMake complains of not finding the NVHPC SDK, you can manually specify the location of the sdk's cmake files like so:

$ export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVCOMPILERS/$NVARCH/22.9/cmake
$ pip install --user .

Specific Install Notes for Specific Machines

IDRIS Jean Zay HPE SGI 8600 supercomputer

As of April. 2024, the following works:

You need to load modules in that order exactly.

# Load NVHPC 23.9 because it has cuda 12.2
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda  openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
# Installing mpi4py
CFLAGS=-noswitcherror pip install mpi4py
# Installing jax
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake
pip install .

NERSC Perlmutter HPE Cray EX supercomputer

As of Nov. 2022, the following works:

module load PrgEnv-nvhpc python
export CRAY_ACCEL_TARGET=nvidia80
# Installing mpi4py
MPICC="cc -target-accel=nvidia80 -shared" CC=nvc CFLAGS="-noswitcherror" pip install --force --no-cache-dir --no-binary=mpi4py mpi4py
# Installing jax
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake
pip install .

Design

Here is what works now :

from jaxdecomp.fft import pfft3, ipfft3

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils, multihost_utils

# Initialize jax distributed to instruct jax local process which GPU to use
jax.distributed.initialize()

pdims = (2 , 4)
global_shape = (512 , 512 , 512 )

local_array = jax.random.normal(shape=[global_shape[0]//pdims[0],
                                        global_shape[1]//pdims[1],
                                        global_shape[2]], key=jax.random.PRNGKey(0))

# remap to global array (this is a free call no communications are happening)

devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
    array, mesh, P('z', 'y'))


with mesh
    z = pfft3(global_array)

    # If we could inspect the distribution of y, we would see that it is sliced in 2 along x, and 4 along y

    # This could also be part of a jitted function, no problem
    z_rec = ipfft3(z)

# And z remains at all times distributed.

jaxdecomp.finalize()
jax.distributed.shutdown()

Backend configuration

We can set the default communication backend to use for cuDecomp operations either through a config module, or environment variables. This will allow the users to choose at startup (although can be changed afterwards) the communication backend, making it possible to use CUDA-aware MPI or NVSHMEM as preferred.

Here is how it would like:

jaxdecomp.config.update('transpose_comm_backend', 'NCCL')
# We could for instance time how long it takes to execute in this mode
%timeit pfft3(y)

# And then update the backend
jaxdecomp.config.update('transpose_comm_backend', 'MPI')
# And measure again
%timeit pfft3(y)

Autotune computational mesh

We can also make things fancier, since cuDecomp is able to autotune, we could use it to tell us what is the best way to partition the data given the available GPUs, something like this:

automesh = jaxdecomp.autotune(shape=[512,512,512])
# This is a JAX Sharding spec object, optimized for the given GPUs
# and shape of the tensor
sharding = PositionalSharding(automesh)

jaxdecomp's People

Contributors

askabalan avatar eiffl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

jaxdecomp's Issues

Manual assigment of GPU per rank

I noticed the following potential issue here:

CHECK_CUDA_EXIT(cudaSetDevice(local_rank));

This should not be necessary, the gpu we use is going to be decided by the stream we are provided with by jax. It also has the potential of being different between what jax decides and the rank decided here.

This is might also possibly be why the gpu binding is not working on jean zay.

Currently unsupported features and caveats (aka the big TODO list)

Because implementing all features is not necessarily interesting unless there is a need for it, here are the current restrictions of the code in its current version. All these restrictions can be lifted if you have a need for them, don't hesitate to comment on this issue if there is something you would like to be able to do!

General

  • Configuration mechanism to choose the communication API (by default CUDA aware MPI) only works at initialization
  • No interface to run autotuning to figure out the best communication strategy
  • No interface with the JAX 0.4 Array API

Transpose operations

  • Only transpose operations that preserve the size of local slices are supported
  • Transpose ops do not have batching or gradient operations implemented
  • No support for letting XLA allocate the workspace size needed for the transpose
  • Double precision is silently not supported (returns crazy values!)

FFTs

  • Only complex FFTs are implemented as a single CUDA-level operation
  • Only FFTs operations that preserve the size of local slices are supported

FFTs are not working properly

Comparing the 3D FFT computed by jaxdecomp and manually in jax, I realized that the result of fft3d does not match with the non-distributed version.
This could be due to a transposition of the pfft3d result, which is something more or less conventional, to save 2 all-to-all communications in a forward-backward step, but depending on the partitioning scheme, I get a result that is in different orders.

I have modified the FFT test to actually detect this problem in the fix_fft branch in #12

@ASKabalan can you take a look?

If we don't provide any other information to the user regarding the order of dimensions in the FFT, the user expects the following to be true:

pdims = (2, 2)
mesh_shape = (4, 4, 4)

devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('z', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('z', 'y'))

local_mesh_shape = [mesh_shape[0]//pdims[0], mesh_shape[1]//pdims[1], mesh_shape[2]]

z = jax.make_array_from_single_device_arrays(shape=mesh_shape,
                                             sharding=sharding,
                                             arrays=[jax.random.normal(key, local_mesh_shape)])

with mesh:
    kfield_dist = jaxdecomp.fft.pfft3d(z)

kfield_dist = multihost_utils.process_allgather(kfield_dist, tiled=True)

kfield = np.fft.fftn(multihost_utils.process_allgather(z, tiled=True))

# This should be true to within numerical accuracy
assert_allclose(kfield_dist, kfield )

Things to update for version 0.0.1

  • slide_unpad does not seem to be too happy with getting complex numbers as inputs:
  File "/mnt/home/flanusse/repo/jaxDecomp/scripts/demo.py", line 58, in <module>
    recarray = slice_unpad(exchanged_reduced, padding_width, pdims)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: custom_partitioner: TypeError: pad operand and padding_value must be same dtype: got complex64 and float32.
  • transpose operations are not yet updated to new API
  • Can we remove jaxdecomp.finalize() entirely?
  • #11
  • #13
  • Add a mechanism to run all the tests at once (currently, they can only be ran one by one)

Easier way to chose cuda versions for end users

Currently in the CMake, the cuda version is set to be 12.2.

jaxDecomp (and cuDecomp) can be compiled with 11.8 (no specific cuda 12 code)

JAX 0.4.26 and above no longer supports cuda 11, but some machines do not have the latest drivers so some users have to use JAX 0.4.25.

I propose to allow users to chose which version to compile jaxDecomp with like so

By default 12.2

pip install jaxdecomp 

or

pip install jaxdecomp[cuda11]
pip install jaxdecomp[cuda12]

But obviously we don't dowload the nvidia wheels, we still expect the user to have the modules loaded.

[Installation error] Run setup.py,command execution error

error message:

CMake Error at CMakeLists.txt:5 (find_package):
By not providing "FindNVHPC.cmake" in CMAKE_MODULE_PATH this project has
asked CMake to find a package configuration file provided by "NVHPC", but
CMake did not find one.

Could not find a package configuration file provided by "NVHPC" with any of
the following names:

NVHPCConfig.cmake
nvhpc-config.cmake
Add the installation prefix of "NVHPC" to CMAKE_PREFIX_PATH or set
"NVHPC_DIR" to a directory containing one of the above files. If "NVHPC"
provides a separate development package or SDK, be sure it has been
installed.

-- Configuring incomplete, errors occurred!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.