Tutorial 6: Multiple GPUs#
This tutorial describes how to run Jaxion simulations on distributed GPUs. An example is provided in examples/soliton_binary_merger
First, ensure that JAX is initialized for distributed GPU use, by calling:
import jax
jax.distributed.initialize()
On some machines, you may need to specify additional parameters to the initialize() function.
For most SLURM clusters, you will not need to specify anything.
If you’d like your Python script to print info, it should by guarded by:
if jax.process_index() == 0:
print("Using distributed GPU mode")
to prevent multiple processes from printing simultaneously.
Slurm Example#
An example SLURM submission script for running our example on the Flatiron Rusty cluster is provided below:
#!/usr/bin/bash
#SBATCH --job-name=soliton_binary_merger
#SBATCH --output=slurm-%j.out
#SBATCH --error=slurm-%j.err
#SBATCH --partition=gpu
#SBATCH --constraint=h100
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=8
#SBATCH --mem=80G
#SBATCH --time=00-01:00
module purge
module load python/3.12.9
export PYTHONUNBUFFERED=TRUE
source $VENVDIR/jaxion-venv/bin/activate
srun --gpu-bind=none --cpu-bind=cores python soliton_binary_merger.py --res=16 --distributed