A hands-on introduction to predictively oriented posteriors
A self-contained introduction to predictively oriented posteriors, their guarantees, and their computation.
Predictively oriented posteriors
Optimisation
PAC Bayes
Author
Yann McLatchie
Published
January 15, 2026
Set up
# dependenciesfrom collections import namedtupleimport jaximport jax.numpy as jnpfrom jax import random, vmapimport numpy as npimport matplotlib.pyplot as pltimport seaborn as sns# experiment settingsN =1_000sigma2 =1.0sigma = np.sqrt(sigma2)lambda_n =1/ Nnum_particles =2**5num_mmd2_samples =2**7SEED =1234key = random.PRNGKey(SEED)
Introduction
Notation
Throughout, we consider data \(x_{1:n}\) where \(x_i\in\mathcal{X}\) for \(i=1,\dots,n\), which we assume to be drawn from an unknown distribution \(P_0\). Let \(P_\Theta = \{P_\theta:\,\theta\in\Theta\}\) denote a class of models. For \(Q\in\mathcal{P}(\Theta)\) the space of distributions over \(\Theta\), let \(P_Q=\int P_\theta \,\mathrm{d} Q(\theta)\) denote the predictive distribution induced by \(Q\). Lastly, we write \(Q^{\otimes m}:=\bigotimes_{j=1}^{m}Q\) to denote the product measure over \(m\) replications of \(\theta\), each with marginal measure \(Q(\vartheta)\), and similarly \(\Pi^{\otimes m}\) the \(m\)-product measure with marginal measure \(\Pi\).
We write \(S:\mathcal{P}(\mathcal{X}) \times \mathcal{X} \to \mathbb{R}\cup\{\infty\}\) to denote a proper scoring rule: the expected score \(\mathcal{S}(P, P') := \mathsf{E}_{X\sim P'}\left\{S(P,X)\right\}\) is finite for all \(P,P'\in\mathcal{P}\); and, \(\mathcal{S}(P',P') \le \mathcal{S}(P,P')\) for all \(P,P' \in \mathcal{P}\), so that \(\mathcal{S}\) induces the statistical divergence \(\mathcal{D}_{S}(P,P') := \mathcal{S}(P,P') - \mathcal{S}(P',P') \geq 0\) for which \(\mathcal{D}_{S}(P,P') = 0 \Longleftrightarrow P = P'\). Given a target distribution \(P_0\), a candidate predictive distribution \(P\) is judged to be more accurate than an alternative candidate \(P'\) by the expected score \(\mathcal{S}\) if and only if \(\mathcal{S}(P, P_0) \leq \mathcal{S}(P', P_0)\).
Motivation
When \(P_{\theta}\) has density \(p_{\theta}\), the Bayes posterior \(\Pi_n\) and its associated posterior predictive distribution are defined as
In this set-up, the predictive distribution \(P_{\Pi_n}\) is derived as a consequence of the Bayes posterior \(\Pi_n\). Yet much of the motivation for Bayesian inference is the ability to integrate over parameter uncertainty when making predictions. And while \(P_{\Pi_n}\) is predictively optimal if \(P_0 \in P_\Theta\)(Aitchison 1975), this optimality breaks down once \(P_0 \notin P_\Theta\). Even in this case, however, as \(n\to\infty\) the posterior distribution will collapse around the KL-minimising parameter value, \(\Pi_n(\mathrm{d}\theta) \to \delta_{\theta^\star_{\mathrm{KL}}}\). And when the model is misspecified, there are no guarantees that this point predictive \(P_{\theta^\star_{\mathrm{KL}}}\) will be deliver calibrated or sharp predictions.
Though this is well-understood, robust Bayesian methods usually still treat prediction as a second-order concern, and are more concerned with parameter uncertainty. One family emblematic of this kind of research are Gibbs posteriors(Knoblauch, Jewson, and Damoulas 2022; Bissiri, Holmes, and Walker 2016) posteriors, which are often based on a scoring rule \(S\) and a so-called learning rate\(\lambda_n\):
This formulation makes clear that we are targeting an expected loss with respect to the posterior: an average-case problem. For clarity, we note that the solution of this optimisation program under the choice of \(\lambda_n = n\) and \(S(P_\theta, x) = -\log p_\theta(x)\) is precisely the Bayes posterior of Equation 1. Under very mild conditions (Martin and Syring 2022), the posterior will again collapse around a single value \(\theta^\star_{S}\) reducing the predictive to a singelton \(P_{\theta^\star_S}\), and again if the model is misspecified, then independently to the choice of \(S\) there are no guarantees of \(P_{\theta^\star_S}\) being calibrated or sharp. Indeed, the opposite is commonly observed in practice.
To release this philosophical tension, we reverse that order: the posterior predictive \(P_{Q} := \int_\Theta P_\theta \,\mathrm{d} Q(\theta)\) is the object of interest, and we obtain a posterior distribution over \(\Theta\) as the measure \(Q_n\) which quantifies uncertainty about \(\theta\) in a manner that is optimal for prediction. By inverting the order of integration, we lift the parametric statistical model to a mixing distribution, \(P_Q = \int P_\theta\,\mathrm{d}Q(\theta)\);
Throughout, we refer to \(Q_n\) as the predictively-oriented (PrO) posterior.
Theory
The fundamental difference between the asymptotic behaviour of PrO posteriors and that of standard Gibbs posteriors is that, while the latter will concentrate onto a point mass under very light assumptions, the former will only concentrate onto a point mass when the induced predictive is exactly the true data-generating process. And when the model is misspecified, \(Q_n\) will adapt to nature of the misspecification to find the predictively-optimal mixing distribution. This can manifest itself by identifying multi-modal posterior distributions which, once integrated over, lift the asymptotic predictive distribution to the convex hull of \(P_\Theta\): there is no now longer the requirement that we reduce to a singleton in \(P_\Theta\) if a mixture of predictives can out-perform it. We refer to the difference between \(P_\Theta\) and its the convex hull (coloured in grey in the figure below) as the mixability gap inline with Grünwald and Ommen (2017).
Parameter space, \(\Theta\).
Predictive space, \(P_\Theta\), and the true data-generating process \(P_0\).
Figure 1: The posterior and posterior predictive distributions induced by different belief updates.
Under some regularity conditions on the map \(\theta \mapsto S(P_\theta, x)\), in particular that is it convex in its first argument, and some standard prior mass conditions McLatchie et al. (2025) prove that the PrO posterior induces predictive performance at least as good as the generalised Bayes posterior fit under the same score, as measured by that score. And as the mixability gap grows, this predictive difference grows too: in the scenario graphically depicted above, for \(n\) sufficiently large but finite, the PrO posterior induces a provably better predictive distribution to the Gibbs posterior:
where \(\mathsf{E}_0\) denotes the expectation over data with respect to \(P_0\). Those authors also formally show that the PrO posterior \(Q_n\) will concentrate around a singleton only under model well-specification and otherwise concentrates around the predictively-optimal mixture.
Running example: a normal location model
Show code
def generate_gauss_mix_data( n, sigma, epsilon=0.8, theta_0=-2, theta_1=2, mu_0=0, sigma_0=1, seed=None):"""Helper function to sample from a mixture of Gaussians."""if seed isnotNone: np.random.seed(seed)# generate indicator z z = np.random.binomial(n=1, p=epsilon, size=n)# compute mu depending on zifisinstance(theta_0, (int, float)) andisinstance(theta_1, (int, float)): mu = np.where(z ==0, theta_0, theta_1)else:raiseValueError# generate response variable y y = np.random.normal(loc=mu, scale=sigma)# combine into a list data = {"z": z, "mu": mu, "y": y, "sigma": sigma, "mu_0": mu_0, "sigma_0": sigma_0} data_tuple = namedtuple("x", data.keys())(*data.values())return data_tuple# simulate from a mixture of Gaussiansdata = generate_gauss_mix_data(n=N, sigma=sigma, seed=SEED)# plot the empirical data distributionfig, ax = plt.subplots(figsize=(5, 3))sns.histplot(data.y, stat='density', color='lightgray', edgecolor='black', ax=ax)plt.show()
Figure 2: Simulated data distribution.
Suppose we observe the univariate continuous data \(x_{i} \in \mathbb{R}\) for \(i = 1,\ldots,n\) for \(n = 1,000\) shown in Figure 2. We model these data as arising from a Gaussian distribution with fixed standard deviation \(\sigma > 0\) and wish to perform inference on the mean: \(P_\Theta = \{\mathsf{N}(\theta,\,\sigma^2):\,\theta \in \mathbb{R}\}\). We proceed by specifying a prior, say \(\pi(\theta) = \mathsf{N}(\mu_0, \sigma_0^2)\), which for the purposes of exposition we can assume would remain unchanged regardless of whether we were fitting a standard Bayesian posterior or a PrO posterior.
The fundamental consideration for the practical application of PrO posteriors is the calculation of the predictive scoring rule \(S(P_Q, x)\) in Equation 3. In particular, the objective differs from that of Equation 2 insofar as it is non-linear in \(Q\), and many of the tools theoretical analysis and computation rely on this linearity. To rectify this, we leverage the convexity of the scoring rule from Section 1.3 to apply Jensen’s inequality so that
for some non-negative remainder term \(\Delta:\mathcal{P}(\Theta) \times \mathcal{X} \to \mathbb{R}_{+}\) which we assume to be finite. While this remainder term need not simplify in general, McLatchie et al. (2025) show that in some cases it can be written as an expectation over an \(m\)-product measure in \(\Theta\), so that
where now the predictive score is linear in \(Q\), albeit on an extended space, and where \(\delta(\theta_{1:m};x)\) can be computed pointwise for all \(x\in\mathcal{X}\). In a word, this objective is now computationally feasible.
Each choice of predictive score induces a different \(\Delta\), and with it also a different dimensionality \(m\). For instance, Masegosa (2020) attains an upper bound on \(\Delta\) in the case of the log score under the assumption that \(\sup_\theta p_\theta(x)\) for all \(x\in\mathcal{X}\) resulting in an objective in \(m = 2\). Elsewhere, Morningstar, Alemi, and Dillon (2022) approximate \(P_Q\) with \(m\) Monte Carlo draws, thereby treating \(m\) as a hyper-parameter where the quality of the approximation improves with \(m\).
An important class of scoring rules which admit the representation in Equation 5 without any such approximation are kernel scoring rules, and are always available with \(m = 2\)(McLatchie et al. 2025, Lemma 2). Given their simplicity, for exposition purposes, we will continue considering only kernel scoring rules.
We presently ground the above in our running example by considering the squared maximum mean discrepancy (MMD) as the predictive scoring rule. In this case, Shen et al. (2024) show that
where \(C\) is a constant which doesn’t dependent on \(\theta_1\) or \(\theta_2\). Each of these inner products are precisely the expected kernels: \(\langle \mu(P_{\theta_1}), \mu(P_{\theta_2}) \rangle_{\mathcal{H}} = \mathsf{E}_{y^{(1)}\sim P_{\theta_1}}\mathsf{E}_{y^{(2)}\sim P_{\theta_2}}\{\kappa(y^{(1)}, y^{(2)})\}\). As such, their expectations can be empirically estimated via Monte Carlo with \(N\) samples from \(P_{\theta_1}\) and \(P_{\theta_2}\), leading to \[\begin{multline}
\mathsf{L}_{\mathrm{MMD}}(\theta_1,\theta_2;x_{1:n}) \approx \frac{1}{{N}^2} \sum_{j = 1}^N\sum_{\ell = 1}^N \kappa\{y^{(1)}_j, y^{(2)}_\ell\}\\ - \frac{1}{{N}\cdot n} \sum_{\ell = 1}^N\sum_{i = 1}^n \kappa\{y^{(1)}_j, x_i\} - \frac{1}{{N}\cdot n} \sum_{\ell = 1}^N\sum_{i = 1}^n \kappa\{y^{(2)}_\ell, x_i\},
\end{multline}\] where \(y^{(1)}_{1:N}\sim P_{\theta_1}\) and \(y^{(2)}_{1:N}\sim P_{\theta_2}\). In Python, this is readily implemented as follows.
def pred_score( theta_i, theta_j, key, data, sigma=1.0, num_mmd2_samples=100, gamma2=1.0):""" Compute the predictive score for the normal location model under the MMD. """# split the key for reproducibility subkey1, subkey2 = random.split(key)# sample from the predictive at `theta_i` and `theta_j` ftheta_i_samples = ( theta_i+ random.normal(subkey1, shape=(num_mmd2_samples, 1))* sigma ) ftheta_j_samples = ( theta_j+ random.normal(subkey2, shape=(num_mmd2_samples, 1))* sigma )# compute the Gram matrices K_ij = gaussian_kernel(ftheta_i_samples, ftheta_j_samples, gamma2=gamma2) K_il = gaussian_kernel(ftheta_i_samples, data, gamma2=gamma2) K_jl = gaussian_kernel(ftheta_j_samples, data, gamma2=gamma2)# extract relevant dimensionsassert ftheta_i_samples.shape[0] == ftheta_j_samples.shape[0] n = data.shape[0]# compute the component terms of the tandem predictive MMD2 and return sum_K_ij = (jnp.sum(K_ij)) / (num_mmd2_samples**2) sum_K_il = (jnp.sum(K_il)) / (num_mmd2_samples * n) sum_K_jl = (jnp.sum(K_jl)) / (num_mmd2_samples * n)return sum_K_ij - sum_K_il - sum_K_jl
In the specfic running example, under the Gaussian kernel the squared MMD can in fact be computed in closed form (see Chérief-Abdellatif and Alquier 2019, Appendix D). And for large values of \(n\) and \(N\), one could consider employing a more light-weight approximation to the expected kernel for computational reasons. For the purposes of this exposition, we treat only the most naïve, general approach.
Regression with the MMD
When considering regression models, the kernel should take into account both the covariates and target variable. Consider a regression setting where come as pairs \(\{X_i, y_i\}_{i=1}^n\). In this case, the predictive score takes the form of the efficient estimator proposed by Alquier and Gerber (2024, Equation 6)
for \(y_j^{(1)} \sim P_{\theta_1}(\cdot\mid X)\) and \(y_{\ell}^{(2)}\sim P_{\theta_2}(\cdot\mid X)\).
Computation
The objective of this section is to give a whistle-stop tour of the computational machinery most immediate to the PrO posterior, exposing its constituent parts more than explaining their complete derivation. For a more complete overview of Wasserstein gradient flows one could look at the work of Wild et al. (2023), McLatchie et al. (2025), Shen et al. (2024), and references therein.
We begin by re-writing the objective of Equation 3 as
In order to construct a sampling algorithm for PrO posteriors, we look to derive a continuous process where particles evolve in the direction of steepest descent of the map \(\theta\mapsto\mathcal{F}_n(Q)(\theta)\). This is precisely achieved by evolving the distribution \(Q\) along its Wasserstein gradient
where \(B_t\) are independent draws from a Brownian noise. This process still depends on the inaccessible distribution \(Q_t\) at each step, and so to allow for computation we approximate is with the discrete distribution of \(p\) interacting particles, \(\widehat{Q}_t = p^{-1}\sum_{j=1}^p \delta_{\{\theta_t^{j}\}}\), where \(\delta_{\{x\}}\) denotes the delta function evaluated at \(x\). Given this approximation, we evolve each particle \(j = 1,\ldots,p\) at time \(t\) according to
It is precisely this Wasserstein gradient flow we use to sample, asymptotically exactly, from the PrO posterior defined in Equation 3.
Running example: implementing the Wasserstein gradient flow
Recalling from Section 2.1 that \(\mathsf{L}_{\operatorname{MMD}}(\theta_1,\theta_2;x) = \langle\mu(P_{\theta_1}) - \mu(\delta_{\{x_{1:n}\}}), \mu(P_{\theta_2}) - \mu(\delta_{\{x_{1:n}\}})\rangle_{\mathcal{H}}\), we denote the Wasserstein gradient of the free energy term in Equation 6
where \(\nabla_1\) denotes the Euclidean gradient with respect to the first argument, \(\theta_t\). Under the finite-particle approximation \(\widehat{Q}_t\) described above, this now takes the form
This particle system is readily implemented in Python as follows.
def compute_drift_field( particles, data, key, sigma=1.0, num_mmd2_samples=100, gamma2=1.0):"""Compute the particle drift field."""# label the particles num_particles, _ = particles.shape idxs = jnp.arange(num_particles)# split the key particle_keys = random.split(key, num_particles)# gradient of predictive loss w.r.t. first argument grad_pred_score = jax.grad(pred_score, argnums=0)def compute_drift_j(j, theta_j):""" Compute drift for a single particle based on symmetric interactions. """# masked average over l ≠ j grads = vmap(lambda theta_l: grad_pred_score(theta_j, theta_l, particle_keys[j], data) )(particles) mask = (idxs != j).astype(jnp.float32)[:, None] masked_grads = grads * mask interaction_term = masked_grads.sum(axis=0) / mask.sum()# gradient of the log prior grad_log_prior = jax.grad(log_prior, argnums=0)# prior influence prior_term = grad_log_prior(theta_j)return-(lambda_n * interaction_term - prior_term)# vectorised drift computation drifts = vmap(compute_drift_j, in_axes=(0, 0))(idxs, particles)return driftsdef flow_kernel(drift_field_fn, particles, data, key, dt, lambda_n):"""Wasserstein gradient flow kernel."""# extract the dimensions num_particles, dim = particles.shape# split the key key_noise, key_drift = random.split(key)# simulate Brownian noise noise = jnp.sqrt(2* dt) * random.normal(key_noise, (num_particles, dim))# vectorised drift computation drifts = drift_field_fn(particles, data, key_drift)# Euler-Maruyama updatereturn particles + dt * drifts + noise
Which results in the following particle trajectories and posterior distribution in our running univariate Gaussian example.
Show code
def step_fn(particles, key):"""Evolve the particles one step ahead.""" new_particles = flow_kernel(compute_drift_field, particles, data.y.reshape(-1, 1), key, dt, lambda_n)return new_particles, new_particles# generate the random keysnum_samples =1_000keys = random.split(key, num_samples)# initialise particles from the priornum_particles =20init_particles = random.normal(key, shape=(num_particles, 1))# define the step sizedt =1e-3# define the learning ratelambda_n =1e3# sample from the flow_, trajectory = jax.lax.scan(step_fn, init_particles, keys)# initialise the plotfig, ax = plt.subplots(1, 2, figsize=(6, 4), gridspec_kw={'width_ratios':[4, 1], 'wspace':0.05})ax_main, ax_kde = ax# plot the WGF particle trajectoriesfor i inrange(trajectory.shape[1]): ax_main.plot(np.arange(trajectory.shape[0]), trajectory[:, i], color="#4477AA", alpha=0.4) # titles etc.ax_main.set_xlabel(r"Time step, $t$")ax_main.set_ylabel(r"$\theta_t^{(j)}$")ax_main.set_title("Particle trajectories")ax_main.set_yticks([-4, -2, 0, 2, 4])ax_main.yaxis.set_ticks_position('left')ax_main.tick_params(axis='y', which='both', direction='inout', labelleft=True)ax_main.set_ylim(-5, 5);# plot the empirical posterior distributionssns.kdeplot(y=trajectory[-1, :].reshape(-1), ax=ax_kde, color="#4477AA", fill=False, linewidth=3, bw_adjust=0.6)# titles etc.ax_kde.set_title("PrO", color="#4477AA", fontsize=12)ax_kde.set_ylabel("")ax_kde.set_yticklabels([]) ax_kde.tick_params(axis='y', which='both', left=False, right=False) ax_kde.set_ylim(-5, 5);
Figure 3: Wasserstein gradient flow particle trajectories and posterior density.
We only note here that the PrO posterior recovers precisely the predictively-optimal mixture over \(\Theta\) to recover the truly bimodal data-generating process responsible for Figure 2.
Alquier, P, and M Gerber. 2024. “Universal Robust Regression via Maximum Mean Discrepancy.”Biometrika 111 (1): 71–92. https://doi.org/10.1093/biomet/asad031.
Bissiri, Pier Giovanni, Chris Holmes, and Stephen Walker. 2016. “A GeneralFramework for UpdatingBeliefDistributions.”Journal of the Royal Statistical Society: Series B (Statistical Methodology) 78 (5): 1103–30. https://doi.org/10.1111/rssb.12158.
Chérief-Abdellatif, Badr-Eddine, and Pierre Alquier. 2019. “MMD-Bayes: RobustBayesianEstimation via MaximumMeanDiscrepancy.” arXiv. https://doi.org/10.48550/ARXIV.1909.13339.
Grünwald, Peter, and Thijs van Ommen. 2017. “Inconsistency of BayesianInference for MisspecifiedLinearModels, and a Proposal for RepairingIt.”Bayesian Analysis 12 (4): 1069–1103. https://doi.org/10.1214/17-BA1085.
Knoblauch, Jeremias, Jack Jewson, and Theodoros Damoulas. 2022. “An Optimization-Centric View on Bayes’ Rule: Reviewing and Generalizing Variational Inference.”Journal of Machine Learning Research 23 (132): 1–109. http://jmlr.org/papers/v23/19-1047.html.
Martin, Ryan, and Nicholas Syring. 2022. “Direct Gibbs Posterior Inference on Risk Minimizers: Construction, Concentration, and Calibration.” In Handbook of Statistics, 47:1–41. Elsevier. https://doi.org/10.1016/bs.host.2022.06.004.
McLatchie, Yann, Badr-Eddine Cherief-Abdellatif, David T. Frazier, and Jeremias Knoblauch. 2025. “Predictively OrientedPosteriors.” arXiv. https://doi.org/10.48550/arXiv.2510.01915.
Morningstar, Warren R., Alex Alemi, and Joshua V. Dillon. 2022. “PACm-Bayes: Narrowing the Empirical Risk Gap in the Misspecified Bayesian Regime.” In Proceedings of the 25th International Conference on Artificial Intelligence and Statistics, edited by Gustau Camps-Valls, Francisco J. R. Ruiz, and Isabel Valera, 151:8270–98. Proceedings of Machine Learning Research. PMLR. https://proceedings.mlr.press/v151/morningstar22a.html.
Shen, Zheyang, Jeremias Knoblauch, Sam Power, and Chris J. Oates. 2024. “Prediction-CentricUncertaintyQuantification via MMD.” arXiv. http://arxiv.org/abs/2410.11637.