Tutorial 6: Multiple GPUs

Tutorial 6: Multiple GPUs#

This tutorial describes how to run Jaxion simulations on distributed GPUs. An example is provided in examples/soliton_binary_merger

First, ensure that JAX is initialized for distributed GPU use, by calling:

import jax

jax.distributed.initialize()

On some machines, you may need to specify additional parameters to the initialize() function. For most SLURM clusters, you will not need to specify anything.

If you’d like your Python script to print info, it should by guarded by:

if jax.process_index() == 0:
    print("Using distributed GPU mode")

to prevent multiple processes from printing simultaneously.

Sharding#

We need to set up sharding for the simulation state arrays. Sharding splits the arrays across multiple devices for distributed computation. This can be done as follows:

from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding

# Create mesh and sharding for distributed computation
n_devices = jax.device_count()
devices = mesh_utils.create_device_mesh((1, n_devices))
mesh = Mesh(devices, axis_names=("x", "y"))
sharding = NamedSharding(mesh, PartitionSpec("x", "y"))

In the example above, we create a virtual device mesh with one row and n_devices columns. Arrays (2D and 3D) are split along the “y” axis across the devices.

When creating the simulation object, we need to pass the sharding to it:

sim = jaxion.Simulation(params, sharding=sharding)

And that’s it! Jaxion will now run the simulation on multiple GPUs.

If you grab arrays from the simulation, such as the grid (sim.grid) or spectral grid (sim.kgrid), these will be sharded arrays too.

Slurm Example#

An example SLURM submission script for running our example on the Flatiron Rusty cluster is provided below:

#!/usr/bin/bash
#SBATCH --job-name=soliton_binary_merger
#SBATCH --output=slurm-%j.out
#SBATCH --error=slurm-%j.err
#SBATCH --partition=gpu
#SBATCH --constraint=h100
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=8
#SBATCH --mem=80G
#SBATCH --time=00-01:00

module purge
module load python/3.12.9

export PYTHONUNBUFFERED=TRUE

source $VENVDIR/jaxion-venv/bin/activate

srun --gpu-bind=none --cpu-bind=cores python soliton_binary_merger.py --res=16 --distributed