A hands-on introduction to predictively oriented posteriors

A self-contained introduction to predictively oriented posteriors, their guarantees, and their computation.
Predictively oriented posteriors
Optimisation
PAC Bayes
Author

Yann McLatchie

Published

January 15, 2026

Set up
# dependencies
from collections import namedtuple
import jax
import jax.numpy as jnp
from jax import random, vmap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# experiment settings
N = 1_000
sigma2 = 1.0
sigma = np.sqrt(sigma2)
lambda_n = 1 / N
num_particles = 2 ** 5
num_mmd2_samples = 2 ** 7
SEED = 1234
key = random.PRNGKey(SEED)

Introduction

Notation

Throughout, we consider data \(x_{1:n}\) where \(x_i\in\mathcal{X}\) for \(i=1,\dots,n\), which we assume to be drawn from an unknown distribution \(P_0\). Let \(P_\Theta = \{P_\theta:\,\theta\in\Theta\}\) denote a class of models. For \(Q\in\mathcal{P}(\Theta)\) the space of distributions over \(\Theta\), let \(P_Q=\int P_\theta \,\mathrm{d} Q(\theta)\) denote the predictive distribution induced by \(Q\). Lastly, we write \(Q^{\otimes m}:=\bigotimes_{j=1}^{m}Q\) to denote the product measure over \(m\) replications of \(\theta\), each with marginal measure \(Q(\vartheta)\), and similarly \(\Pi^{\otimes m}\) the \(m\)-product measure with marginal measure \(\Pi\).

