Examples#

The examples/ directory contains a collection of astrophysics simulations demonstrating various applications of the Jaxion library.

cosmological_box#

cosmological_box

See on GitHub: examples/cosmological_box#

README:

# Cosmological Box

Fuzzy Dark Matter cosmological box

Philip Mocz (2025)

Usage:

```console
python cosmological_box.py
```

Takes around 870 seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm000.png" alt="dm000" width="32%"/>
  <img src="dm060.png" alt="dm060" width="32%"/>
  <img src="dm070.png" alt="dm070" width="32%"/>
  <img src="dm080.png" alt="dm080" width="32%"/>
  <img src="dm090.png" alt="dm090" width="32%"/>
  <img src="dm100.png" alt="dm100" width="32%"/>
</div>


## References

[Mocz, P. et. al.; Galaxy formation with BECDM - II. Cosmic filaments and first galaxies. MNRAS (2020)](https://ui.adsabs.harvard.edu/abs/2020MNRAS.494.2027M)

Script:

import jax.numpy as jnp
import jaxion
import h5py

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Fuzzy Dark Matter cosmological box

Philip Mocz (2025)

Usage:
  python cosmological_box.py
"""


def set_up_simulation():
    # Parameters added/changed from default values
    params = {
        "physics": {
            "cosmology": True,
        },
        "domain": {
            "box_size": 1000.0,  # in h^-1 kpc
            "resolution_base": 256,
        },
        "time": {
            "start": 127.0,  # start at z=127
            "end": 7.0,  # end at z=7
        },
        "output": {
            "path": "./checkpoints/",
        },
        "quantum": {
            "m_22": 2.5,  # m = 2.5e-22 eV
        },
        "cosmology": {
            "omega_matter": 0.3,
            "omega_lambda": 0.7,
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # Set initial conditions from precomputed cosmological ICs
    # read in hdf5 file (which is in units of 1e10 Msun/kpc^2)
    with h5py.File("fdm_1mpc_256_m2.5e-22_z127_ic.hdf5", "r") as f:
        psi = jnp.array(f["psiRe"]) + 1.0j * jnp.array(f["psiIm"])
        psi *= 1e5

    sim.state["psi"] = psi

    return sim


def main():
    sim = set_up_simulation()
    sim.run()

    return sim


if __name__ == "__main__":
    main()

dynamical_friction#

dynamical_friction

See on GitHub: examples/dynamical_friction#

README:

# Dynamical Friction

Dynamical friction of Fuzzy Dark Matter soliton in an external potential

Philip Mocz (2025)

Usage:

```console
python dynamical_friction.py --res <resolution_multiplier>
```

Takes around 3 (res=1) seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm100.png" alt="dm100" width="45%"/>
</div>


## References

[Lancaster, L.; Giovanetti, C.; Mocz, P.; Kahn, Y.; Lisanti, M.; Spergel, D.N.; Dynamical friction in a Fuzzy Dark Matter universe. JCAP (2020)](https://ui.adsabs.harvard.edu/abs/2020JCAP...01..001L)

Script:

import jax.numpy as jnp
import argparse
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Dynamical of Fuzzy Dark Matter soliton in an external potential

Philip Mocz (2025)

Usage:
  python dynamical_friction.py --res <resolution_multiplier>
"""


def set_up_simulation(resolution_multiplier):
    # Parameters added/changed from default values
    params = {
        "physics": {
            "external_potential": True,
        },
        "domain": {
            "resolution_multiplier": resolution_multiplier,
        },
        "time": {
            "end": 0.1,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
            "plot_dynamic_range": 5.0,
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # Set initial conditions (uniform density w/ relative motion, in external potential)
    density = 1.0e6  # dm density M_sun/kpc^3
    k_rel = 5.0  # wavenumber of relative motion
    G = jaxion.constants["gravitational_constant"]
    box_size = sim.params["domain"]["box_size"]

    xx, yy, zz = sim.grid
    sim.state["psi"] = jnp.sqrt(density) * jnp.exp(
        1.0j * (2.0 * jnp.pi / box_size) * k_rel * xx
    )

    # add external potential (host halo)
    M_halo = 1e9  # M_sun
    print(f"M_halo: {M_halo:.2e} M_sun")

    r = jnp.sqrt(
        (xx - 0.5 * box_size) ** 2
        + (yy - 0.5 * box_size) ** 2
        + (zz - 0.5 * box_size) ** 2
    )
    sim.state["V_ext"] = -G * M_halo / (r + 0.5 * sim.dx)  # softening

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    args = parser.parse_args()

    sim = set_up_simulation(args.res)
    sim.run()
    print("mean |psi| =", jnp.mean(jnp.abs(sim.state["psi"])))

    return sim


if __name__ == "__main__":
    main()

heating_gas#

heating_gas

See on GitHub: examples/heating_gas#

README:

# Heating Gas

Heating Gas due to Fuzzy Dark Matter fluctuations

Philip Mocz (2025)

Usage:

```console
python heating_gas.py --res <resolution_multiplier>
```

Takes around 37 (res=1) seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm100.png" alt="dm100" width="45%"/>
  <img src="gas100.png" alt="gas100" width="45%"/>
</div>


## Analysis

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="power_spectrum.png" alt="power_spectrum" width="45%"/>
</div>

Script:

import jax.numpy as jnp
import numpy as np
import jaxdecomp as jd
import argparse
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Heating Gas due to Fuzzy Dark Matter fluctuations

Philip Mocz (2025)

Usage:
  python heating_gas.py --res <resolution_multiplier>
"""


def set_up_simulation(resolution_multiplier):
    # Parameters added/changed from default values
    params = {
        "physics": {
            "hydro": True,
        },
        "domain": {
            "resolution_multiplier": resolution_multiplier,
            "box_size": 4.0,  # kpc
        },
        "time": {
            "end": 0.4,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
            "plot_dynamic_range": 2.0,
        },
        "quantum": {
            "m_22": 1.0,  # axion mass in units of 10^-22 eV
        },
        "hydro": {
            "sound_speed": 20.0,  # km/s
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # average density of all matter (dm+gas) in the simulation (in units of Msun / kpc^3)
    rho_bar = 1.0e7

    # gas
    frac_gas = 0.15  # fraction of total mass in gas
    rho_gas = frac_gas * rho_bar  # average density of gas
    c_sound = sim.sound_speed  # sound speed (km/s)

    # dark matter
    frac_dm = 1.0 - frac_gas  # fraction of total mass in dark matter
    sigma = 40.0  # velocity dispersion of dm

    m = sim.axion_mass
    hbar = jaxion.constants["reduced_planck_constant"]
    m_per_hbar = m / hbar

    kx, ky, kz = sim.kgrid
    k_sq = kx**2 + ky**2 + kz**2

    nx = sim.resolution
    box_size = sim.box_size
    G = jaxion.constants["gravitational_constant"]

    # check that de broglie wavelength fits into box
    de_broglie_wavelength = hbar / (m * sigma)
    n_wavelengths = box_size / de_broglie_wavelength
    assert n_wavelengths > 1

    # check the Jeans length
    jeans_length = c_sound * jnp.sqrt(jnp.pi / (G * rho_gas))
    n_jeans = box_size / jeans_length
    assert n_jeans < 0.5

    # dark matter
    # construct in fourier space according to Eq (27) of our paper [https://arxiv.org/abs/1801.03507]
    np.random.seed(17)
    # initialize random phases
    # (use order to set lowest k modes first)
    sid = np.argsort(k_sq.flatten(), stable=True)
    psi = np.zeros((nx**3,), dtype=complex)
    psi[sid] = np.exp(1.0j * 2.0 * np.pi * np.random.rand(nx**3))
    psi = psi.reshape(k_sq.shape)
    psi = jnp.array(psi)
    psi *= np.sqrt(np.exp(-k_sq / (2.0 * sigma**2 * m_per_hbar**2)))
    psi = jd.fft.pifft3d(psi)
    # re-normalize it
    psi *= jnp.sqrt(frac_dm * rho_bar / jnp.mean(jnp.abs(psi) ** 2))

    # gas is initially uniform
    rho = jnp.ones((nx, nx, nx)) * rho_gas
    vx = jnp.zeros((nx, nx, nx))
    vy = jnp.zeros((nx, nx, nx))
    vz = jnp.zeros((nx, nx, nx))

    sim.state["psi"] = psi
    sim.state["rho"] = rho
    sim.state["vx"] = vx
    sim.state["vy"] = vy
    sim.state["vz"] = vz

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    args = parser.parse_args()

    sim = set_up_simulation(args.res)
    sim.run()
    print("mean |psi| =", jnp.mean(jnp.abs(sim.state["psi"])))
    print("mean |vx| =", jnp.mean(jnp.abs(sim.state["vx"])))
    print("mean |vy| =", jnp.mean(jnp.abs(sim.state["vy"])))
    print("mean |vz| =", jnp.mean(jnp.abs(sim.state["vz"])))

    return sim


if __name__ == "__main__":
    main()

heating_stars#

heating_stars

See on GitHub: examples/heating_stars#

README:

# Heating Stars

Heating Stars due to Fuzzy Dark Matter fluctuations

Philip Mocz (2025)

Usage:

```console
python heating_stars.py --res <resolution_multiplier>
```

Takes around 11 (res=1) seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm100.png" alt="dm100" width="45%"/>
</div>


## References

[Church, B.; Mocz, P.; Ostriker, J.P.; Heating of Milky Way disc Stars by Dark Matter Fluctuations in Cold Dark Matter and Fuzzy Dark Matter Paradigms. MNRAS (2019)](https://ui.adsabs.harvard.edu/abs/2019MNRAS.485.2861C)

Script:

import jax.numpy as jnp
import numpy as np
import jaxdecomp as jd
import argparse
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Heating Stars due to Fuzzy Dark Matter fluctuations

Philip Mocz (2025)

Usage:
  python heating_stars.py --res <resolution_multiplier>
"""


def set_up_simulation(resolution_multiplier):
    # average density of all matter (dm+stars) in the simulation (in units of Msun / kpc^3)
    rho_bar = 1.0e7

    # stars
    frac_stars = 0.15  # fraction of total mass in stars
    rho_stars = frac_stars * rho_bar  # average density of stars
    sigma_stars = 20.0  # velocity dispersion (1d) of stars (km/s)

    box_size = 4.0  # kpc

    n_stars = 400
    m_stars = rho_stars * box_size**3 / n_stars  # Msun

    # Parameters added/changed from default values
    params = {
        "physics": {
            "particles": True,
        },
        "domain": {
            "resolution_multiplier": resolution_multiplier,
            "box_size": box_size,
        },
        "time": {
            "end": 0.4,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
            "plot_dynamic_range": 2.0,
        },
        "quantum": {
            "m_22": 1.0,  # axion mass in units of 10^-22 eV
        },
        "particles": {"num_particles": n_stars, "particle_mass": m_stars},
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # dark matter
    frac_dm = 1.0 - frac_stars  # fraction of total mass in dark matter
    sigma = 40.0  # velocity dispersion of dm (km/s)

    m = sim.axion_mass
    hbar = jaxion.constants["reduced_planck_constant"]
    m_per_hbar = m / hbar

    kx, ky, kz = sim.kgrid
    k_sq = kx**2 + ky**2 + kz**2

    nx = sim.resolution
    G = jaxion.constants["gravitational_constant"]

    # check that de broglie wavelength fits into box
    de_broglie_wavelength = hbar / (m * sigma)
    n_wavelengths = box_size / de_broglie_wavelength
    assert n_wavelengths > 1

    # check the Jeans length
    jeans_length = sigma_stars * jnp.sqrt(jnp.pi / (G * rho_stars))
    n_jeans = box_size / jeans_length
    assert n_jeans < 0.5

    # dark matter
    # construct in fourier space according to Eq (27) of our paper [https://arxiv.org/abs/1801.03507]
    np.random.seed(17)
    # initialize random phases
    # initialize random phases
    # (use order to set lowest k modes first)
    sid = np.argsort(k_sq.flatten(), stable=True)
    psi = np.zeros((nx**3,), dtype=complex)
    psi[sid] = np.exp(1.0j * 2.0 * np.pi * np.random.rand(nx**3))
    psi = psi.reshape(k_sq.shape)
    psi = jnp.array(psi)
    psi *= np.sqrt(np.exp(-k_sq / (2.0 * sigma**2 * m_per_hbar**2)))
    psi = jd.fft.pifft3d(psi)
    # re-normalize it
    psi *= jnp.sqrt(frac_dm * rho_bar / jnp.mean(jnp.abs(psi) ** 2))

    # stars are initially uniform
    np.random.seed(17)
    pos = np.random.rand(n_stars, 3) * box_size
    vel = np.random.randn(n_stars, 3) * sigma_stars
    pos = jnp.array(pos)
    vel = jnp.array(vel)

    sim.state["psi"] = psi
    sim.state["pos"] = pos
    sim.state["vel"] = vel

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    args = parser.parse_args()

    sim = set_up_simulation(args.res)
    sim.run()
    print("mean |psi| =", jnp.mean(jnp.abs(sim.state["psi"])))
    print("mean |vx| =", jnp.mean(jnp.abs(sim.state["vel"][:, 0])))
    print("mean |vy| =", jnp.mean(jnp.abs(sim.state["vel"][:, 1])))
    print("mean |vz| =", jnp.mean(jnp.abs(sim.state["vel"][:, 2])))
    print(
        "std vel =",
        jnp.sqrt(
            jnp.std(sim.state["vel"][:, 0]) ** 2
            + jnp.std(sim.state["vel"][:, 1]) ** 2
            + jnp.std(sim.state["vel"][:, 2]) ** 2
        ),
    )

    return sim


if __name__ == "__main__":
    main()

kinetic_condensation#

kinetic_condensation

See on GitHub: examples/kinetic_condensation#

README:

# Kinetic Condensation

Soliton formation in the kinetic regime

Philip Mocz (2025) 

Usage:

```console
python kinetic_condensation.py --res <resolution_multiplier>
```

Takes around 34 (res=1) seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm000.png" alt="dm000" width="32%"/>
  <img src="dm020.png" alt="dm020" width="32%"/>
  <img src="dm040.png" alt="dm040" width="32%"/>
  <img src="dm060.png" alt="dm060" width="32%"/>
  <img src="dm080.png" alt="dm080" width="32%"/>
  <img src="dm100.png" alt="dm100" width="32%"/>
</div>


## References

[Levkov, D.G.; Panin, A.G.; Tkachev, I.I.; Gravitational Bose-Einstein condensation in the kinetic regime PRL (2018)](https://ui.adsabs.harvard.edu/abs/2018PhRvL.121o1301L)

Script:

import jax.numpy as jnp
import numpy as np
import jaxdecomp as jd
import argparse
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Soliton formation in the kinetic regime

Philip Mocz (2025)

Usage:
  python kinetic_condensation.py --res <resolution_multiplier>
"""


def set_up_simulation(resolution_multiplier):
    # Parameters added/changed from default values
    params = {
        "domain": {
            "resolution_multiplier": resolution_multiplier,
            "box_size": 6.0,  # kpc
        },
        "time": {
            "end": 10.0,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
            "plot_dynamic_range": 4.0,
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # average density of dark matter in the simulation (in units of Msun / kpc^3)
    rho_bar = 1.0e8

    # dark matter
    sigma = 100.0  # velocity dispersion of dm

    m = sim.axion_mass
    hbar = jaxion.constants["reduced_planck_constant"]
    m_per_hbar = m / hbar

    kx, ky, kz = sim.kgrid
    k_sq = kx**2 + ky**2 + kz**2

    nx = sim.resolution
    box_size = sim.box_size
    G = jaxion.constants["gravitational_constant"]

    # check that de broglie wavelength fits into box
    de_broglie_wavelength = hbar / (m * sigma)
    n_wavelengths = box_size / de_broglie_wavelength
    assert n_wavelengths > 1

    # check timescales
    lambda_fac = np.log(m * sigma * box_size / hbar)
    b = 1.0
    n = rho_bar / m
    # eqn 4 of Levkov
    tau_gr = (
        b
        * np.sqrt(2.0)
        / (12.0 * np.pi**3)
        * m
        * sigma**6
        / (G**2 * n**2 * lambda_fac)
        / hbar**3
    )
    # we want to see condensation happen in the simulation time
    assert 10.0 * tau_gr < sim.params["time"]["end"]

    kin1 = m * sigma * box_size / hbar  # box crossing time >> 1
    kin2 = m * sigma**2 * tau_gr / hbar  # condensation time  >> 1   eqn 1 of Levkov
    assert kin1 > 1
    assert kin2 > 1

    # check Jeans lengths
    jeans_length = sigma * np.sqrt(np.pi / (G * rho_bar))  # kinetic Jeans
    # quantum Jeans, eqn 40 of Levkov
    jeans_length_Q = np.pi / ((np.pi * G * rho_bar) ** 0.25 * m**0.5 / np.sqrt(hbar))

    assert jeans_length > box_size
    assert jeans_length_Q < box_size

    # dark matter
    # construct in fourier space according to Eq (27) of our paper [https://arxiv.org/abs/1801.03507]
    np.random.seed(17)
    # initialize random phases
    # (use order to set lowest k modes first)
    sid = np.argsort(k_sq.flatten(), stable=True)
    psi = np.zeros((nx**3,), dtype=complex)
    psi[sid] = np.exp(1.0j * 2.0 * np.pi * np.random.rand(nx**3))
    psi = psi.reshape(k_sq.shape)
    psi = jnp.array(psi)
    psi *= np.sqrt(np.exp(-k_sq / (2.0 * sigma**2 * m_per_hbar**2)))
    psi = jd.fft.pifft3d(psi)
    # re-normalize it
    psi *= jnp.sqrt(rho_bar / jnp.mean(jnp.abs(psi) ** 2))

    sim.state["psi"] = psi

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    args = parser.parse_args()

    sim = set_up_simulation(args.res)
    sim.run()

    return sim


if __name__ == "__main__":
    main()

logo_inverse_problem#

logo_inverse_problem

See on GitHub: examples/logo_inverse_problem#

README:

# Logo Inverse Problem

Inverse problem: Find initial conditions for velocity (dm+stars) to achieve target density at t=1

Philip Mocz (2025)

Usage:

```console
python logo_inverse_problem.py
```

Takes around 200 seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm100.png" alt="dm100" width="45%"/>
</div>

Script:

import jax
import jax.numpy as jnp

import jaxion
import chex
from typing import NamedTuple
import optax
import time
import matplotlib.image as img

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Inverse problem: Find initial conditions for velocity (dm+stars) to achieve target density at t=1

Philip Mocz (2025)

Usage:
  python logo_inverse_problem.py
"""


def set_up_simulation(save=False):
    # Parameters added/changed from default values
    params = {
        "physics": {
            "particles": True,
        },
        "domain": {
            "resolution_base": 32,
            "box_size": 20.0,  # kpc
        },
        "time": {
            "end": 1.0,
        },
        "output": {
            "path": "./checkpoints/",
            "num_checkpoints": 100,
            "save": save,
            "plot_dynamic_range": 10.0,
        },
        "particles": {
            "num_particles": 64,
            "particle_mass": 1.0e7,  # Msun
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # dark matter
    rho_dm = 1.0e5  # (Msun / kpc^3)

    box_size = sim.box_size
    nx = sim.resolution

    # dark matter
    psi = jnp.sqrt(rho_dm) * jnp.ones((nx, nx, nx)) + 0j

    # stars are initially uniformly distributed
    # Arrange positions in a uniform grid in x-y plane, z=box_size/2
    num_stars = params["particles"]["num_particles"]
    side = int(num_stars ** (1 / 2))
    xlin = jnp.linspace(0, box_size, side, endpoint=False) + box_size / (2 * side)
    ylin = jnp.linspace(0, box_size, side, endpoint=False) + box_size / (2 * side)
    zlin = jnp.array([box_size / 2.0])
    xx, yy, zz = jnp.meshgrid(xlin, ylin, zlin, indexing="ij")
    pos = jnp.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T
    vel = jnp.zeros((num_stars, 3))

    sim.state["psi"] = psi
    sim.state["pos"] = pos
    sim.state["vel"] = vel

    return sim


class InfoState(NamedTuple):
    iter_num: chex.Numeric


def print_info():
    def init_fn(params):
        del params
        return InfoState(iter_num=0)

    def update_fn(updates, state, params, *, value, grad, **extra_args):
        del params, extra_args

        jax.debug.print(
            "Iteration: {i}, Loss: {v:.2e}, |grad|: {e:.2e}",
            i=state.iter_num,
            v=value,
            e=optax.tree_utils.tree_norm(grad),
        )
        return updates, InfoState(iter_num=state.iter_num + 1)

    return optax.GradientTransformationExtraArgs(init_fn, update_fn)


def run_opt(init_params, fun, opt, max_iter, tol):
    value_and_grad_fun = optax.value_and_grad_from_state(fun)

    def step(carry):
        params, state = carry
        value, grad = value_and_grad_fun(params, state=state)
        updates, state = opt.update(
            grad, state, params, value=value, grad=grad, value_fn=fun
        )
        params = optax.apply_updates(params, updates)
        return params, state

    def continuing_criterion(carry):
        _, state = carry
        iter_num = optax.tree_utils.tree_get(state, "count")
        grad = optax.tree_utils.tree_get(state, "grad")
        err = optax.tree_utils.tree_norm(grad)
        return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))

    init_carry = (init_params, opt.init(init_params))
    final_params, final_state = jax.lax.while_loop(
        continuing_criterion, step, init_carry
    )
    return final_params, final_state


def solve_inverse_problem():
    """Optimize the initial star positions to recreate the logo"""
    # Load the target density field
    target_data = img.imread("target.png")[:, :, 0]
    target_data = target_data[::2, ::2]  # downsample
    target = jnp.flipud(jnp.array(target_data, dtype=float)).T
    target = 1.0 - 1.6 * (target - 0.5)
    target /= jnp.mean(target)

    @jax.jit
    def loss_function(theta):
        sim = set_up_simulation()
        sim.state["vel"] = theta

        sim.run()

        projected_density = jnp.mean(jnp.abs(sim.state["psi"]) ** 2, axis=2)
        norm = jnp.mean(projected_density)
        projected_density /= norm

        error_norm = jnp.mean((projected_density - target) ** 2)

        return error_norm

    # opt = optax.lbfgs()
    opt = optax.chain(print_info(), optax.lbfgs())

    sim = set_up_simulation()
    init_params = sim.state["vel"]

    print(
        f"Initial value: {loss_function(init_params):.2e} "
        f"Initial gradient norm: {optax.tree_utils.tree_norm(jax.grad(loss_function)(init_params)):.2e}"
    )
    t0 = time.time()
    final_params, _ = run_opt(init_params, loss_function, opt, max_iter=100, tol=1e-5)
    print("Inverse-problem solve time (s): ", time.time() - t0)
    print(
        f"Final value: {loss_function(final_params):.2e}, "
        f"Final gradient norm: {optax.tree_utils.tree_norm(jax.grad(loss_function)(final_params)):.2e}"
    )

    return final_params


def main():
    optimized_ics = solve_inverse_problem()

    sim = set_up_simulation(save=True)
    sim.state["vel"] = optimized_ics
    sim.run()

    return sim


if __name__ == "__main__":
    main()

self_interaction_collapse#

self_interaction_collapse

See on GitHub: examples/self_interaction_collapse#

README:

# Self-Interaction Collapse

Collapse a soliton due to attractive self-interaction

Philip Mocz (2025)

Usage:

```console
python self_interaction_collapse.py --res <resolution_multiplier>
```

Takes around 5 (res=1) seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm000.png" alt="dm000" width="45%"/>
  <img src="dm064.png" alt="dm002" width="45%"/>
</div>


## References

[Chavanis, P.H.; Phase transitions between dilute and dense axion stars. Phys. Rev. D. (2018)](https://ui.adsabs.harvard.edu/abs/2018PhRvD..98b3009C)

[Mocz, P. et. al.; Cosmological Structure Formation and Soliton Phase Transition in Fuzzy Dark Matter with Axion Self-Interactions. MNRAS (2023)](https://ui.adsabs.harvard.edu/abs/2023MNRAS.521.2608M)

Script:

import jax
import jax.numpy as jnp
import argparse
import jaxion

# switch on for double precision
jax.config.update("jax_enable_x64", True)

"""
Collapse a soliton due to attractive self-interaction

Philip Mocz (2025)

Usage:
  python self_interaction_collapse.py --res <resolution_multiplier>
"""


def set_up_simulation(resolution_multiplier):
    # Parameters added/changed from default values
    params = {
        "domain": {
            "resolution_multiplier": resolution_multiplier,
            "box_size": 4.0,
        },
        "time": {
            "end": 0.1,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
            "save": True,
        },
        "quantum": {
            "m_22": 1.0,  # axion mass in 10^-22 eV
            "f_15": -0.8,  # decay constant in 10^15 GeV
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # Set initial conditions (randomly placed solitons)
    m_22 = sim.params["quantum"]["m_22"]
    hbar = jaxion.constants["reduced_planck_constant"]
    G = jaxion.constants["gravitational_constant"]
    m = sim.axion_mass
    a_s = sim.scattering_length
    box_size = sim.params["domain"]["box_size"]
    xx, yy, zz = sim.grid

    print("a_s =", a_s)
    M_soliton = 1.0e9
    r_soliton = 2.2e8 * m_22**-2 / M_soliton  # in kpc

    M_crit = 1.012 * hbar / jnp.sqrt(G * m * jnp.abs(a_s))
    print("M_soliton/M_crit =", M_soliton / M_crit)

    def rho_soliton(r, r_soliton, m_22):
        return (
            1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
        )

    r = jnp.sqrt(
        (xx - 0.5 * box_size) ** 2
        + (yy - 0.5 * box_size) ** 2
        + (zz - 0.5 * box_size) ** 2
    )
    rho = rho_soliton(r, r_soliton, m_22)

    sim.state["psi"] = jnp.array(jnp.sqrt(rho)) + 0j

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    args = parser.parse_args()

    sim = set_up_simulation(args.res)
    sim.run()

    return sim


if __name__ == "__main__":
    main()

soliton_binary_merger#

soliton_binary_merger

See on GitHub: examples/soliton_binary_merger#

README:

# Soliton Binary Merger

Collide two solitons with opposite phases

Philip Mocz (2025)

Usage:

```console
python soliton_binary_merger.py --res <resolution_multiplier>
```

Takes around 4 (res=1) seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm000.png" alt="dm000" width="32%"/>
  <img src="dm002.png" alt="dm002" width="32%"/>
  <img src="dm004.png" alt="dm004" width="32%"/>
  <img src="dm006.png" alt="dm006" width="32%"/>
  <img src="dm008.png" alt="dm008" width="32%"/>
  <img src="dm010.png" alt="dm010" width="32%"/>
</div>

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="timing.png" alt="timing" width="45%"/>
</div>


## References

[Schwabe, B.; Niemeyer, J.C.; Engels, J.F.; Simulations of solitonic core mergers in ultralight axion dark matter cosmologies. Phys Rev D. (2016)](https://ui.adsabs.harvard.edu/abs/2016PhRvD..94d3513S)

[Mocz, P. et. al.; Galaxy formation with BECDM - I. Turbulence and relaxation of idealized haloes. MNRAS (2017)](https://ui.adsabs.harvard.edu/abs/2017MNRAS.471.4559M)

Script:

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import argparse
import os
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Collide two solitons with opposite phases

Philip Mocz (2025)

Usage:
  python soliton_binary_merger.py --res <resolution_multiplier>

  python soliton_binary_merger.py --distributed --emulate
"""


def set_up_simulation(resolution_multiplier, sharding):
    # Parameters added/changed from default values
    params = {
        "domain": {
            "resolution_multiplier": resolution_multiplier,
        },
        "time": {
            "end": 1.0,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
            "save": True,
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params, sharding=sharding)

    # Set initial conditions (two solitons)
    m_22 = sim.params["quantum"]["m_22"]
    box_size = sim.params["domain"]["box_size"]
    r_soliton = 0.02 * box_size
    xx, yy, zz = sim.grid

    def rho_soliton(r, r_soliton, m_22):
        return (
            1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
        )

    rho = 0.0
    r_soliton = 0.02 * box_size
    for i in range(2):
        x_soliton = (0.5 + 0.3 * (i - 0.5)) * box_size
        y_soliton = 0.5 * box_size
        z_soliton = 0.5 * box_size
        r = jnp.sqrt(
            (xx - x_soliton) ** 2 + (yy - y_soliton) ** 2 + (zz - z_soliton) ** 2
        )
        rho += rho_soliton(r, r_soliton, m_22)

    half = (xx - 0.5 * box_size) > 0
    sim.state["psi"] = jnp.array(jnp.sqrt(rho) + 0j) * (
        1.0 * half + -1.0 * (1.0 - half)
    )

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    parser.add_argument(
        "--distributed", action="store_true", help="run in distributed mode"
    )
    parser.add_argument(
        "--emulate", action="store_true", help="emulate distributed mode on CPU"
    )
    args = parser.parse_args()

    if args.distributed:
        if args.emulate:
            flags = os.environ.get("XLA_FLAGS", "")
            flags += " --xla_force_host_platform_device_count=8"  # change to, e.g., 8 for testing sharding virtually
            os.environ["CUDA_VISIBLE_DEVICES"] = ""
            os.environ["XLA_FLAGS"] = flags
            if jax.process_index() == 0:
                print("Using emulated distributed CPU mode")
        else:
            jax.distributed.initialize()
            if jax.process_index() == 0:
                print("Using distributed GPU mode")
        # Create mesh and sharding for distributed computation
        n_devices = jax.device_count()
        devices = mesh_utils.create_device_mesh((1, n_devices))
        mesh = Mesh(devices, axis_names=("x", "y"))
        sharding = NamedSharding(mesh, PartitionSpec("x", "y"))
    else:
        sharding = None

    sim = set_up_simulation(args.res, sharding)
    sim.run()

    return sim


if __name__ == "__main__":
    main()

soliton_gas_star#

soliton_gas_star

See on GitHub: examples/soliton_gas_star#

README:

# Soliton Gas Star

Explore the dynamics of a soliton, gas, and single star

Philip Mocz (2025)

Usage:

```console
python soliton_gas_star.py --res <resolution_multiplier>
```

Takes around 20 (res=1) seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm010.png" alt="dm010" width="45%"/>
  <img src="gas010.png" alt="gas010" width="45%"/>
</div>

Script:

import jax.numpy as jnp
import argparse
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Explore the dynamics of a soliton, gas, and single star

Philip Mocz (2025)

Usage:
  python soliton_gas_star.py --res <resolution_multiplier>
"""


def set_up_simulation(resolution_multiplier):
    # Parameters added/changed from default values
    params = {
        "physics": {
            "hydro": True,
            "particles": True,
        },
        "domain": {
            "resolution_multiplier": resolution_multiplier,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
        },
        "hydro": {
            "sound_speed": 40.0,  # km/s
        },
        "particles": {"num_particles": 1, "particle_mass": 1.0e7},
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # Set initial conditions
    M_soliton = 1.0e9  # mass of soliton (M_sun)
    m_22 = sim.params["quantum"]["m_22"]
    r_soliton = 2.2e8 * m_22**-2 / M_soliton  # in kpc
    box_size = sim.params["domain"]["box_size"]
    nx = sim.resolution

    xx, yy, zz = sim.grid
    r = jnp.sqrt(
        (xx - 0.5 * box_size) ** 2
        + (yy - 0.5 * box_size) ** 2
        + (zz - 0.5 * box_size) ** 2
    )
    sim.state["psi"] = (
        jnp.sqrt(
            1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
        )
        + 0.0j
    )

    # add gas
    frac_gas = 0.1
    rho_gas = frac_gas * M_soliton / box_size**3
    sim.state["rho"] = jnp.ones((nx, nx, nx)) * rho_gas
    sim.state["vx"] = jnp.zeros((nx, nx, nx))
    sim.state["vy"] = jnp.zeros((nx, nx, nx))
    sim.state["vz"] = jnp.zeros((nx, nx, nx))

    # add star
    pos = jnp.array([[0.6 * box_size, 0.5 * box_size, 0.5 * box_size]])
    vel = jnp.array([[0.0, 40.0, 0.0]])

    sim.state["pos"] = pos
    sim.state["vel"] = vel

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    args = parser.parse_args()

    sim = set_up_simulation(args.res)
    sim.run()

    return sim


if __name__ == "__main__":
    main()

soliton_merger#

soliton_merger

See on GitHub: examples/soliton_merger#

README:

# Soliton Merger

Merge solitons to form an idealized Fuzzy Dark Matter halo

Philip Mocz (2025)

Usage:

```console
python soliton_merger.py --res <resolution_multiplier>
```

Takes around 4 (res=1) seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm000.png" alt="dm000" width="32%"/>
  <img src="dm002.png" alt="dm002" width="32%"/>
  <img src="dm004.png" alt="dm004" width="32%"/>
  <img src="dm006.png" alt="dm006" width="32%"/>
  <img src="dm008.png" alt="dm008" width="32%"/>
  <img src="dm010.png" alt="dm010" width="32%"/>
</div>


## References

[Mocz, P. et. al.; Galaxy formation with BECDM - I. Turbulence and relaxation of idealized haloes. MNRAS (2017)](https://ui.adsabs.harvard.edu/abs/2017MNRAS.471.4559M)

Script:

import jax.numpy as jnp
import numpy as np
import argparse
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Merge solitons to form an idealized Fuzzy Dark Matter halo

Philip Mocz (2025)

Usage:
  python soliton_merger.py --res <resolution_multiplier>
"""


def set_up_simulation(resolution_multiplier):
    # Parameters added/changed from default values
    params = {
        "domain": {
            "resolution_multiplier": resolution_multiplier,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # Set initial conditions (randomly placed solitons)
    m_22 = sim.params["quantum"]["m_22"]
    box_size = sim.params["domain"]["box_size"]
    nx = sim.resolution
    xx, yy, zz = sim.grid

    def rho_soliton(r, r_soliton, m_22):
        return (
            1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
        )

    np.random.seed(17)
    n_solitons = 8

    rho = np.zeros((nx, nx, nx), dtype=complex)
    for _ in range(n_solitons):
        r_soliton = (0.05 + 0.03 * np.random.rand()) * box_size
        x_soliton = (0.25 + 0.5 * np.random.rand()) * box_size
        y_soliton = (0.25 + 0.5 * np.random.rand()) * box_size
        z_soliton = (0.25 + 0.5 * np.random.rand()) * box_size

        r = jnp.sqrt(
            (xx - x_soliton) ** 2 + (yy - y_soliton) ** 2 + (zz - z_soliton) ** 2
        )
        rho += rho_soliton(r, r_soliton, m_22)

    sim.state["psi"] = jnp.array(jnp.sqrt(rho))

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    args = parser.parse_args()

    sim = set_up_simulation(args.res)
    sim.run()
    print("mean |psi| =", jnp.mean(jnp.abs(sim.state["psi"])))

    return sim


if __name__ == "__main__":
    main()

tidal_stripping#

tidal_stripping

See on GitHub: examples/tidal_stripping#

README:

# Tidal Stripping

Tidal Stripping of Fuzzy Dark Matter soliton in an external potential

Philip Mocz (2025)

Usage:

```console
python tidal_stripping.py --res <resolution_multiplier>
```

Takes around 4 (res=1) and 40 (res=2) seconds to run on my macbook (cpu).

Demonstrates adding an external potential, including an imaginary sponge potential to absorb material near the boundaries of the periodic box.


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="dm000.png" alt="dm000" width="32%"/>
  <img src="dm002.png" alt="dm002" width="32%"/>
  <img src="dm004.png" alt="dm004" width="32%"/>
  <img src="dm006.png" alt="dm006" width="32%"/>
  <img src="dm008.png" alt="dm008" width="32%"/>
  <img src="dm010.png" alt="dm010" width="32%"/>
</div>

Script:

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import argparse
import os
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Tidal Stripping of Fuzzy Dark Matter soliton in an external potential

Philip Mocz (2025)

Usage:
  python tidal_stripping.py --res <resolution_multiplier>

  python tidal_stripping.py --distributed --emulate
"""


def set_up_simulation(resolution_multiplier, sharding, save):
    # Parameters added/changed from default values
    params = {
        "physics": {
            "external_potential": True,
        },
        "domain": {
            "resolution_multiplier": resolution_multiplier,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
            "save": save,
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params, sharding=sharding)

    # Set initial conditions (orbiting soliton in external potential)
    M_soliton = 1.0e9  # mass of soliton (M_sun)
    k_soliton = 4.0  # wave-number for orbital motion of soliton
    r_separation = 2.0  # separation of soliton from center (kpc)
    m_22 = sim.params["quantum"]["m_22"]
    m = sim.axion_mass
    hbar = jaxion.constants["reduced_planck_constant"]
    G = jaxion.constants["gravitational_constant"]
    r_soliton = 2.2e8 * m_22**-2 / M_soliton  # in kpc
    box_size = sim.params["domain"]["box_size"]

    X, Y, Z = sim.grid
    R = jnp.sqrt(
        (X - 0.5 * box_size) ** 2
        + (Y - 0.5 * box_size - r_separation) ** 2
        + (Z - 0.5 * box_size) ** 2
    )
    sim.state["psi"] = (
        jnp.sqrt(
            1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (R / r_soliton) ** 2) ** 8
        )
        + 0.0j
    )
    # add circular orbit velocity
    sim.state["psi"] *= jnp.exp(1.0j * k_soliton * X)

    # add external potential (host halo)
    M_halo = 0.25 * box_size * k_soliton**2 * hbar**2 / (G * m**2)
    if jax.process_index() == 0:
        print(f"M_halo: {M_halo:.2e} M_sun")
    assert M_halo > M_soliton * 2.0  # halo should be much more massive than soliton

    R = jnp.sqrt(
        (X - 0.5 * box_size) ** 2
        + (Y - 0.5 * box_size) ** 2
        + (Z - 0.5 * box_size) ** 2
    )
    V_halo = -G * M_halo / (R + 0.5 * sim.dx)  # softening
    # add sponge potential to prevent material periodically crossing the box
    # see Schwabe et al. (2016)
    V_0 = G * M_halo / (box_size / 4.0)
    r_N = 0.5 * box_size
    r_p = (7 / 8) * r_N
    r_s = 0.5 * (r_N + r_p)
    delta = r_N - r_p
    V_sponge = (
        -0.5j
        * V_0
        * (2 + jnp.tanh((R - r_s) / delta) - jnp.tanh(r_s / delta))
        * jnp.heaviside(R - r_p, 0.0)
    )
    sim.state["V_ext"] = V_halo + V_sponge

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    parser.add_argument(
        "--save", type=bool, default=True, help="Save simulation output"
    )
    parser.add_argument(
        "--distributed", action="store_true", help="run in distributed mode"
    )
    parser.add_argument(
        "--emulate", action="store_true", help="emulate distributed mode on CPU"
    )
    args = parser.parse_args()

    if args.distributed:
        if args.emulate:
            flags = os.environ.get("XLA_FLAGS", "")
            flags += " --xla_force_host_platform_device_count=8"  # change to, e.g., 8 for testing sharding virtually
            os.environ["CUDA_VISIBLE_DEVICES"] = ""
            os.environ["XLA_FLAGS"] = flags
            if jax.process_index() == 0:
                print("Using emulated distributed CPU mode")
        else:
            jax.distributed.initialize()
            if jax.process_index() == 0:
                print("Using distributed GPU mode")
        # Create mesh and sharding for distributed computation
        n_devices = jax.device_count()
        devices = mesh_utils.create_device_mesh((1, n_devices))
        mesh = Mesh(devices, axis_names=("x", "y"))
        sharding = NamedSharding(mesh, PartitionSpec("x", "y"))
    else:
        sharding = None

    sim = set_up_simulation(args.res, sharding, args.save)
    sim.run()
    mean_psi = jnp.mean(jnp.abs(sim.state["psi"]))
    if jax.process_index() == 0:
        print("mean |psi| =", mean_psi)

    return sim


if __name__ == "__main__":
    main()

timing#

timing

See on GitHub: examples/timing#

README:

# Timing

Timing test for FDM+hydro based on heating_gas/

Philip Mocz (2025)

Usage:

```console
python timing.py --res <resolution_multiplier>
```

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="timing.png" alt="gas100" width="45%"/>
</div>

Script:

import jax.numpy as jnp
import numpy as np
import jaxdecomp as jd
import argparse
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Timing test based on heating_gas/

Philip Mocz (2025)

Usage:
  python timing.py --res <resolution_multiplier>
"""


def set_up_simulation(resolution_multiplier):
    # Parameters added/changed from default values
    params = {
        "physics": {
            "hydro": True,
        },
        "domain": {
            "resolution_multiplier": resolution_multiplier,
            "box_size": 4.0,  # kpc
        },
        "time": {
            "end": 0.01 * (1.0 / resolution_multiplier**2),
        },
        "output": {
            "num_checkpoints": 100,
            "save": False,
        },
        "hydro": {
            "sound_speed": 20.0,  # km/s
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # average density of all matter (dm+gas) in the simulation (in units of Msun / kpc^3)
    rho_bar = 1.0e7

    # gas
    frac_gas = 0.15  # fraction of total mass in gas
    rho_gas = frac_gas * rho_bar  # average density of gas
    c_sound = sim.sound_speed  # sound speed (km/s)

    # dark matter
    frac_dm = 1.0 - frac_gas  # fraction of total mass in dark matter
    sigma = 40.0  # velocity dispersion of dm

    m = sim.axion_mass
    hbar = jaxion.constants["reduced_planck_constant"]
    m_per_hbar = m / hbar

    kx, ky, kz = sim.kgrid
    k_sq = kx**2 + ky**2 + kz**2

    nx = sim.resolution
    box_size = sim.box_size
    G = jaxion.constants["gravitational_constant"]

    # check that de broglie wavelength fits into box
    de_broglie_wavelength = hbar / (m * sigma)
    n_wavelengths = box_size / de_broglie_wavelength
    assert n_wavelengths > 1

    # check the Jeans length
    jeans_length = c_sound * jnp.sqrt(jnp.pi / (G * rho_gas))
    n_jeans = box_size / jeans_length
    assert n_jeans < 0.5

    # dark matter
    # construct in fourier space according to Eq (27) of our paper [https://arxiv.org/abs/1801.03507]
    np.random.seed(17)
    # initialize random phases
    # (use order to set lowest k modes first)
    sid = np.argsort(k_sq.flatten(), stable=True)
    psi = np.zeros((nx**3,), dtype=complex)
    psi[sid] = np.exp(1.0j * 2.0 * np.pi * np.random.rand(nx**3))
    psi = psi.reshape(k_sq.shape)
    psi = jnp.array(psi)
    psi *= np.sqrt(np.exp(-k_sq / (2.0 * sigma**2 * m_per_hbar**2)))
    psi = jd.fft.pifft3d(psi)
    # re-normalize it
    psi *= jnp.sqrt(frac_dm * rho_bar / jnp.mean(jnp.abs(psi) ** 2))

    # gas is initially uniform
    rho = jnp.ones((nx, nx, nx)) * rho_gas
    vx = jnp.zeros((nx, nx, nx))
    vy = jnp.zeros((nx, nx, nx))
    vz = jnp.zeros((nx, nx, nx))

    sim.state["psi"] = psi
    sim.state["rho"] = rho
    sim.state["vx"] = vx
    sim.state["vy"] = vy
    sim.state["vz"] = vz

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    args = parser.parse_args()

    sim = set_up_simulation(args.res)
    sim.run()

    return sim


if __name__ == "__main__":
    main()

two_field#

two_field
two_field2

See on GitHub: examples/two_field#

README:

# Two-Field

Demo for adding custom fields: Two-field fuzzy dark matter
Merge solitons in two-field FDM to form a halo

Philip Mocz (2025)

Usage:

```console
python two_field.py --res <resolution_multiplier>
```

Takes around 6 (res=1) seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="rho1_070.png" alt="rho1_070" width="45%"/>
  <img src="rho2_070.png" alt="rho2_070" width="45%"/>
</div>


## References

[Luu, H.N.; Mocz, P.; Vogelsberger, M.; Nested solitons in two-field fuzzy dark matter. MNRAS (2024)](https://ui.adsabs.harvard.edu/abs/2024MNRAS.527.4162L)

Script:

import jax
import jax.numpy as jnp
import numpy as np
import jaxdecomp as jd
import argparse
import jaxion

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Merge solitons to form an idealized Two-Field Fuzzy Dark Matter halo
Demonstrates adding custom fields

Philip Mocz (2025)

Usage:
  python two_field.py --res <resolution_multiplier>
"""


def set_up_simulation(resolution_multiplier):
    # Parameters added/changed from default values
    params = {
        "physics": {
            "quantum": False,  # use custom fields instead for psi1, psi2
            "gravity": True,
        },
        "domain": {
            "resolution_multiplier": resolution_multiplier,
        },
        "output": {
            "path": f"./checkpoints{resolution_multiplier}/",
        },
    }

    # Initialize the simulation
    sim = jaxion.Simulation(params)

    # Set initial conditions (randomly placed solitons)
    m_22_1 = 1.0
    m_22_2 = 2.0
    hbar = jaxion.constants["reduced_planck_constant"]
    eV = jaxion.constants["electron_volt"]
    c = jaxion.constants["speed_of_light"]
    m1 = m_22_1 * 1.0e-22 * eV / c**2
    m2 = m_22_2 * 1.0e-22 * eV / c**2
    m1_per_hbar = m1 / hbar
    m2_per_hbar = m2 / hbar

    box_size = sim.params["domain"]["box_size"]
    nx = sim.resolution
    xx, yy, zz = sim.grid

    def rho_soliton(r, r_soliton, m_22):
        return (
            1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
        )

    np.random.seed(17)
    n_solitons = 8

    rho1 = np.zeros((nx, nx, nx), dtype=complex)
    rho2 = np.zeros((nx, nx, nx), dtype=complex)
    for _ in range(n_solitons):
        r_soliton = (0.05 + 0.03 * np.random.rand()) * box_size
        x_soliton = (0.25 + 0.5 * np.random.rand()) * box_size
        y_soliton = (0.25 + 0.5 * np.random.rand()) * box_size
        z_soliton = (0.25 + 0.5 * np.random.rand()) * box_size

        r = jnp.sqrt(
            (xx - x_soliton) ** 2 + (yy - y_soliton) ** 2 + (zz - z_soliton) ** 2
        )
        rho1 += rho_soliton(r, r_soliton, m_22_1)
        rho2 += rho_soliton(r, r_soliton, m_22_2)

    sim.state["psi1"] = jnp.array(jnp.sqrt(rho1))
    sim.state["psi2"] = jnp.array(jnp.sqrt(rho2))

    # Define custom functions
    def custom_density(state):
        return jnp.abs(state["psi1"]) ** 2 + jnp.abs(state["psi2"]) ** 2

    def custom_kick(state, V, dt):
        state["psi1"] = jnp.exp(-1j * m1_per_hbar * dt * V) * state["psi1"]
        state["psi2"] = jnp.exp(-1j * m2_per_hbar * dt * V) * state["psi2"]

        return state

    def custom_drift(state, k_sq, dt):
        psi1_hat = jd.fft.pfft3d(state["psi1"])
        psi1_hat = jnp.exp(dt * (-1.0j * k_sq / m1_per_hbar / 2.0)) * psi1_hat
        state["psi1"] = jd.fft.pifft3d(psi1_hat)

        psi2_hat = jd.fft.pfft3d(state["psi2"])
        psi2_hat = jnp.exp(dt * (-1.0j * k_sq / m2_per_hbar / 2.0)) * psi2_hat
        state["psi2"] = jd.fft.pifft3d(psi2_hat)

        return state

    def custom_plot(state, checkpoint_dir, i, params):
        import matplotlib.pyplot as plt
        import os

        dynamic_range = params["output"]["plot_dynamic_range"]

        # process distributed data
        nx = state["psi1"].shape[0]
        rho_bar_1 = jnp.mean(jnp.abs(state["psi1"]) ** 2)
        rho_proj_1 = jnp.log10(
            jax.experimental.multihost_utils.process_allgather(
                jnp.mean(jnp.abs(state["psi1"]) ** 2, axis=2), tiled=True
            )
        ).T
        rho_bar_2 = jnp.mean(jnp.abs(state["psi2"]) ** 2)
        rho_proj_2 = jnp.log10(
            jax.experimental.multihost_utils.process_allgather(
                jnp.mean(jnp.abs(state["psi2"]) ** 2, axis=2), tiled=True
            )
        ).T

        # create plot on process 0
        if jax.process_index() == 0:
            plt.clf()

            # Field 1 projection
            vmin1 = jnp.log10(rho_bar_1 / dynamic_range)
            vmax1 = jnp.log10(rho_bar_1 * dynamic_range)

            ax = plt.gca()
            ax.imshow(
                rho_proj_1,
                cmap="inferno",
                origin="lower",
                vmin=vmin1,
                vmax=vmax1,
                extent=[0, nx, 0, nx],
            )
            ax.set_aspect("equal")
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

            plt.savefig(
                os.path.join(checkpoint_dir, f"rho1_{i:03d}.png"),
                bbox_inches="tight",
                pad_inches=0,
            )
            plt.close()

            # Field 2 projection
            vmin2 = jnp.log10(rho_bar_2 / dynamic_range)
            vmax2 = jnp.log10(rho_bar_2 * dynamic_range)
            plt.clf()

            ax = plt.gca()
            ax.imshow(
                rho_proj_2,
                cmap="inferno",
                origin="lower",
                vmin=vmin2,
                vmax=vmax2,
                extent=[0, nx, 0, nx],
            )
            ax.set_aspect("equal")
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

            plt.savefig(
                os.path.join(checkpoint_dir, f"rho2_{i:03d}.png"),
                bbox_inches="tight",
                pad_inches=0,
            )
            plt.close()

    sim.custom_density = custom_density
    sim.custom_plot = custom_plot
    sim.custom_kick = custom_kick
    sim.custom_drift = custom_drift

    print("rho_bar:", sim.rho_bar)

    return sim


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
    args = parser.parse_args()

    sim = set_up_simulation(args.res)
    sim.run()

    return sim


if __name__ == "__main__":
    main()