Is this a new feature, an improvement, or a change to existing functionality?
New Feature
How would you describe the priority of this feature request
Medium
Please provide a clear description of problem you would like to solve.
I use Modulus DistributedManager
with SLURM. Right now, DistributedManager
sets the local_rank
based on the number of local processes on the node (this line).
local_rank = int(os.environ.get("SLURM_LOCALID"))
This line) then sets the device based on the local_rank
.
manager._device = torch.device(
f"cuda:{manager.local_rank}" if torch.cuda.is_available() else "cpu"
)
Notably, this line breaks if "SLURM_LOCALID" is greater than torch.cuda.device_count().
In my use case, however, I need to use the SBATCH โ-gpu-bind:map_gpus:0,1,2,3
flag on a node with 4 GPUs. With 4 processes per node and 4 GPUs per node, each process only sees 1 device called cuda:0, though that name actually refers to 4 different GPUs. (This forum explains why I need to use this flag.)
There may be other use cases where the number of local processes specified through SLURM may not equal the number of GPUs accessible (e.g. running FourCastNet with 4 GPUs and 1 process per GPU, but analyzing the output with more processes).
My request would be to add a flag to DistributedManager
, through which I could specify that the behavior below is desired for SLURM as well.
manager._local_rank = rank % torch.cuda.device_count()
This ensures that torch.device is not called on a device that can't be accessed.
Describe any alternatives you have considered
Without a flag, DistributedManager.initialize()
returns an error because torch.device is used to access a device that is not available. I could make an equivalent for DistributedManager, or I could create a subclass of DistributedManager that overrides the initialize_slurm
method. Let me know if that would be the preferred solution, and I can continue with my fix on my local end.