Examples#
The examples/ directory contains a collection of astrophysics simulations demonstrating various applications of the Jaxion library.
Gallery#
cosmological_box#
See on GitHub: examples/cosmological_box#
README:
# Cosmological Box
Fuzzy Dark Matter cosmological box
Philip Mocz (2025)
Usage:
```console
python cosmological_box.py
```
Takes around 870 seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm000.png" alt="dm000" width="32%"/>
<img src="dm060.png" alt="dm060" width="32%"/>
<img src="dm070.png" alt="dm070" width="32%"/>
<img src="dm080.png" alt="dm080" width="32%"/>
<img src="dm090.png" alt="dm090" width="32%"/>
<img src="dm100.png" alt="dm100" width="32%"/>
</div>
## References
[Mocz, P. et. al.; Galaxy formation with BECDM - II. Cosmic filaments and first galaxies. MNRAS (2020)](https://ui.adsabs.harvard.edu/abs/2020MNRAS.494.2027M)
Script:
import jax.numpy as jnp
import jaxion
import h5py
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Fuzzy Dark Matter cosmological box
Philip Mocz (2025)
Usage:
python cosmological_box.py
"""
def set_up_simulation():
# Parameters added/changed from default values
params = {
"physics": {
"cosmology": True,
},
"domain": {
"box_size": 1000.0, # in h^-1 kpc
"resolution_base": 256,
},
"time": {
"start": 127.0, # start at z=127
"end": 7.0, # end at z=7
},
"output": {
"path": "./checkpoints/",
},
"quantum": {
"m_22": 2.5, # m = 2.5e-22 eV
},
"cosmology": {
"omega_matter": 0.3,
"omega_lambda": 0.7,
},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# Set initial conditions from precomputed cosmological ICs
# read in hdf5 file (which is in units of 1e10 Msun/kpc^2)
with h5py.File("fdm_1mpc_256_m2.5e-22_z127_ic.hdf5", "r") as f:
psi = jnp.array(f["psiRe"]) + 1.0j * jnp.array(f["psiIm"])
psi *= 1e5
sim.state["psi"] = psi
return sim
def main():
sim = set_up_simulation()
sim.run()
return sim
if __name__ == "__main__":
main()
dynamical_friction#
See on GitHub: examples/dynamical_friction#
README:
# Dynamical Friction
Dynamical friction of Fuzzy Dark Matter soliton in an external potential
Philip Mocz (2025)
Usage:
```console
python dynamical_friction.py --res <resolution_multiplier>
```
Takes around 3 (res=1) seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm100.png" alt="dm100" width="45%"/>
</div>
## References
[Lancaster, L.; Giovanetti, C.; Mocz, P.; Kahn, Y.; Lisanti, M.; Spergel, D.N.; Dynamical friction in a Fuzzy Dark Matter universe. JCAP (2020)](https://ui.adsabs.harvard.edu/abs/2020JCAP...01..001L)
Script:
import jax.numpy as jnp
import argparse
import jaxion
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Dynamical of Fuzzy Dark Matter soliton in an external potential
Philip Mocz (2025)
Usage:
python dynamical_friction.py --res <resolution_multiplier>
"""
def set_up_simulation(resolution_multiplier):
# Parameters added/changed from default values
params = {
"physics": {
"external_potential": True,
},
"domain": {
"resolution_multiplier": resolution_multiplier,
},
"time": {
"end": 0.1,
},
"output": {
"path": f"./checkpoints{resolution_multiplier}/",
"plot_dynamic_range": 5.0,
},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# Set initial conditions (uniform density w/ relative motion, in external potential)
density = 1.0e6 # dm density M_sun/kpc^3
k_rel = 5.0 # wavenumber of relative motion
G = jaxion.constants["gravitational_constant"]
box_size = sim.params["domain"]["box_size"]
xx, yy, zz = sim.grid
sim.state["psi"] = jnp.sqrt(density) * jnp.exp(
1.0j * (2.0 * jnp.pi / box_size) * k_rel * xx
)
# add external potential (host halo)
M_halo = 1e9 # M_sun
print(f"M_halo: {M_halo:.2e} M_sun")
r = jnp.sqrt(
(xx - 0.5 * box_size) ** 2
+ (yy - 0.5 * box_size) ** 2
+ (zz - 0.5 * box_size) ** 2
)
sim.state["V_ext"] = -G * M_halo / (r + 0.5 * sim.dx) # softening
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
args = parser.parse_args()
sim = set_up_simulation(args.res)
sim.run()
print("mean |psi| =", jnp.mean(jnp.abs(sim.state["psi"])))
return sim
if __name__ == "__main__":
main()
heating_gas#
See on GitHub: examples/heating_gas#
README:
# Heating Gas
Heating Gas due to Fuzzy Dark Matter fluctuations
Philip Mocz (2025)
Usage:
```console
python heating_gas.py --res <resolution_multiplier>
```
Takes around 37 (res=1) seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm100.png" alt="dm100" width="45%"/>
<img src="gas100.png" alt="gas100" width="45%"/>
</div>
## Analysis
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="power_spectrum.png" alt="power_spectrum" width="45%"/>
</div>
Script:
import jax.numpy as jnp
import numpy as np
import jaxdecomp as jd
import argparse
import jaxion
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Heating Gas due to Fuzzy Dark Matter fluctuations
Philip Mocz (2025)
Usage:
python heating_gas.py --res <resolution_multiplier>
"""
def set_up_simulation(resolution_multiplier):
# Parameters added/changed from default values
params = {
"physics": {
"hydro": True,
},
"domain": {
"resolution_multiplier": resolution_multiplier,
"box_size": 4.0, # kpc
},
"time": {
"end": 0.4,
},
"output": {
"path": f"./checkpoints{resolution_multiplier}/",
"plot_dynamic_range": 2.0,
},
"quantum": {
"m_22": 1.0, # axion mass in units of 10^-22 eV
},
"hydro": {
"sound_speed": 20.0, # km/s
},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# average density of all matter (dm+gas) in the simulation (in units of Msun / kpc^3)
rho_bar = 1.0e7
# gas
frac_gas = 0.15 # fraction of total mass in gas
rho_gas = frac_gas * rho_bar # average density of gas
c_sound = sim.sound_speed # sound speed (km/s)
# dark matter
frac_dm = 1.0 - frac_gas # fraction of total mass in dark matter
sigma = 40.0 # velocity dispersion of dm
m = sim.axion_mass
hbar = jaxion.constants["reduced_planck_constant"]
m_per_hbar = m / hbar
kx, ky, kz = sim.kgrid
k_sq = kx**2 + ky**2 + kz**2
nx = sim.resolution
box_size = sim.box_size
G = jaxion.constants["gravitational_constant"]
# check that de broglie wavelength fits into box
de_broglie_wavelength = hbar / (m * sigma)
n_wavelengths = box_size / de_broglie_wavelength
assert n_wavelengths > 1
# check the Jeans length
jeans_length = c_sound * jnp.sqrt(jnp.pi / (G * rho_gas))
n_jeans = box_size / jeans_length
assert n_jeans < 0.5
# dark matter
# construct in fourier space according to Eq (27) of our paper [https://arxiv.org/abs/1801.03507]
np.random.seed(17)
# initialize random phases
# (use order to set lowest k modes first)
sid = np.argsort(k_sq.flatten(), stable=True)
psi = np.zeros((nx**3,), dtype=complex)
psi[sid] = np.exp(1.0j * 2.0 * np.pi * np.random.rand(nx**3))
psi = psi.reshape(k_sq.shape)
psi = jnp.array(psi)
psi *= np.sqrt(np.exp(-k_sq / (2.0 * sigma**2 * m_per_hbar**2)))
psi = jd.fft.pifft3d(psi)
# re-normalize it
psi *= jnp.sqrt(frac_dm * rho_bar / jnp.mean(jnp.abs(psi) ** 2))
# gas is initially uniform
rho = jnp.ones((nx, nx, nx)) * rho_gas
vx = jnp.zeros((nx, nx, nx))
vy = jnp.zeros((nx, nx, nx))
vz = jnp.zeros((nx, nx, nx))
sim.state["psi"] = psi
sim.state["rho"] = rho
sim.state["vx"] = vx
sim.state["vy"] = vy
sim.state["vz"] = vz
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
args = parser.parse_args()
sim = set_up_simulation(args.res)
sim.run()
print("mean |psi| =", jnp.mean(jnp.abs(sim.state["psi"])))
print("mean |vx| =", jnp.mean(jnp.abs(sim.state["vx"])))
print("mean |vy| =", jnp.mean(jnp.abs(sim.state["vy"])))
print("mean |vz| =", jnp.mean(jnp.abs(sim.state["vz"])))
return sim
if __name__ == "__main__":
main()
heating_stars#
See on GitHub: examples/heating_stars#
README:
# Heating Stars
Heating Stars due to Fuzzy Dark Matter fluctuations
Philip Mocz (2025)
Usage:
```console
python heating_stars.py --res <resolution_multiplier>
```
Takes around 11 (res=1) seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm100.png" alt="dm100" width="45%"/>
</div>
## References
[Church, B.; Mocz, P.; Ostriker, J.P.; Heating of Milky Way disc Stars by Dark Matter Fluctuations in Cold Dark Matter and Fuzzy Dark Matter Paradigms. MNRAS (2019)](https://ui.adsabs.harvard.edu/abs/2019MNRAS.485.2861C)
Script:
import jax.numpy as jnp
import numpy as np
import jaxdecomp as jd
import argparse
import jaxion
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Heating Stars due to Fuzzy Dark Matter fluctuations
Philip Mocz (2025)
Usage:
python heating_stars.py --res <resolution_multiplier>
"""
def set_up_simulation(resolution_multiplier):
# average density of all matter (dm+stars) in the simulation (in units of Msun / kpc^3)
rho_bar = 1.0e7
# stars
frac_stars = 0.15 # fraction of total mass in stars
rho_stars = frac_stars * rho_bar # average density of stars
sigma_stars = 20.0 # velocity dispersion (1d) of stars (km/s)
box_size = 4.0 # kpc
n_stars = 400
m_stars = rho_stars * box_size**3 / n_stars # Msun
# Parameters added/changed from default values
params = {
"physics": {
"particles": True,
},
"domain": {
"resolution_multiplier": resolution_multiplier,
"box_size": box_size,
},
"time": {
"end": 0.4,
},
"output": {
"path": f"./checkpoints{resolution_multiplier}/",
"plot_dynamic_range": 2.0,
},
"quantum": {
"m_22": 1.0, # axion mass in units of 10^-22 eV
},
"particles": {"num_particles": n_stars, "particle_mass": m_stars},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# dark matter
frac_dm = 1.0 - frac_stars # fraction of total mass in dark matter
sigma = 40.0 # velocity dispersion of dm (km/s)
m = sim.axion_mass
hbar = jaxion.constants["reduced_planck_constant"]
m_per_hbar = m / hbar
kx, ky, kz = sim.kgrid
k_sq = kx**2 + ky**2 + kz**2
nx = sim.resolution
G = jaxion.constants["gravitational_constant"]
# check that de broglie wavelength fits into box
de_broglie_wavelength = hbar / (m * sigma)
n_wavelengths = box_size / de_broglie_wavelength
assert n_wavelengths > 1
# check the Jeans length
jeans_length = sigma_stars * jnp.sqrt(jnp.pi / (G * rho_stars))
n_jeans = box_size / jeans_length
assert n_jeans < 0.5
# dark matter
# construct in fourier space according to Eq (27) of our paper [https://arxiv.org/abs/1801.03507]
np.random.seed(17)
# initialize random phases
# initialize random phases
# (use order to set lowest k modes first)
sid = np.argsort(k_sq.flatten(), stable=True)
psi = np.zeros((nx**3,), dtype=complex)
psi[sid] = np.exp(1.0j * 2.0 * np.pi * np.random.rand(nx**3))
psi = psi.reshape(k_sq.shape)
psi = jnp.array(psi)
psi *= np.sqrt(np.exp(-k_sq / (2.0 * sigma**2 * m_per_hbar**2)))
psi = jd.fft.pifft3d(psi)
# re-normalize it
psi *= jnp.sqrt(frac_dm * rho_bar / jnp.mean(jnp.abs(psi) ** 2))
# stars are initially uniform
np.random.seed(17)
pos = np.random.rand(n_stars, 3) * box_size
vel = np.random.randn(n_stars, 3) * sigma_stars
pos = jnp.array(pos)
vel = jnp.array(vel)
sim.state["psi"] = psi
sim.state["pos"] = pos
sim.state["vel"] = vel
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
args = parser.parse_args()
sim = set_up_simulation(args.res)
sim.run()
print("mean |psi| =", jnp.mean(jnp.abs(sim.state["psi"])))
print("mean |vx| =", jnp.mean(jnp.abs(sim.state["vel"][:, 0])))
print("mean |vy| =", jnp.mean(jnp.abs(sim.state["vel"][:, 1])))
print("mean |vz| =", jnp.mean(jnp.abs(sim.state["vel"][:, 2])))
print(
"std vel =",
jnp.sqrt(
jnp.std(sim.state["vel"][:, 0]) ** 2
+ jnp.std(sim.state["vel"][:, 1]) ** 2
+ jnp.std(sim.state["vel"][:, 2]) ** 2
),
)
return sim
if __name__ == "__main__":
main()
kinetic_condensation#
See on GitHub: examples/kinetic_condensation#
README:
# Kinetic Condensation
Soliton formation in the kinetic regime
Philip Mocz (2025)
Usage:
```console
python kinetic_condensation.py --res <resolution_multiplier>
```
Takes around 34 (res=1) seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm000.png" alt="dm000" width="32%"/>
<img src="dm020.png" alt="dm020" width="32%"/>
<img src="dm040.png" alt="dm040" width="32%"/>
<img src="dm060.png" alt="dm060" width="32%"/>
<img src="dm080.png" alt="dm080" width="32%"/>
<img src="dm100.png" alt="dm100" width="32%"/>
</div>
## References
[Levkov, D.G.; Panin, A.G.; Tkachev, I.I.; Gravitational Bose-Einstein condensation in the kinetic regime PRL (2018)](https://ui.adsabs.harvard.edu/abs/2018PhRvL.121o1301L)
Script:
import jax.numpy as jnp
import numpy as np
import jaxdecomp as jd
import argparse
import jaxion
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Soliton formation in the kinetic regime
Philip Mocz (2025)
Usage:
python kinetic_condensation.py --res <resolution_multiplier>
"""
def set_up_simulation(resolution_multiplier):
# Parameters added/changed from default values
params = {
"domain": {
"resolution_multiplier": resolution_multiplier,
"box_size": 6.0, # kpc
},
"time": {
"end": 10.0,
},
"output": {
"path": f"./checkpoints{resolution_multiplier}/",
"plot_dynamic_range": 4.0,
},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# average density of dark matter in the simulation (in units of Msun / kpc^3)
rho_bar = 1.0e8
# dark matter
sigma = 100.0 # velocity dispersion of dm
m = sim.axion_mass
hbar = jaxion.constants["reduced_planck_constant"]
m_per_hbar = m / hbar
kx, ky, kz = sim.kgrid
k_sq = kx**2 + ky**2 + kz**2
nx = sim.resolution
box_size = sim.box_size
G = jaxion.constants["gravitational_constant"]
# check that de broglie wavelength fits into box
de_broglie_wavelength = hbar / (m * sigma)
n_wavelengths = box_size / de_broglie_wavelength
assert n_wavelengths > 1
# check timescales
lambda_fac = np.log(m * sigma * box_size / hbar)
b = 1.0
n = rho_bar / m
# eqn 4 of Levkov
tau_gr = (
b
* np.sqrt(2.0)
/ (12.0 * np.pi**3)
* m
* sigma**6
/ (G**2 * n**2 * lambda_fac)
/ hbar**3
)
# we want to see condensation happen in the simulation time
assert 10.0 * tau_gr < sim.params["time"]["end"]
kin1 = m * sigma * box_size / hbar # box crossing time >> 1
kin2 = m * sigma**2 * tau_gr / hbar # condensation time >> 1 eqn 1 of Levkov
assert kin1 > 1
assert kin2 > 1
# check Jeans lengths
jeans_length = sigma * np.sqrt(np.pi / (G * rho_bar)) # kinetic Jeans
# quantum Jeans, eqn 40 of Levkov
jeans_length_Q = np.pi / ((np.pi * G * rho_bar) ** 0.25 * m**0.5 / np.sqrt(hbar))
assert jeans_length > box_size
assert jeans_length_Q < box_size
# dark matter
# construct in fourier space according to Eq (27) of our paper [https://arxiv.org/abs/1801.03507]
np.random.seed(17)
# initialize random phases
# (use order to set lowest k modes first)
sid = np.argsort(k_sq.flatten(), stable=True)
psi = np.zeros((nx**3,), dtype=complex)
psi[sid] = np.exp(1.0j * 2.0 * np.pi * np.random.rand(nx**3))
psi = psi.reshape(k_sq.shape)
psi = jnp.array(psi)
psi *= np.sqrt(np.exp(-k_sq / (2.0 * sigma**2 * m_per_hbar**2)))
psi = jd.fft.pifft3d(psi)
# re-normalize it
psi *= jnp.sqrt(rho_bar / jnp.mean(jnp.abs(psi) ** 2))
sim.state["psi"] = psi
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
args = parser.parse_args()
sim = set_up_simulation(args.res)
sim.run()
return sim
if __name__ == "__main__":
main()
logo_inverse_problem#
See on GitHub: examples/logo_inverse_problem#
README:
# Logo Inverse Problem
Inverse problem: Find initial conditions for velocity (dm+stars) to achieve target density at t=1
Philip Mocz (2025)
Usage:
```console
python logo_inverse_problem.py
```
Takes around 200 seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm100.png" alt="dm100" width="45%"/>
</div>
Script:
import jax
import jax.numpy as jnp
import jaxion
import chex
from typing import NamedTuple
import optax
import time
import matplotlib.image as img
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Inverse problem: Find initial conditions for velocity (dm+stars) to achieve target density at t=1
Philip Mocz (2025)
Usage:
python logo_inverse_problem.py
"""
def set_up_simulation(save=False):
# Parameters added/changed from default values
params = {
"physics": {
"particles": True,
},
"domain": {
"resolution_base": 32,
"box_size": 20.0, # kpc
},
"time": {
"end": 1.0,
},
"output": {
"path": "./checkpoints/",
"num_checkpoints": 100,
"save": save,
"plot_dynamic_range": 10.0,
},
"particles": {
"num_particles": 64,
"particle_mass": 1.0e7, # Msun
},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# dark matter
rho_dm = 1.0e5 # (Msun / kpc^3)
box_size = sim.box_size
nx = sim.resolution
# dark matter
psi = jnp.sqrt(rho_dm) * jnp.ones((nx, nx, nx)) + 0j
# stars are initially uniformly distributed
# Arrange positions in a uniform grid in x-y plane, z=box_size/2
num_stars = params["particles"]["num_particles"]
side = int(num_stars ** (1 / 2))
xlin = jnp.linspace(0, box_size, side, endpoint=False) + box_size / (2 * side)
ylin = jnp.linspace(0, box_size, side, endpoint=False) + box_size / (2 * side)
zlin = jnp.array([box_size / 2.0])
xx, yy, zz = jnp.meshgrid(xlin, ylin, zlin, indexing="ij")
pos = jnp.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T
vel = jnp.zeros((num_stars, 3))
sim.state["psi"] = psi
sim.state["pos"] = pos
sim.state["vel"] = vel
return sim
class InfoState(NamedTuple):
iter_num: chex.Numeric
def print_info():
def init_fn(params):
del params
return InfoState(iter_num=0)
def update_fn(updates, state, params, *, value, grad, **extra_args):
del params, extra_args
jax.debug.print(
"Iteration: {i}, Loss: {v:.2e}, |grad|: {e:.2e}",
i=state.iter_num,
v=value,
e=optax.tree_utils.tree_norm(grad),
)
return updates, InfoState(iter_num=state.iter_num + 1)
return optax.GradientTransformationExtraArgs(init_fn, update_fn)
def run_opt(init_params, fun, opt, max_iter, tol):
value_and_grad_fun = optax.value_and_grad_from_state(fun)
def step(carry):
params, state = carry
value, grad = value_and_grad_fun(params, state=state)
updates, state = opt.update(
grad, state, params, value=value, grad=grad, value_fn=fun
)
params = optax.apply_updates(params, updates)
return params, state
def continuing_criterion(carry):
_, state = carry
iter_num = optax.tree_utils.tree_get(state, "count")
grad = optax.tree_utils.tree_get(state, "grad")
err = optax.tree_utils.tree_norm(grad)
return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))
init_carry = (init_params, opt.init(init_params))
final_params, final_state = jax.lax.while_loop(
continuing_criterion, step, init_carry
)
return final_params, final_state
def solve_inverse_problem():
"""Optimize the initial star positions to recreate the logo"""
# Load the target density field
target_data = img.imread("target.png")[:, :, 0]
target_data = target_data[::2, ::2] # downsample
target = jnp.flipud(jnp.array(target_data, dtype=float)).T
target = 1.0 - 1.6 * (target - 0.5)
target /= jnp.mean(target)
@jax.jit
def loss_function(theta):
sim = set_up_simulation()
sim.state["vel"] = theta
sim.run()
projected_density = jnp.mean(jnp.abs(sim.state["psi"]) ** 2, axis=2)
norm = jnp.mean(projected_density)
projected_density /= norm
error_norm = jnp.mean((projected_density - target) ** 2)
return error_norm
# opt = optax.lbfgs()
opt = optax.chain(print_info(), optax.lbfgs())
sim = set_up_simulation()
init_params = sim.state["vel"]
print(
f"Initial value: {loss_function(init_params):.2e} "
f"Initial gradient norm: {optax.tree_utils.tree_norm(jax.grad(loss_function)(init_params)):.2e}"
)
t0 = time.time()
final_params, _ = run_opt(init_params, loss_function, opt, max_iter=100, tol=1e-5)
print("Inverse-problem solve time (s): ", time.time() - t0)
print(
f"Final value: {loss_function(final_params):.2e}, "
f"Final gradient norm: {optax.tree_utils.tree_norm(jax.grad(loss_function)(final_params)):.2e}"
)
return final_params
def main():
optimized_ics = solve_inverse_problem()
sim = set_up_simulation(save=True)
sim.state["vel"] = optimized_ics
sim.run()
return sim
if __name__ == "__main__":
main()
self_interaction_collapse#
See on GitHub: examples/self_interaction_collapse#
README:
# Self-Interaction Collapse
Collapse a soliton due to attractive self-interaction
Philip Mocz (2025)
Usage:
```console
python self_interaction_collapse.py --res <resolution_multiplier>
```
Takes around 5 (res=1) seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm000.png" alt="dm000" width="45%"/>
<img src="dm064.png" alt="dm002" width="45%"/>
</div>
## References
[Chavanis, P.H.; Phase transitions between dilute and dense axion stars. Phys. Rev. D. (2018)](https://ui.adsabs.harvard.edu/abs/2018PhRvD..98b3009C)
[Mocz, P. et. al.; Cosmological Structure Formation and Soliton Phase Transition in Fuzzy Dark Matter with Axion Self-Interactions. MNRAS (2023)](https://ui.adsabs.harvard.edu/abs/2023MNRAS.521.2608M)
Script:
import jax
import jax.numpy as jnp
import argparse
import jaxion
# switch on for double precision
jax.config.update("jax_enable_x64", True)
"""
Collapse a soliton due to attractive self-interaction
Philip Mocz (2025)
Usage:
python self_interaction_collapse.py --res <resolution_multiplier>
"""
def set_up_simulation(resolution_multiplier):
# Parameters added/changed from default values
params = {
"domain": {
"resolution_multiplier": resolution_multiplier,
"box_size": 4.0,
},
"time": {
"end": 0.1,
},
"output": {
"path": f"./checkpoints{resolution_multiplier}/",
"save": True,
},
"quantum": {
"m_22": 1.0, # axion mass in 10^-22 eV
"f_15": -0.8, # decay constant in 10^15 GeV
},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# Set initial conditions (randomly placed solitons)
m_22 = sim.params["quantum"]["m_22"]
hbar = jaxion.constants["reduced_planck_constant"]
G = jaxion.constants["gravitational_constant"]
m = sim.axion_mass
a_s = sim.scattering_length
box_size = sim.params["domain"]["box_size"]
xx, yy, zz = sim.grid
print("a_s =", a_s)
M_soliton = 1.0e9
r_soliton = 2.2e8 * m_22**-2 / M_soliton # in kpc
M_crit = 1.012 * hbar / jnp.sqrt(G * m * jnp.abs(a_s))
print("M_soliton/M_crit =", M_soliton / M_crit)
def rho_soliton(r, r_soliton, m_22):
return (
1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
)
r = jnp.sqrt(
(xx - 0.5 * box_size) ** 2
+ (yy - 0.5 * box_size) ** 2
+ (zz - 0.5 * box_size) ** 2
)
rho = rho_soliton(r, r_soliton, m_22)
sim.state["psi"] = jnp.array(jnp.sqrt(rho)) + 0j
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
args = parser.parse_args()
sim = set_up_simulation(args.res)
sim.run()
return sim
if __name__ == "__main__":
main()
soliton_binary_merger#
See on GitHub: examples/soliton_binary_merger#
README:
# Soliton Binary Merger
Collide two solitons with opposite phases
Philip Mocz (2025)
Usage:
```console
python soliton_binary_merger.py --res <resolution_multiplier>
```
Takes around 4 (res=1) seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm000.png" alt="dm000" width="32%"/>
<img src="dm002.png" alt="dm002" width="32%"/>
<img src="dm004.png" alt="dm004" width="32%"/>
<img src="dm006.png" alt="dm006" width="32%"/>
<img src="dm008.png" alt="dm008" width="32%"/>
<img src="dm010.png" alt="dm010" width="32%"/>
</div>
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="timing.png" alt="timing" width="45%"/>
</div>
## References
[Schwabe, B.; Niemeyer, J.C.; Engels, J.F.; Simulations of solitonic core mergers in ultralight axion dark matter cosmologies. Phys Rev D. (2016)](https://ui.adsabs.harvard.edu/abs/2016PhRvD..94d3513S)
[Mocz, P. et. al.; Galaxy formation with BECDM - I. Turbulence and relaxation of idealized haloes. MNRAS (2017)](https://ui.adsabs.harvard.edu/abs/2017MNRAS.471.4559M)
Script:
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import argparse
import os
import jaxion
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Collide two solitons with opposite phases
Philip Mocz (2025)
Usage:
python soliton_binary_merger.py --res <resolution_multiplier>
python soliton_binary_merger.py --distributed --emulate
"""
def set_up_simulation(resolution_multiplier, sharding):
# Parameters added/changed from default values
params = {
"domain": {
"resolution_multiplier": resolution_multiplier,
},
"time": {
"end": 1.0,
},
"output": {
"path": f"./checkpoints{resolution_multiplier}/",
"save": True,
},
}
# Initialize the simulation
sim = jaxion.Simulation(params, sharding=sharding)
# Set initial conditions (two solitons)
m_22 = sim.params["quantum"]["m_22"]
box_size = sim.params["domain"]["box_size"]
r_soliton = 0.02 * box_size
xx, yy, zz = sim.grid
def rho_soliton(r, r_soliton, m_22):
return (
1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
)
rho = 0.0
r_soliton = 0.02 * box_size
for i in range(2):
x_soliton = (0.5 + 0.3 * (i - 0.5)) * box_size
y_soliton = 0.5 * box_size
z_soliton = 0.5 * box_size
r = jnp.sqrt(
(xx - x_soliton) ** 2 + (yy - y_soliton) ** 2 + (zz - z_soliton) ** 2
)
rho += rho_soliton(r, r_soliton, m_22)
half = (xx - 0.5 * box_size) > 0
sim.state["psi"] = jnp.array(jnp.sqrt(rho) + 0j) * (
1.0 * half + -1.0 * (1.0 - half)
)
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
parser.add_argument(
"--distributed", action="store_true", help="run in distributed mode"
)
parser.add_argument(
"--emulate", action="store_true", help="emulate distributed mode on CPU"
)
args = parser.parse_args()
if args.distributed:
if args.emulate:
flags = os.environ.get("XLA_FLAGS", "")
flags += " --xla_force_host_platform_device_count=8" # change to, e.g., 8 for testing sharding virtually
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["XLA_FLAGS"] = flags
if jax.process_index() == 0:
print("Using emulated distributed CPU mode")
else:
jax.distributed.initialize()
if jax.process_index() == 0:
print("Using distributed GPU mode")
# Create mesh and sharding for distributed computation
n_devices = jax.device_count()
devices = mesh_utils.create_device_mesh((1, n_devices))
mesh = Mesh(devices, axis_names=("x", "y"))
sharding = NamedSharding(mesh, PartitionSpec("x", "y"))
else:
sharding = None
sim = set_up_simulation(args.res, sharding)
sim.run()
return sim
if __name__ == "__main__":
main()
soliton_gas_star#
See on GitHub: examples/soliton_gas_star#
README:
# Soliton Gas Star
Explore the dynamics of a soliton, gas, and single star
Philip Mocz (2025)
Usage:
```console
python soliton_gas_star.py --res <resolution_multiplier>
```
Takes around 20 (res=1) seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm010.png" alt="dm010" width="45%"/>
<img src="gas010.png" alt="gas010" width="45%"/>
</div>
Script:
import jax.numpy as jnp
import argparse
import jaxion
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Explore the dynamics of a soliton, gas, and single star
Philip Mocz (2025)
Usage:
python soliton_gas_star.py --res <resolution_multiplier>
"""
def set_up_simulation(resolution_multiplier):
# Parameters added/changed from default values
params = {
"physics": {
"hydro": True,
"particles": True,
},
"domain": {
"resolution_multiplier": resolution_multiplier,
},
"output": {
"path": f"./checkpoints{resolution_multiplier}/",
},
"hydro": {
"sound_speed": 40.0, # km/s
},
"particles": {"num_particles": 1, "particle_mass": 1.0e7},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# Set initial conditions
M_soliton = 1.0e9 # mass of soliton (M_sun)
m_22 = sim.params["quantum"]["m_22"]
r_soliton = 2.2e8 * m_22**-2 / M_soliton # in kpc
box_size = sim.params["domain"]["box_size"]
nx = sim.resolution
xx, yy, zz = sim.grid
r = jnp.sqrt(
(xx - 0.5 * box_size) ** 2
+ (yy - 0.5 * box_size) ** 2
+ (zz - 0.5 * box_size) ** 2
)
sim.state["psi"] = (
jnp.sqrt(
1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
)
+ 0.0j
)
# add gas
frac_gas = 0.1
rho_gas = frac_gas * M_soliton / box_size**3
sim.state["rho"] = jnp.ones((nx, nx, nx)) * rho_gas
sim.state["vx"] = jnp.zeros((nx, nx, nx))
sim.state["vy"] = jnp.zeros((nx, nx, nx))
sim.state["vz"] = jnp.zeros((nx, nx, nx))
# add star
pos = jnp.array([[0.6 * box_size, 0.5 * box_size, 0.5 * box_size]])
vel = jnp.array([[0.0, 40.0, 0.0]])
sim.state["pos"] = pos
sim.state["vel"] = vel
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
args = parser.parse_args()
sim = set_up_simulation(args.res)
sim.run()
return sim
if __name__ == "__main__":
main()
soliton_merger#
See on GitHub: examples/soliton_merger#
README:
# Soliton Merger
Merge solitons to form an idealized Fuzzy Dark Matter halo
Philip Mocz (2025)
Usage:
```console
python soliton_merger.py --res <resolution_multiplier>
```
Takes around 4 (res=1) seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm000.png" alt="dm000" width="32%"/>
<img src="dm002.png" alt="dm002" width="32%"/>
<img src="dm004.png" alt="dm004" width="32%"/>
<img src="dm006.png" alt="dm006" width="32%"/>
<img src="dm008.png" alt="dm008" width="32%"/>
<img src="dm010.png" alt="dm010" width="32%"/>
</div>
## References
[Mocz, P. et. al.; Galaxy formation with BECDM - I. Turbulence and relaxation of idealized haloes. MNRAS (2017)](https://ui.adsabs.harvard.edu/abs/2017MNRAS.471.4559M)
Script:
import jax.numpy as jnp
import numpy as np
import argparse
import jaxion
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Merge solitons to form an idealized Fuzzy Dark Matter halo
Philip Mocz (2025)
Usage:
python soliton_merger.py --res <resolution_multiplier>
"""
def set_up_simulation(resolution_multiplier):
# Parameters added/changed from default values
params = {
"domain": {
"resolution_multiplier": resolution_multiplier,
},
"output": {
"path": f"./checkpoints{resolution_multiplier}/",
},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# Set initial conditions (randomly placed solitons)
m_22 = sim.params["quantum"]["m_22"]
box_size = sim.params["domain"]["box_size"]
nx = sim.resolution
xx, yy, zz = sim.grid
def rho_soliton(r, r_soliton, m_22):
return (
1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
)
np.random.seed(17)
n_solitons = 8
rho = np.zeros((nx, nx, nx), dtype=complex)
for _ in range(n_solitons):
r_soliton = (0.05 + 0.03 * np.random.rand()) * box_size
x_soliton = (0.25 + 0.5 * np.random.rand()) * box_size
y_soliton = (0.25 + 0.5 * np.random.rand()) * box_size
z_soliton = (0.25 + 0.5 * np.random.rand()) * box_size
r = jnp.sqrt(
(xx - x_soliton) ** 2 + (yy - y_soliton) ** 2 + (zz - z_soliton) ** 2
)
rho += rho_soliton(r, r_soliton, m_22)
sim.state["psi"] = jnp.array(jnp.sqrt(rho))
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
args = parser.parse_args()
sim = set_up_simulation(args.res)
sim.run()
print("mean |psi| =", jnp.mean(jnp.abs(sim.state["psi"])))
return sim
if __name__ == "__main__":
main()
tidal_stripping#
See on GitHub: examples/tidal_stripping#
README:
# Tidal Stripping
Tidal Stripping of Fuzzy Dark Matter soliton in an external potential
Philip Mocz (2025)
Usage:
```console
python tidal_stripping.py --res <resolution_multiplier>
```
Takes around 4 (res=1) and 40 (res=2) seconds to run on my macbook (cpu).
Demonstrates adding an external potential, including an imaginary sponge potential to absorb material near the boundaries of the periodic box.
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="dm000.png" alt="dm000" width="32%"/>
<img src="dm002.png" alt="dm002" width="32%"/>
<img src="dm004.png" alt="dm004" width="32%"/>
<img src="dm006.png" alt="dm006" width="32%"/>
<img src="dm008.png" alt="dm008" width="32%"/>
<img src="dm010.png" alt="dm010" width="32%"/>
</div>
Script:
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import argparse
import os
import jaxion
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Tidal Stripping of Fuzzy Dark Matter soliton in an external potential
Philip Mocz (2025)
Usage:
python tidal_stripping.py --res <resolution_multiplier>
python tidal_stripping.py --distributed --emulate
"""
def set_up_simulation(resolution_multiplier, sharding, save):
# Parameters added/changed from default values
params = {
"physics": {
"external_potential": True,
},
"domain": {
"resolution_multiplier": resolution_multiplier,
},
"output": {
"path": f"./checkpoints{resolution_multiplier}/",
"save": save,
},
}
# Initialize the simulation
sim = jaxion.Simulation(params, sharding=sharding)
# Set initial conditions (orbiting soliton in external potential)
M_soliton = 1.0e9 # mass of soliton (M_sun)
k_soliton = 4.0 # wave-number for orbital motion of soliton
r_separation = 2.0 # separation of soliton from center (kpc)
m_22 = sim.params["quantum"]["m_22"]
m = sim.axion_mass
hbar = jaxion.constants["reduced_planck_constant"]
G = jaxion.constants["gravitational_constant"]
r_soliton = 2.2e8 * m_22**-2 / M_soliton # in kpc
box_size = sim.params["domain"]["box_size"]
X, Y, Z = sim.grid
R = jnp.sqrt(
(X - 0.5 * box_size) ** 2
+ (Y - 0.5 * box_size - r_separation) ** 2
+ (Z - 0.5 * box_size) ** 2
)
sim.state["psi"] = (
jnp.sqrt(
1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (R / r_soliton) ** 2) ** 8
)
+ 0.0j
)
# add circular orbit velocity
sim.state["psi"] *= jnp.exp(1.0j * k_soliton * X)
# add external potential (host halo)
M_halo = 0.25 * box_size * k_soliton**2 * hbar**2 / (G * m**2)
if jax.process_index() == 0:
print(f"M_halo: {M_halo:.2e} M_sun")
assert M_halo > M_soliton * 2.0 # halo should be much more massive than soliton
R = jnp.sqrt(
(X - 0.5 * box_size) ** 2
+ (Y - 0.5 * box_size) ** 2
+ (Z - 0.5 * box_size) ** 2
)
V_halo = -G * M_halo / (R + 0.5 * sim.dx) # softening
# add sponge potential to prevent material periodically crossing the box
# see Schwabe et al. (2016)
V_0 = G * M_halo / (box_size / 4.0)
r_N = 0.5 * box_size
r_p = (7 / 8) * r_N
r_s = 0.5 * (r_N + r_p)
delta = r_N - r_p
V_sponge = (
-0.5j
* V_0
* (2 + jnp.tanh((R - r_s) / delta) - jnp.tanh(r_s / delta))
* jnp.heaviside(R - r_p, 0.0)
)
sim.state["V_ext"] = V_halo + V_sponge
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
parser.add_argument(
"--save", type=bool, default=True, help="Save simulation output"
)
parser.add_argument(
"--distributed", action="store_true", help="run in distributed mode"
)
parser.add_argument(
"--emulate", action="store_true", help="emulate distributed mode on CPU"
)
args = parser.parse_args()
if args.distributed:
if args.emulate:
flags = os.environ.get("XLA_FLAGS", "")
flags += " --xla_force_host_platform_device_count=8" # change to, e.g., 8 for testing sharding virtually
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["XLA_FLAGS"] = flags
if jax.process_index() == 0:
print("Using emulated distributed CPU mode")
else:
jax.distributed.initialize()
if jax.process_index() == 0:
print("Using distributed GPU mode")
# Create mesh and sharding for distributed computation
n_devices = jax.device_count()
devices = mesh_utils.create_device_mesh((1, n_devices))
mesh = Mesh(devices, axis_names=("x", "y"))
sharding = NamedSharding(mesh, PartitionSpec("x", "y"))
else:
sharding = None
sim = set_up_simulation(args.res, sharding, args.save)
sim.run()
mean_psi = jnp.mean(jnp.abs(sim.state["psi"]))
if jax.process_index() == 0:
print("mean |psi| =", mean_psi)
return sim
if __name__ == "__main__":
main()
timing#
See on GitHub: examples/timing#
README:
# Timing
Timing test for FDM+hydro based on heating_gas/
Philip Mocz (2025)
Usage:
```console
python timing.py --res <resolution_multiplier>
```
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="timing.png" alt="gas100" width="45%"/>
</div>
Script:
import jax.numpy as jnp
import numpy as np
import jaxdecomp as jd
import argparse
import jaxion
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Timing test based on heating_gas/
Philip Mocz (2025)
Usage:
python timing.py --res <resolution_multiplier>
"""
def set_up_simulation(resolution_multiplier):
# Parameters added/changed from default values
params = {
"physics": {
"hydro": True,
},
"domain": {
"resolution_multiplier": resolution_multiplier,
"box_size": 4.0, # kpc
},
"time": {
"end": 0.01 * (1.0 / resolution_multiplier**2),
},
"output": {
"num_checkpoints": 100,
"save": False,
},
"hydro": {
"sound_speed": 20.0, # km/s
},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# average density of all matter (dm+gas) in the simulation (in units of Msun / kpc^3)
rho_bar = 1.0e7
# gas
frac_gas = 0.15 # fraction of total mass in gas
rho_gas = frac_gas * rho_bar # average density of gas
c_sound = sim.sound_speed # sound speed (km/s)
# dark matter
frac_dm = 1.0 - frac_gas # fraction of total mass in dark matter
sigma = 40.0 # velocity dispersion of dm
m = sim.axion_mass
hbar = jaxion.constants["reduced_planck_constant"]
m_per_hbar = m / hbar
kx, ky, kz = sim.kgrid
k_sq = kx**2 + ky**2 + kz**2
nx = sim.resolution
box_size = sim.box_size
G = jaxion.constants["gravitational_constant"]
# check that de broglie wavelength fits into box
de_broglie_wavelength = hbar / (m * sigma)
n_wavelengths = box_size / de_broglie_wavelength
assert n_wavelengths > 1
# check the Jeans length
jeans_length = c_sound * jnp.sqrt(jnp.pi / (G * rho_gas))
n_jeans = box_size / jeans_length
assert n_jeans < 0.5
# dark matter
# construct in fourier space according to Eq (27) of our paper [https://arxiv.org/abs/1801.03507]
np.random.seed(17)
# initialize random phases
# (use order to set lowest k modes first)
sid = np.argsort(k_sq.flatten(), stable=True)
psi = np.zeros((nx**3,), dtype=complex)
psi[sid] = np.exp(1.0j * 2.0 * np.pi * np.random.rand(nx**3))
psi = psi.reshape(k_sq.shape)
psi = jnp.array(psi)
psi *= np.sqrt(np.exp(-k_sq / (2.0 * sigma**2 * m_per_hbar**2)))
psi = jd.fft.pifft3d(psi)
# re-normalize it
psi *= jnp.sqrt(frac_dm * rho_bar / jnp.mean(jnp.abs(psi) ** 2))
# gas is initially uniform
rho = jnp.ones((nx, nx, nx)) * rho_gas
vx = jnp.zeros((nx, nx, nx))
vy = jnp.zeros((nx, nx, nx))
vz = jnp.zeros((nx, nx, nx))
sim.state["psi"] = psi
sim.state["rho"] = rho
sim.state["vx"] = vx
sim.state["vy"] = vy
sim.state["vz"] = vz
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
args = parser.parse_args()
sim = set_up_simulation(args.res)
sim.run()
return sim
if __name__ == "__main__":
main()
two_field#
See on GitHub: examples/two_field#
README:
# Two-Field
Demo for adding custom fields: Two-field fuzzy dark matter
Merge solitons in two-field FDM to form a halo
Philip Mocz (2025)
Usage:
```console
python two_field.py --res <resolution_multiplier>
```
Takes around 6 (res=1) seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="rho1_070.png" alt="rho1_070" width="45%"/>
<img src="rho2_070.png" alt="rho2_070" width="45%"/>
</div>
## References
[Luu, H.N.; Mocz, P.; Vogelsberger, M.; Nested solitons in two-field fuzzy dark matter. MNRAS (2024)](https://ui.adsabs.harvard.edu/abs/2024MNRAS.527.4162L)
Script:
import jax
import jax.numpy as jnp
import numpy as np
import jaxdecomp as jd
import argparse
import jaxion
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Merge solitons to form an idealized Two-Field Fuzzy Dark Matter halo
Demonstrates adding custom fields
Philip Mocz (2025)
Usage:
python two_field.py --res <resolution_multiplier>
"""
def set_up_simulation(resolution_multiplier):
# Parameters added/changed from default values
params = {
"physics": {
"quantum": False, # use custom fields instead for psi1, psi2
"gravity": True,
},
"domain": {
"resolution_multiplier": resolution_multiplier,
},
"output": {
"path": f"./checkpoints{resolution_multiplier}/",
},
}
# Initialize the simulation
sim = jaxion.Simulation(params)
# Set initial conditions (randomly placed solitons)
m_22_1 = 1.0
m_22_2 = 2.0
hbar = jaxion.constants["reduced_planck_constant"]
eV = jaxion.constants["electron_volt"]
c = jaxion.constants["speed_of_light"]
m1 = m_22_1 * 1.0e-22 * eV / c**2
m2 = m_22_2 * 1.0e-22 * eV / c**2
m1_per_hbar = m1 / hbar
m2_per_hbar = m2 / hbar
box_size = sim.params["domain"]["box_size"]
nx = sim.resolution
xx, yy, zz = sim.grid
def rho_soliton(r, r_soliton, m_22):
return (
1.9e7 * m_22**-2 * r_soliton**-4 / (1.0 + 0.091 * (r / r_soliton) ** 2) ** 8
)
np.random.seed(17)
n_solitons = 8
rho1 = np.zeros((nx, nx, nx), dtype=complex)
rho2 = np.zeros((nx, nx, nx), dtype=complex)
for _ in range(n_solitons):
r_soliton = (0.05 + 0.03 * np.random.rand()) * box_size
x_soliton = (0.25 + 0.5 * np.random.rand()) * box_size
y_soliton = (0.25 + 0.5 * np.random.rand()) * box_size
z_soliton = (0.25 + 0.5 * np.random.rand()) * box_size
r = jnp.sqrt(
(xx - x_soliton) ** 2 + (yy - y_soliton) ** 2 + (zz - z_soliton) ** 2
)
rho1 += rho_soliton(r, r_soliton, m_22_1)
rho2 += rho_soliton(r, r_soliton, m_22_2)
sim.state["psi1"] = jnp.array(jnp.sqrt(rho1))
sim.state["psi2"] = jnp.array(jnp.sqrt(rho2))
# Define custom functions
def custom_density(state):
return jnp.abs(state["psi1"]) ** 2 + jnp.abs(state["psi2"]) ** 2
def custom_kick(state, V, dt):
state["psi1"] = jnp.exp(-1j * m1_per_hbar * dt * V) * state["psi1"]
state["psi2"] = jnp.exp(-1j * m2_per_hbar * dt * V) * state["psi2"]
return state
def custom_drift(state, k_sq, dt):
psi1_hat = jd.fft.pfft3d(state["psi1"])
psi1_hat = jnp.exp(dt * (-1.0j * k_sq / m1_per_hbar / 2.0)) * psi1_hat
state["psi1"] = jd.fft.pifft3d(psi1_hat)
psi2_hat = jd.fft.pfft3d(state["psi2"])
psi2_hat = jnp.exp(dt * (-1.0j * k_sq / m2_per_hbar / 2.0)) * psi2_hat
state["psi2"] = jd.fft.pifft3d(psi2_hat)
return state
def custom_plot(state, checkpoint_dir, i, params):
import matplotlib.pyplot as plt
import os
dynamic_range = params["output"]["plot_dynamic_range"]
# process distributed data
nx = state["psi1"].shape[0]
rho_bar_1 = jnp.mean(jnp.abs(state["psi1"]) ** 2)
rho_proj_1 = jnp.log10(
jax.experimental.multihost_utils.process_allgather(
jnp.mean(jnp.abs(state["psi1"]) ** 2, axis=2), tiled=True
)
).T
rho_bar_2 = jnp.mean(jnp.abs(state["psi2"]) ** 2)
rho_proj_2 = jnp.log10(
jax.experimental.multihost_utils.process_allgather(
jnp.mean(jnp.abs(state["psi2"]) ** 2, axis=2), tiled=True
)
).T
# create plot on process 0
if jax.process_index() == 0:
plt.clf()
# Field 1 projection
vmin1 = jnp.log10(rho_bar_1 / dynamic_range)
vmax1 = jnp.log10(rho_bar_1 * dynamic_range)
ax = plt.gca()
ax.imshow(
rho_proj_1,
cmap="inferno",
origin="lower",
vmin=vmin1,
vmax=vmax1,
extent=[0, nx, 0, nx],
)
ax.set_aspect("equal")
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.savefig(
os.path.join(checkpoint_dir, f"rho1_{i:03d}.png"),
bbox_inches="tight",
pad_inches=0,
)
plt.close()
# Field 2 projection
vmin2 = jnp.log10(rho_bar_2 / dynamic_range)
vmax2 = jnp.log10(rho_bar_2 * dynamic_range)
plt.clf()
ax = plt.gca()
ax.imshow(
rho_proj_2,
cmap="inferno",
origin="lower",
vmin=vmin2,
vmax=vmax2,
extent=[0, nx, 0, nx],
)
ax.set_aspect("equal")
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.savefig(
os.path.join(checkpoint_dir, f"rho2_{i:03d}.png"),
bbox_inches="tight",
pad_inches=0,
)
plt.close()
sim.custom_density = custom_density
sim.custom_plot = custom_plot
sim.custom_kick = custom_kick
sim.custom_drift = custom_drift
print("rho_bar:", sim.rho_bar)
return sim
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=1, help="Resolution multiplier")
args = parser.parse_args()
sim = set_up_simulation(args.res)
sim.run()
return sim
if __name__ == "__main__":
main()