We write \(S:\mathcal{P}(\mathcal{X}) \times \mathcal{X} \to \mathbb{R}\cup\{\infty\}\) to denote a proper scoring rule: the expected score \(\mathcal{S}(P, P') := \mathsf{E}_{X\sim P'}\left\{S(P,X)\right\}\) is finite for all \(P,P'\in\mathcal{P}\); and, \(\mathcal{S}(P',P') \le \mathcal{S}(P,P')\) for all \(P,P' \in \mathcal{P}\), so that \(\mathcal{S}\) induces the statistical divergence \(\mathcal{D}_{S}(P,P') := \mathcal{S}(P,P') - \mathcal{S}(P',P') \geq 0\) for which \(\mathcal{D}_{S}(P,P') = 0 \Longleftrightarrow P = P'\). Given a target distribution \(P_0\), a candidate predictive distribution \(P\) is judged to be more accurate than an alternative candidate \(P'\) by the expected score \(\mathcal{S}\) if and only if \(\mathcal{S}(P, P_0) \leq \mathcal{S}(P', P_0)\).

Motivation

When \(P_{\theta}\) has density \(p_{\theta}\), the Bayes posterior \(\Pi_n\) and its associated posterior predictive distribution are defined as

\[ \Pi_n(\mathrm{d}\theta) \propto p_\theta(x_{1:n}) \Pi(\mathrm{d}\theta), \quad \quad \text{and}\quad\quad P_{\Pi_n}=\int_\Theta P_\theta \,\mathrm{d} \Pi_n(\theta). \tag{1}\]

In this set-up, the predictive distribution \(P_{\Pi_n}\) is derived as a consequence of the Bayes posterior \(\Pi_n\). Yet much of the motivation for Bayesian inference is the ability to integrate over parameter uncertainty when making predictions. And while \(P_{\Pi_n}\) is predictively optimal if \(P_0 \in P_\Theta\) (Aitchison 1975), this optimality breaks down once \(P_0 \notin P_\Theta\). Even in this case, however, as \(n\to\infty\) the posterior distribution will collapse around the KL-minimising parameter value, \(\Pi_n(\mathrm{d}\theta) \to \delta_{\theta^\star_{\mathrm{KL}}}\). And when the model is misspecified, there are no guarantees that this point predictive \(P_{\theta^\star_{\mathrm{KL}}}\) will be deliver calibrated or sharp predictions.

Though this is well-understood, robust Bayesian methods usually still treat prediction as a second-order concern, and are more concerned with parameter uncertainty. One family emblematic of this kind of research are Gibbs posteriors (Knoblauch, Jewson, and Damoulas 2022; Bissiri, Holmes, and Walker 2016) posteriors, which are often based on a scoring rule \(S\) and a so-called learning rate \(\lambda_n\):

\[ Q_n^{\dagger} = \arg\min_{Q \in \mathcal{P}(\Theta)}\left\{\frac{\lambda_n}{n}\sum_{i=1}^n\int S(P_\theta, x_i)\,\mathrm{d} Q(\theta)+d_{\mathrm{KL}}(Q;\,\Pi)\right\}. \tag{2}\]

This formulation makes clear that we are targeting an expected loss with respect to the posterior: an average-case problem. For clarity, we note that the solution of this optimisation program under the choice of \(\lambda_n = n\) and \(S(P_\theta, x) = -\log p_\theta(x)\) is precisely the Bayes posterior of Equation 1. Under very mild conditions (Martin and Syring 2022), the posterior will again collapse around a single value \(\theta^\star_{S}\) reducing the predictive to a singelton \(P_{\theta^\star_S}\), and again if the model is misspecified, then independently to the choice of \(S\) there are no guarantees of \(P_{\theta^\star_S}\) being calibrated or sharp. Indeed, the opposite is commonly observed in practice.

To release this philosophical tension, we reverse that order: the posterior predictive \(P_{Q} := \int_\Theta P_\theta \,\mathrm{d} Q(\theta)\) is the object of interest, and we obtain a posterior distribution over \(\Theta\) as the measure \(Q_n\) which quantifies uncertainty about \(\theta\) in a manner that is optimal for prediction. By inverting the order of integration, we lift the parametric statistical model to a mixing distribution, \(P_Q = \int P_\theta\,\mathrm{d}Q(\theta)\);

\[ Q_n = \arg\min_{Q \in \mathcal{P}(\Theta)}\left\{\frac{\lambda_n}{n}\sum_{i=1}^n S(P_Q, x_i)+d_{\mathrm{KL}}(Q;\,\Pi)\right\}. \tag{3}\]

Throughout, we refer to \(Q_n\) as the predictively-oriented (PrO) posterior.

Theory

The fundamental difference between the asymptotic behaviour of PrO posteriors and that of standard Gibbs posteriors is that, while the latter will concentrate onto a point mass under very light assumptions, the former will only concentrate onto a point mass when the induced predictive is exactly the true data-generating process. And when the model is misspecified, \(Q_n\) will adapt to nature of the misspecification to find the predictively-optimal mixing distribution. This can manifest itself by identifying multi-modal posterior distributions which, once integrated over, lift the asymptotic predictive distribution to the convex hull of \(P_\Theta\): there is no now longer the requirement that we reduce to a singleton in \(P_\Theta\) if a mixture of predictives can out-perform it. We refer to the difference between \(P_\Theta\) and its the convex hull (coloured in grey in the figure below) as the mixability gap inline with Grünwald and Ommen (2017).

Parameter space, \(\Theta\).

Predictive space, \(P_\Theta\), and the true data-generating process \(P_0\).

Figure 1: The posterior and posterior predictive distributions induced by different belief updates.

Under some regularity conditions on the map \(\theta \mapsto S(P_\theta, x)\), in particular that is it convex in its first argument, and some standard prior mass conditions McLatchie et al. (2025) prove that the PrO posterior induces predictive performance at least as good as the generalised Bayes posterior fit under the same score, as measured by that score. And as the mixability gap grows, this predictive difference grows too: in the scenario graphically depicted above, for \(n\) sufficiently large but finite, the PrO posterior induces a provably better predictive distribution to the Gibbs posterior:

\[ \mathsf{E}_{0}\{\mathcal{D}_{\mathcal{S}}(P_{Q_n}, P_0)\} < \mathsf{E}_{0}\{\mathcal{D}_{\mathcal{S}}(P_{Q_n^\dagger}, P_0)\}, \]

where \(\mathsf{E}_0\) denotes the expectation over data with respect to \(P_0\). Those authors also formally show that the PrO posterior \(Q_n\) will concentrate around a singleton only under model well-specification and otherwise concentrates around the predictively-optimal mixture.

Running example: a normal location model

Show code
def generate_gauss_mix_data(
    n, sigma, epsilon=0.8, theta_0=-2, theta_1=2, mu_0=0, sigma_0=1, seed=None
):
    """Helper function to sample from a mixture of Gaussians."""
    
    if seed is not None:
        np.random.seed(seed)

    # generate indicator z
    z = np.random.binomial(n=1, p=epsilon, size=n)

    # compute mu depending on z
    if isinstance(theta_0, (int, float)) and isinstance(theta_1, (int, float)):
        mu = np.where(z == 0, theta_0, theta_1)
    else:
        raise ValueError

    # generate response variable y
    y = np.random.normal(loc=mu, scale=sigma)

    # combine into a list
    data = {"z": z, "mu": mu, "y": y, "sigma": sigma, "mu_0": mu_0, "sigma_0": sigma_0}
    data_tuple = namedtuple("x", data.keys())(*data.values())
    return data_tuple
  
# simulate from a mixture of Gaussians
data = generate_gauss_mix_data(n=N, sigma=sigma, seed=SEED)

# plot the empirical data distribution
fig, ax = plt.subplots(figsize=(5, 3))
sns.histplot(data.y, stat='density', color='lightgray', 
             edgecolor='black', ax=ax)
plt.show()

Figure 2: Simulated data distribution.

Suppose we observe the univariate continuous data \(x_{i} \in \mathbb{R}\) for \(i = 1,\ldots,n\) for \(n = 1,000\) shown in Figure 2. We model these data as arising from a Gaussian distribution with fixed standard deviation \(\sigma > 0\) and wish to perform inference on the mean: \(P_\Theta = \{\mathsf{N}(\theta,\,\sigma^2):\,\theta \in \mathbb{R}\}\). We proceed by specifying a prior, say \(\pi(\theta) = \mathsf{N}(\mu_0, \sigma_0^2)\), which for the purposes of exposition we can assume would remain unchanged regardless of whether we were fitting a standard Bayesian posterior or a PrO posterior.

def log_prior(theta, mu_0=0.0, sigma_0=1.0):
    """Compute the log prior density."""
    
    return jax.scipy.stats.norm.logpdf(theta, loc=mu_0, scale=sigma_0).sum()

Predictive scoring rules

The fundamental consideration for the practical application of PrO posteriors is the calculation of the predictive scoring rule \(S(P_Q, x)\) in Equation 3. In particular, the objective differs from that of Equation 2 insofar as it is non-linear in \(Q\), and many of the tools theoretical analysis and computation rely on this linearity. To rectify this, we leverage the convexity of the scoring rule from Section 1.3 to apply Jensen’s inequality so that

\[ S\left( P_Q, x\right) = \int S\left( P_{\theta}, x \right) \,\mathrm{d} Q(\theta) - \Delta(Q, x), \tag{4}\]

for some non-negative remainder term \(\Delta:\mathcal{P}(\Theta) \times \mathcal{X} \to \mathbb{R}_{+}\) which we assume to be finite. While this remainder term need not simplify in general, McLatchie et al. (2025) show that in some cases it can be written as an expectation over an \(m\)-product measure in \(\Theta\), so that

\[ S\left( P_Q, x\right) = \int \underbrace{\{S(P_{\theta_1}, x) - \delta(\theta_1,\ldots,\theta_m;x) \}}_{= \mathsf{L}\left( \theta_{1:m};\, x \right)} \,\mathrm{d} Q^{\otimes m}(\theta_{1:m}), \tag{5}\]

where now the predictive score is linear in \(Q\), albeit on an extended space, and where \(\delta(\theta_{1:m};x)\) can be computed pointwise for all \(x\in\mathcal{X}\). In a word, this objective is now computationally feasible.

Each choice of predictive score induces a different \(\Delta\), and with it also a different dimensionality \(m\). For instance, Masegosa (2020) attains an upper bound on \(\Delta\) in the case of the log score under the assumption that \(\sup_\theta p_\theta(x)\) for all \(x\in\mathcal{X}\) resulting in an objective in \(m = 2\). Elsewhere, Morningstar, Alemi, and Dillon (2022) approximate \(P_Q\) with \(m\) Monte Carlo draws, thereby treating \(m\) as a hyper-parameter where the quality of the approximation improves with \(m\).

An important class of scoring rules which admit the representation in Equation 5 without any such approximation are kernel scoring rules, and are always available with \(m = 2\) (McLatchie et al. 2025, Lemma 2). Given their simplicity, for exposition purposes, we will continue considering only kernel scoring rules.

Running example: maximum mean discrepancy

Gaussian kernel in JAX
def gaussian_kernel(x, y, gamma2):
    """RBF kernel with squared lengthscale `gamma2`."""

    x_norm = jnp.sum(x**2, axis=1).reshape(-1, 1)
    y_norm = jnp.sum(y**2, axis=1).reshape(1, -1)
    dist = x_norm + y_norm - 2 * jnp.dot(x, y.T)

    return jnp.exp(-dist / (2 * gamma2))

We presently ground the above in our running example by considering the squared maximum mean discrepancy (MMD) as the predictive scoring rule. In this case, Shen et al. (2024) show that

\[ \begin{aligned} S(P_Q,x_{1:n}) &= \int \mathsf{L}_{\mathrm{MMD}}(\theta_1,\theta_2;x_{1:n})\,\mathrm{d}Q^{\otimes2}(\theta_1,\theta_2) \\ &= \int \langle \mu(P_{\theta_1}) - \mu(\delta_{\{x_{1:n}\}}), \mu(P_{\theta_2}) - \mu(\delta_{\{x_{1:n}\}}) \rangle_{\mathcal{H}}\,\mathrm{d}Q^{\otimes2}(\theta_1,\theta_2)\\ &= \int \left\{\langle \mu(P_{\theta_1}), \mu(P_{\theta_2}) \rangle_{\mathcal{H}} - \langle \mu(P_{\theta_2}), \delta_{\{x_{1:n}\}} \rangle_{\mathcal{H}} - \langle \mu(P_{\theta_1}), \mu(\delta_{\{x_{1:n}\}}) \rangle_{\mathcal{H}} \right\}\,\mathrm{d}Q^{\otimes2}(\theta_1,\theta_2) + C, \end{aligned} \]

where \(C\) is a constant which doesn’t dependent on \(\theta_1\) or \(\theta_2\). Each of these inner products are precisely the expected kernels: \(\langle \mu(P_{\theta_1}), \mu(P_{\theta_2}) \rangle_{\mathcal{H}} = \mathsf{E}_{y^{(1)}\sim P_{\theta_1}}\mathsf{E}_{y^{(2)}\sim P_{\theta_2}}\{\kappa(y^{(1)}, y^{(2)})\}\). As such, their expectations can be empirically estimated via Monte Carlo with \(N\) samples from \(P_{\theta_1}\) and \(P_{\theta_2}\), leading to \[\begin{multline} \mathsf{L}_{\mathrm{MMD}}(\theta_1,\theta_2;x_{1:n}) \approx \frac{1}{{N}^2} \sum_{j = 1}^N\sum_{\ell = 1}^N \kappa\{y^{(1)}_j, y^{(2)}_\ell\}\\ - \frac{1}{{N}\cdot n} \sum_{\ell = 1}^N\sum_{i = 1}^n \kappa\{y^{(1)}_j, x_i\} - \frac{1}{{N}\cdot n} \sum_{\ell = 1}^N\sum_{i = 1}^n \kappa\{y^{(2)}_\ell, x_i\}, \end{multline}\] where \(y^{(1)}_{1:N}\sim P_{\theta_1}\) and \(y^{(2)}_{1:N}\sim P_{\theta_2}\). In Python, this is readily implemented as follows.

def pred_score(
    theta_i, theta_j, key, data, sigma=1.0, num_mmd2_samples=100, gamma2=1.0
):
    """
    Compute the predictive score for the normal location model under the MMD.
    """
  
    # split the key for reproducibility
    subkey1, subkey2 = random.split(key)

    # sample from the predictive at `theta_i` and `theta_j`
    ftheta_i_samples = (
        theta_i
        + random.normal(subkey1, shape=(num_mmd2_samples, 1))
        * sigma
    )
    ftheta_j_samples = (
        theta_j
        + random.normal(subkey2, shape=(num_mmd2_samples, 1))
        * sigma
    )

    # compute the Gram matrices
    K_ij = gaussian_kernel(ftheta_i_samples, ftheta_j_samples, gamma2=gamma2)
    K_il = gaussian_kernel(ftheta_i_samples, data, gamma2=gamma2)
    K_jl = gaussian_kernel(ftheta_j_samples, data, gamma2=gamma2)

    # extract relevant dimensions
    assert ftheta_i_samples.shape[0] == ftheta_j_samples.shape[0]
    n = data.shape[0]

    # compute the component terms of the tandem predictive MMD2 and return
    sum_K_ij = (jnp.sum(K_ij)) / (num_mmd2_samples**2)
    sum_K_il = (jnp.sum(K_il)) / (num_mmd2_samples * n)
    sum_K_jl = (jnp.sum(K_jl)) / (num_mmd2_samples * n)
    return sum_K_ij - sum_K_il - sum_K_jl

In the specfic running example, under the Gaussian kernel the squared MMD can in fact be computed in closed form (see Chérief-Abdellatif and Alquier 2019, Appendix D). And for large values of \(n\) and \(N\), one could consider employing a more light-weight approximation to the expected kernel for computational reasons. For the purposes of this exposition, we treat only the most naïve, general approach.

Regression with the MMD

When considering regression models, the kernel should take into account both the covariates and target variable. Consider a regression setting where come as pairs \(\{X_i, y_i\}_{i=1}^n\). In this case, the predictive score takes the form of the efficient estimator proposed by Alquier and Gerber (2024, Equation 6)

\[ \mathsf{L}(\theta_1, \theta_2;\,X_{1:n}, y_{1:n}) = \sum_{i=1}^n \ell (\theta_1, \theta_2;\,X_i,y_i) \]

by slightly overloading notation, where now

\[ \ell(\theta_1, \theta_2;\,X,y) = \mathsf{E}_{\substack{y^{(1)}\sim P_{\theta_1}(\cdot\mid X)\\ y^{(2)}\sim P_{\theta_2}(\cdot\mid X)}}\left[\kappa\{y^{(1)}, y^{(2)}\} - \kappa\{y^{(1)}, y\} - \kappa\{y^{(2)}, y\}\right]. \]

As before, each of these expectations can be approximated by Monte Carlo:

\[ \ell(\theta_1, \theta_2;\,X, y) \approx \frac{1}{N^2}\sum_{j=1}^N\sum_{\ell=1}^N \kappa\{y^{(1)}_j, y^{(2)}_\ell\} - \frac{1}{N}\sum_{j = 1}^N \kappa\{y^{(1)}_j, y\} - \frac{1}{N}\sum_{\ell = 1}^N \kappa\{y^{(2)}_\ell, y\}, \]

for \(y_j^{(1)} \sim P_{\theta_1}(\cdot\mid X)\) and \(y_{\ell}^{(2)}\sim P_{\theta_2}(\cdot\mid X)\).

Computation

The objective of this section is to give a whistle-stop tour of the computational machinery most immediate to the PrO posterior, exposing its constituent parts more than explaining their complete derivation. For a more complete overview of Wasserstein gradient flows one could look at the work of Wild et al. (2023), McLatchie et al. (2025), Shen et al. (2024), and references therein.

We begin by re-writing the objective of Equation 3 as

\[ \mathcal{F}_n(Q) = \underbrace{\frac{\lambda_n}{n}\sum_{i=1}^n S(P_Q, x_i)}_{\lambda_n\mathcal{E}_n(Q)} - \int \log \pi(\theta) \,\mathrm{d}Q(\theta) + \int \log q(\theta) \,\mathrm{d}Q(\theta). \]

In order to construct a sampling algorithm for PrO posteriors, we look to derive a continuous process where particles evolve in the direction of steepest descent of the map \(\theta\mapsto\mathcal{F}_n(Q)(\theta)\). This is precisely achieved by evolving the distribution \(Q\) along its Wasserstein gradient

\[ \nabla_{\mathrm{W}}\mathcal{F}_n(Q)(\theta) = \lambda_n\nabla_{\mathrm{W}}\mathcal{E}_n(Q)(\theta) - \nabla_\theta \log \pi(\theta) + \nabla_\theta \log q(\theta). \]

Considering at this stage the mean-field representation of this process gives rise to the stochastic differential equation, for \(\theta_t\sim Q_t\),

\[ \mathrm{d}\theta_t = -\left\{\lambda_n\nabla_{\mathrm{W}}\mathcal{E}_n(Q_t)(\theta_t) - \nabla_\theta\log\pi(\theta_t) \right\}\mathrm{d}t + \sqrt 2 B_t, \tag{6}\]

where \(B_t\) are independent draws from a Brownian noise. This process still depends on the inaccessible distribution \(Q_t\) at each step, and so to allow for computation we approximate is with the discrete distribution of \(p\) interacting particles, \(\widehat{Q}_t = p^{-1}\sum_{j=1}^p \delta_{\{\theta_t^{j}\}}\), where \(\delta_{\{x\}}\) denotes the delta function evaluated at \(x\). Given this approximation, we evolve each particle \(j = 1,\ldots,p\) at time \(t\) according to

\[ \mathrm{d}\theta_t^{(j)} = -\left[\lambda_n\nabla_{\mathrm{W}}\mathcal{E}_n(\widehat{Q}_t)\{\theta_t^{(j)}\} - \nabla_\theta\log\pi\{\theta_t^{(j)}\} \right]\mathrm{d}t + \sqrt 2 B_t^{(j)}. \]

It is precisely this Wasserstein gradient flow we use to sample, asymptotically exactly, from the PrO posterior defined in Equation 3.

Running example: implementing the Wasserstein gradient flow

Recalling from Section 2.1 that \(\mathsf{L}_{\operatorname{MMD}}(\theta_1,\theta_2;x) = \langle\mu(P_{\theta_1}) - \mu(\delta_{\{x_{1:n}\}}), \mu(P_{\theta_2}) - \mu(\delta_{\{x_{1:n}\}})\rangle_{\mathcal{H}}\), we denote the Wasserstein gradient of the free energy term in Equation 6

\[ \nabla_{\mathrm{W}}\mathcal{E}_n(Q_t)(\theta_t) = \frac{1}{n}\sum_{i=1}^n \int \nabla_1\mathsf{L}_{\operatorname{MMD}}(\theta_t,\vartheta;x_i) \,\mathrm{d} Q_t(\vartheta), \]

where \(\nabla_1\) denotes the Euclidean gradient with respect to the first argument, \(\theta_t\). Under the finite-particle approximation \(\widehat{Q}_t\) described above, this now takes the form

\[ \nabla_{\mathrm{W}}\mathcal{E}_n(\widehat{Q}_t)\{\theta^{(j)}_t\} = \frac{1}{n}\sum_{i=1}^n \frac{1}{(p - 1)}\sum_{\ell\ne j} \nabla_1\mathsf{L}_{\operatorname{MMD}}\{\theta_t^{(j)},\theta_t^{(\ell)};x_i\}. \]

As such, swapping the order of integration by Tonelli-Fubini, the particle evolutions are governed by

\[ \mathrm{d} \theta_t^{(j)} = -\left[\frac{ 1}{(p - 1)}\sum_{\substack{\ell \ne j}} \frac{ \lambda_n}{n} \sum_{i=1}^n \nabla_1\mathsf{L}_{\operatorname{MMD}}\{\theta_t^{(j)}, \theta_t^{(\ell)}; x_i\} - \nabla_\theta \log \pi\{\theta_t^{(j)}\}\right] \,\mathrm{d} t + \sqrt{2}\,\mathrm{d} B_t^{(j)}. \]

This particle system is readily implemented in Python as follows.

def compute_drift_field(
    particles, data, key, sigma=1.0, num_mmd2_samples=100, gamma2=1.0
):
    """Compute the particle drift field."""

    # label the particles
    num_particles, _ = particles.shape
    idxs = jnp.arange(num_particles)
    
    # split the key
    particle_keys = random.split(key, num_particles)

    # gradient of predictive loss w.r.t. first argument
    grad_pred_score = jax.grad(pred_score, argnums=0)


    def compute_drift_j(j, theta_j):
        """
        Compute drift for a single particle based on symmetric interactions.
        """

        # masked average over l ≠ j
        grads = vmap(
          lambda theta_l: grad_pred_score(theta_j, theta_l, particle_keys[j], data)
        )(particles)
        mask = (idxs != j).astype(jnp.float32)[:, None]
        masked_grads = grads * mask
        interaction_term = masked_grads.sum(axis=0) / mask.sum()
        
        # gradient of the log prior
        grad_log_prior = jax.grad(log_prior, argnums=0)

        # prior influence
        prior_term = grad_log_prior(theta_j)

        return -(lambda_n * interaction_term - prior_term)

    # vectorised drift computation
    drifts = vmap(compute_drift_j, in_axes=(0, 0))(idxs, particles)
    return drifts
      

def flow_kernel(drift_field_fn, particles, data, key, dt, lambda_n):
    """Wasserstein gradient flow kernel."""

    # extract the dimensions
    num_particles, dim = particles.shape
    
    # split the key
    key_noise, key_drift = random.split(key)

    # simulate Brownian noise
    noise = jnp.sqrt(2 * dt) * random.normal(key_noise, (num_particles, dim))

    # vectorised drift computation
    drifts = drift_field_fn(particles, data, key_drift)

    # Euler-Maruyama update
    return particles + dt * drifts + noise

Which results in the following particle trajectories and posterior distribution in our running univariate Gaussian example.

Show code
def step_fn(particles, key):
    """Evolve the particles one step ahead."""
    
    new_particles = flow_kernel(compute_drift_field, particles, data.y.reshape(-1, 1), key, dt, lambda_n)
    return new_particles, new_particles

# generate the random keys
num_samples = 1_000
keys = random.split(key, num_samples)

# initialise particles from the prior
num_particles = 20
init_particles = random.normal(key, shape=(num_particles, 1))

# define the step size
dt = 1e-3

# define the learning rate
lambda_n = 1e3

# sample from the flow
_, trajectory = jax.lax.scan(step_fn, init_particles, keys)

# initialise the plot
fig, ax = plt.subplots(1, 2, figsize=(6, 4), gridspec_kw={'width_ratios':[4, 1], 'wspace':0.05})
ax_main, ax_kde = ax

# plot the WGF particle trajectories
for i in range(trajectory.shape[1]):
    ax_main.plot(np.arange(trajectory.shape[0]), trajectory[:, i], color="#4477AA", alpha=0.4)    
        
# titles etc.
ax_main.set_xlabel(r"Time step, $t$")
ax_main.set_ylabel(r"$\theta_t^{(j)}$")
ax_main.set_title("Particle trajectories")
ax_main.set_yticks([-4, -2, 0, 2, 4])
ax_main.yaxis.set_ticks_position('left')
ax_main.tick_params(axis='y', which='both', direction='inout', labelleft=True)
ax_main.set_ylim(-5, 5);

# plot the empirical posterior distributions
sns.kdeplot(y=trajectory[-1, :].reshape(-1), ax=ax_kde, color="#4477AA", fill=False, linewidth=3, bw_adjust=0.6)

# titles etc.
ax_kde.set_title("PrO", color="#4477AA", fontsize=12)
ax_kde.set_ylabel("")
ax_kde.set_yticklabels([]) 
ax_kde.tick_params(axis='y', which='both', left=False, right=False) 
ax_kde.set_ylim(-5, 5);

Figure 3: Wasserstein gradient flow particle trajectories and posterior density.

We only note here that the PrO posterior recovers precisely the predictively-optimal mixture over \(\Theta\) to recover the truly bimodal data-generating process responsible for Figure 2.

References

Aitchison, J. 1975. “Goodness of Prediction Fit.” Biometrika 62 (3): 547–54. https://doi.org/10.1093/biomet/62.3.547.
Alquier, P, and M Gerber. 2024. “Universal Robust Regression via Maximum Mean Discrepancy.” Biometrika 111 (1): 71–92. https://doi.org/10.1093/biomet/asad031.
Bissiri, Pier Giovanni, Chris Holmes, and Stephen Walker. 2016. “A General Framework for Updating Belief Distributions.” Journal of the Royal Statistical Society: Series B (Statistical Methodology) 78 (5): 1103–30. https://doi.org/10.1111/rssb.12158.
Chérief-Abdellatif, Badr-Eddine, and Pierre Alquier. 2019. MMD-Bayes: Robust Bayesian Estimation via Maximum Mean Discrepancy.” arXiv. https://doi.org/10.48550/ARXIV.1909.13339.
Grünwald, Peter, and Thijs van Ommen. 2017. “Inconsistency of Bayesian Inference for Misspecified Linear Models, and a Proposal for Repairing It.” Bayesian Analysis 12 (4): 1069–1103. https://doi.org/10.1214/17-BA1085.
Knoblauch, Jeremias, Jack Jewson, and Theodoros Damoulas. 2022. “An Optimization-Centric View on Bayes’ Rule: Reviewing and Generalizing Variational Inference.” Journal of Machine Learning Research 23 (132): 1–109. http://jmlr.org/papers/v23/19-1047.html.
Martin, Ryan, and Nicholas Syring. 2022. “Direct Gibbs Posterior Inference on Risk Minimizers: Construction, Concentration, and Calibration.” In Handbook of Statistics, 47:1–41. Elsevier. https://doi.org/10.1016/bs.host.2022.06.004.
Masegosa, Andres. 2020. “Learning Under Model Misspecification: Applications to Variational and Ensemble Methods.” In Advances in Neural Information Processing Systems, edited by H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, 33:5479–91. Curran Associates, Inc. https://proceedings.neurips.cc/paper_files/paper/2020/file/3ac48664b7886cf4e4ab4aba7e6b6bc9-Paper.pdf.
McLatchie, Yann, Badr-Eddine Cherief-Abdellatif, David T. Frazier, and Jeremias Knoblauch. 2025. “Predictively Oriented Posteriors.” arXiv. https://doi.org/10.48550/arXiv.2510.01915.
Morningstar, Warren R., Alex Alemi, and Joshua V. Dillon. 2022. “PACm-Bayes: Narrowing the Empirical Risk Gap in the Misspecified Bayesian Regime.” In Proceedings of the 25th International Conference on Artificial Intelligence and Statistics, edited by Gustau Camps-Valls, Francisco J. R. Ruiz, and Isabel Valera, 151:8270–98. Proceedings of Machine Learning Research. PMLR. https://proceedings.mlr.press/v151/morningstar22a.html.
Shen, Zheyang, Jeremias Knoblauch, Sam Power, and Chris J. Oates. 2024. “Prediction-Centric Uncertainty Quantification via MMD.” arXiv. http://arxiv.org/abs/2410.11637.
Wild, Veit David, Sahra Ghalebikesabi, Dino Sejdinovic, and Jeremias Knoblauch. 2023. “A Rigorous Link Between Deep Ensembles and (Variational) Bayesian Methods.” In Advances in Neural Information Processing Systems, edited by A. Oh, T. Naumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine, 36:39782–811. Curran Associates, Inc. https://proceedings.neurips.cc/paper_files/paper/2023/file/7d25b1db211d99d5750ec45d65fd6e4e-Paper-Conference.pdf.

Citation

BibTeX citation:
@online{mclatchie2026,
  author = {Yann McLatchie},
  title = {A Hands-on Introduction to Predictively Oriented Posteriors},
  date = {2026-01-15},
  url = {https://yannmclatchie.github.io/blog/posts/pro-tutorial},
  langid = {en}
}
For attribution, please cite this work as:
Yann McLatchie. 2026. “A Hands-on Introduction to Predictively Oriented Posteriors.” January 15, 2026. https://yannmclatchie.github.io/blog/posts/pro-tutorial.