Quickstart

Contents

Quickstart#

To quickly try out Jaxion, install the package via pip:

pip install jaxion

If pip needs to build jaxdecomp from source, install the build prerequisites listed on the Installation page first.

And run a simple example simulation (examples/soliton_binary_merger/soliton_binary_merger.py):

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import argparse
import os
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Collide two solitons with opposite phases

Philip Mocz (2025)

Usage:
  python soliton_binary_merger.py --res <resolution_multiplier>

  python soliton_binary_merger.py --distributed --emulate
"""


def set_up_simulation(resolution_multiplier, sharding):
    # Parameters added/changed from default values
    params = {
        "domain": {
            "resolution_multiplier": resolution_multiplier,
        },
        "time": {
            "end": 1.0,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
            "save": True,
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params, sharding=sharding)

    # Set initial conditions (two solitons)
    m_22 = sim.params["quantum"]["m_22"]
    box_size = sim.params["domain"]["box_size"]
    r_soliton = 0.02 * box_size
    xx, yy, zz = sim.grid

    def rho_soliton(r, r_soliton, m_22):
        return (
            1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
        )

    rho = 0.0
    r_soliton = 0.02 * box_size
    for i in range(2):
        x_soliton = (0.5 + 0.3 * (i - 0.5)) * box_size
        y_soliton = 0.5 * box_size
        z_soliton = 0.5 * box_size
        r = jnp.sqrt(
            (xx - x_soliton) ** 2 + (yy - y_soliton) ** 2 + (zz - z_soliton) ** 2
        )
        rho += rho_soliton(r, r_soliton, m_22)

    half = (xx - 0.5 * box_size) > 0
    sim.state["psi"] = jnp.array(jnp.sqrt(rho) + 0j) * (
        1.0 * half + -1.0 * (1.0 - half)
    )

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    parser.add_argument(
        "--distributed", action="store_true", help="run in distributed mode"
    )
    parser.add_argument(
        "--emulate", action="store_true", help="emulate distributed mode on CPU"
    )
    args = parser.parse_args()

    if args.distributed:
        if args.emulate:
            flags = os.environ.get("XLA_FLAGS", "")
            flags += " --xla_force_host_platform_device_count=8"  # change to, e.g., 8 for testing sharding virtually
            os.environ["CUDA_VISIBLE_DEVICES"] = ""
            os.environ["XLA_FLAGS"] = flags
            if jax.process_index() == 0:
                print("Using emulated distributed CPU mode")
        else:
            jax.distributed.initialize()
            if jax.process_index() == 0:
                print("Using distributed GPU mode")
        # Create mesh and sharding for distributed computation
        n_devices = jax.device_count()
        devices = mesh_utils.create_device_mesh((1, n_devices))
        mesh = Mesh(devices, axis_names=("x", "y"))
        sharding = NamedSharding(mesh, PartitionSpec("x", "y"))
    else:
        sharding = None

    sim = set_up_simulation(args.res, sharding)
    sim.run()

    return sim


if __name__ == "__main__":
    main()

with a resolution boost factor of 2 as:

python soliton_binary_merger.py --res=2

Running this should take under a minute and produce output (in checkpoints2/) that look something like:

soliton binary merger

You can also run the example in the cloud: Colab notebook.

For more info#

For info on how to install Jaxion with GPU support, see the Installation page.

For more examples of simulations, see the Examples page.