Simulating Coulomb gases and log-gases in Python

First, import jax and numpyro and configure numpyro to use the CPU and to use 64bit floating point numbers.

import os

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

# numpyro configuration
numpyro.set_platform("cpu")
numpyro.set_host_device_count(os.process_cpu_count())
numpyro.enable_x64(True)

Modelling (for $\R$)

Set the parameters $N$ and $\beta$.

# Simulation parameters
n = 50  # number of particles
beta = 2.0  # inverse temperature

Then define the external potential $V: \R \to \R$. Note that all of these must be differentiable, because the HMC algorithm relies on the gradient of $V$. It's not possible to include jump-type singularities. For hard walls it is better to reparametrize the point coordinates (see later). Here are some examples:

# Harmonic trap: V(x) = 1/2 * |x|^2
V = lambda x: 1 / 2 * x**2

# M^2 + M^4 potential, has a bounded droplet only for gamma > -1 / 12
gamma = -1 / 12
V = lambda x: 1 / 2 * x**2 + gamma / 4 * x**4

# Logaritmic weakly confining potential, the equilibrium measure is 1 / (pi * (1 + x^2))
V = lambda x: jnp.log(1 + x**2)

One way to probe the particles is to include a fixed charge $c$ at some position $p$.

c, p = -0.8, 1.0
Vcharge = lambda x: -2 * c * jnp.log(jnp.abs(x - p))

When we set up the numpyro sampler later, we will use the logarithm of the probability density instead of the actual density. This means here that instead of the factor $\prod_{i = 1}^N e^{-\beta N V(x_i)}$, we use $N \sum_{i = 1}^N V(x_i)$. (The $N$ prefactor ensures that the equilibrium measure for $N \to \infty$ has bounded support. We introduce the $\beta$ later.) In JAX, we first use jax.vmap to vectorize the potential V, then we use jax.numpy.sum to sum the components. Together with the fixed charge, we have

def external_potential(points):
    external = jnp.sum(jax.vmap(V)(points))
    external += jnp.sum(jax.vmap(Vcharge)(points))  # add charge potential
    return n * external

The factor $\prod_{i < j}^N |x_j - x_i|^{\beta} = \exp(\beta \sum_{i < j}^N \log(|x_j - x_i|))$ is modelled by

def particle_interaction(points):
    ilist, jlist = jnp.triu_indices(len(points), k=1)  # indices of upper triangle

    # Pairwise interaction ln(|z_i - z_j|^2) summed over i, j with i != j
    dists = jnp.abs(points[ilist] - points[jlist])
    dists = jnp.maximum(dists, 1e-8)
    return -jnp.sum(jnp.log(dists))  # log-gas

For a Riesz-gas, assign a positive number to s and replace the last line of the particle_interaction function with

    return -jnp.sum(dists**(-s))  # Riesz-gas

The sum of the external potential and the particle interaction energies is the following function.

@jax.jit
def energy1d(points, beta):
    return beta * (particle_interaction(points) + external_potential(points) / 2)

Note that I scale both terms with beta, so that the support of the equilibrium measure is stable in the $\beta \to \infty$ limit. I'm also using beta as a parameter for the function, so that I can set it later.

Sampling (for $\R$)

To sample with numpyro, we need to define a model function that specifies the random variables (here: points) and the factors of the probabilistic model (here: energy1d as the log probability)

# Do sampling with numpyro
def model(beta):
    point_dist = dist.ImproperUniform(dist.constraints.real, (), event_shape=(1,))
    points = numpyro.sample("points", point_dist.expand([n]))
    numpyro.factor("logp", -energy1d(points, beta))  # log probability

Replace the points_dist in the first line of the function by

point_dist = dist.ImproperUniform(dist.constraints.greater_than_eq(0.0), (), event_shape=(1,))

for one hard wall at $0$ (so that $points \in [0, \infty)^N$) and

point_dist = dist.Uniform(-1.0, 1.0)

for two hard walls at $-1$ and $1$.

We also need an initial state, which I sample here from the uniform distribution on $[-0.8, 0.8]^N$:

init_dist = dist.Uniform(-0.8, 0.8).expand([n])
init_state = init_dist.sample(jax.random.PRNGKey(0))
init_strategy = numpyro.infer.util.init_to_value(values={"points": init_state})

To set up the MCMC sampler, I use the NUTS kernel, which is a variant of HMC with adaptive step size. Other kernels are listed in the numpyro documentation. The class MCMC is the actual sampler. The sampler needs a few warmup steps to get from the initial sample to typical samples, the actual numbers of generated data points is num_samples.

# Set up
kernel = numpyro.infer.NUTS(model, init_strategy=init_strategy)
mcmc = numpyro.infer.MCMC(
    kernel,
    num_warmup=1000,
    num_samples=10000,
    num_chains=jax.local_device_count(),
    progress_bar=True
)

# Compute samples
mcmc.run(jax.random.PRNGKey(0), beta)
samples = mcmc.get_samples(group_by_chain=True)["points"]
print("Shape of samples:", samples.shape)  # shape: (num_chains, samples, n)

Since every Markov chain Monte Carlo sampler will produce somewhat correlated samples, it is important to thin out the generated data to get (almost) independent samples. The diagnostics effective_sample_size and split_gelman_rubin can help to see if the sampler was sucessful and how much thinning is needed.

n_eff = numpyro.diagnostics.effective_sample_size(samples).min()
print("min n_eff:", n_eff)
print("max r_hat:", numpyro.diagnostics.split_gelman_rubin(samples).max())

The effective sample size should be of similar size to the num_samples parameter above. If not, the sampler got stuck somewhere and the data is of low quality.

To get a thinned sample, use

from math import ceil

thin = max(ceil(samples.shape[1] / n_eff), 1)
data = samples[:, ::thin].flatten()
print(f"{len(data):,} points (thinning: {thin})")

and to quickly visualize their distribution with matplotlib, use

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(data, bins=500, density=True)
plt.show()