Simulating Coulomb gases and log-gases in Python
First, import jax and numpyro and configure numpyro to use the CPU and to use 64bit floating point numbers.
import os import jax import jax.numpy as jnp import numpyro import numpyro.distributions as dist # numpyro configuration numpyro.set_platform("cpu") numpyro.set_host_device_count(os.process_cpu_count()) numpyro.enable_x64(True)
Modelling (for $\R$)
Set the parameters $N$ and $\beta$.
# Simulation parameters n = 50 # number of particles beta = 2.0 # inverse temperature
Then define the external potential $V: \R \to \R$. Note that all of these must be differentiable, because the HMC algorithm relies on the gradient of $V$. It's not possible to include jump-type singularities. For hard walls it is better to reparametrize the point coordinates (see later). Here are some examples:
# Harmonic trap: V(x) = 1/2 * |x|^2 V = lambda x: 1 / 2 * x**2 # M^2 + M^4 potential, has a bounded droplet only for gamma > -1 / 12 gamma = -1 / 12 V = lambda x: 1 / 2 * x**2 + gamma / 4 * x**4 # Logaritmic weakly confining potential, the equilibrium measure is 1 / (pi * (1 + x^2)) V = lambda x: jnp.log(1 + x**2)
One way to probe the particles is to include a fixed charge $c$ at some position $p$.
c, p = -0.8, 1.0 Vcharge = lambda x: -2 * c * jnp.log(jnp.abs(x - p))
When we set up the numpyro sampler later, we will use the logarithm of the probability density instead of the actual density.
This means here that instead of the factor $\prod_{i = 1}^N e^{-\beta N V(x_i)}$, we use $N \sum_{i = 1}^N V(x_i)$.
(The $N$ prefactor ensures that the equilibrium measure for $N \to \infty$ has bounded support.
We introduce the $\beta$ later.)
In JAX, we first use jax.vmap to vectorize the potential V, then we use jax.numpy.sum to sum the components.
Together with the fixed charge, we have
def external_potential(points): external = jnp.sum(jax.vmap(V)(points)) external += jnp.sum(jax.vmap(Vcharge)(points)) # add charge potential return n * external
The factor $\prod_{i < j}^N |x_j - x_i|^{\beta} = \exp(\beta \sum_{i < j}^N \log(|x_j - x_i|))$ is modelled by
def particle_interaction(points): ilist, jlist = jnp.triu_indices(len(points), k=1) # indices of upper triangle # Pairwise interaction ln(|z_i - z_j|^2) summed over i, j with i != j dists = jnp.abs(points[ilist] - points[jlist]) dists = jnp.maximum(dists, 1e-8) return -jnp.sum(jnp.log(dists)) # log-gas
For a Riesz-gas, assign a positive number to s and replace the last line of the particle_interaction function with
return -jnp.sum(dists**(-s)) # Riesz-gas
The sum of the external potential and the particle interaction energies is the following function.
@jax.jit def energy1d(points, beta): return beta * (particle_interaction(points) + external_potential(points) / 2)
Note that I scale both terms with beta, so that the support of the equilibrium measure is stable in the $\beta \to \infty$ limit.
I'm also using beta as a parameter for the function, so that I can set it later.
Sampling (for $\R$)
To sample with numpyro, we need to define a model function that specifies the random variables (here: points) and the factors of the probabilistic model (here: energy1d as the log probability)
# Do sampling with numpyro def model(beta): point_dist = dist.ImproperUniform(dist.constraints.real, (), event_shape=(1,)) points = numpyro.sample("points", point_dist.expand([n])) numpyro.factor("logp", -energy1d(points, beta)) # log probability
Replace the points_dist in the first line of the function by
point_dist = dist.ImproperUniform(dist.constraints.greater_than_eq(0.0), (), event_shape=(1,))
for one hard wall at $0$ (so that $points \in [0, \infty)^N$) and
point_dist = dist.Uniform(-1.0, 1.0)
for two hard walls at $-1$ and $1$.
We also need an initial state, which I sample here from the uniform distribution on $[-0.8, 0.8]^N$:
init_dist = dist.Uniform(-0.8, 0.8).expand([n]) init_state = init_dist.sample(jax.random.PRNGKey(0)) init_strategy = numpyro.infer.util.init_to_value(values={"points": init_state})
To set up the MCMC sampler, I use the NUTS kernel, which is a variant of HMC with adaptive step size.
Other kernels are listed in the numpyro documentation.
The class MCMC is the actual sampler.
The sampler needs a few warmup steps to get from the initial sample to typical samples, the actual numbers of generated data points is num_samples.
# Set up kernel = numpyro.infer.NUTS(model, init_strategy=init_strategy) mcmc = numpyro.infer.MCMC( kernel, num_warmup=1000, num_samples=10000, num_chains=jax.local_device_count(), progress_bar=True ) # Compute samples mcmc.run(jax.random.PRNGKey(0), beta) samples = mcmc.get_samples(group_by_chain=True)["points"] print("Shape of samples:", samples.shape) # shape: (num_chains, samples, n)
Since every Markov chain Monte Carlo sampler will produce somewhat correlated samples, it is important to thin out the generated data to get (almost) independent samples.
The diagnostics effective_sample_size and split_gelman_rubin can help to see if the sampler was sucessful and how much thinning is needed.
n_eff = numpyro.diagnostics.effective_sample_size(samples).min() print("min n_eff:", n_eff) print("max r_hat:", numpyro.diagnostics.split_gelman_rubin(samples).max())
The effective sample size should be of similar size to the num_samples parameter above.
If not, the sampler got stuck somewhere and the data is of low quality.
To get a thinned sample, use
from math import ceil thin = max(ceil(samples.shape[1] / n_eff), 1) data = samples[:, ::thin].flatten() print(f"{len(data):,} points (thinning: {thin})")
and to quickly visualize their distribution with matplotlib, use
import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(10, 6)) ax.hist(data, bins=500, density=True) plt.show()