Generalised Bayesian inference in Stan

Stan can easily handle arbitrary loss functions.
Stan
Generalised Bayesian inference
Author

Yann McLatchie

Published

May 10, 2023

Generalising Bayesian inference

Suppose we have a parameter \(\theta \in \Theta\) we wish to infer, then under the Bayesian paradigm, one updates one’s prior beliefs on the parameter \(\pi(\theta)\) through some data observations \(x_{1:n}\) to achieve a posterior belief \(p(\theta\mid x_{1:n})\). We do so after defining the likelihood \(p(x_{1:n} \mid \theta)\) and by performing the belief update

\[\begin{equation} p(\theta\mid x_{1:n}) \propto p(x_{1:n} \mid \theta) \pi(\theta). \end{equation}\]

Under this regime, however, the statistician is subject to the assumption that the likelihood is correctly specified, meaning that there exists \(\theta_0\) such that \(p(x_{1:n}|\theta_0) = \mathbb{P}(x_{1:n})\) exactly (Bernardo and Smith 1994).

While a misspecified model may fit some training data sufficiently well, inference under it becomes unreliable given future heterogeneous or atypical observations (Knoblauch, Jewson, and Damoulas 2019). Generalised Bayesian inference (Bissiri, Holmes, and Walker 2016) offers an extension to the standard Bayesian learning paradigm in which one performs one’s belief update in terms of an arbitrary loss function \(\ell\),

\[\begin{equation} p(\theta\mid x) \propto \exp\{-\ell(\theta, x)\} \pi(\theta). \end{equation}\]

This use of a loss function allows us to relax the assumption of a well-specified likelihood since it no longer features explicitly in the belief update.1 Naturally, setting this loss function equal to the negative log-likelihood recovers standard Bayesian inference.

The primary aim of this post is to demonstrate that these new ways of achieving posteriors can be easily implemented in Stan in the hope that they may find more widespread application.

We concern ourselves herein with proper scoring rule loss functions, a subset of loss functions which continue to depend on the likelihood. We do so because of their stability and computational feasibility (Jewson, Smith, and Holmes 2023), but note that other loss functions should also work with HMC.2

Hamiltonian Monte Carlo and the Stan language

Bayesian posteriors are often computed with so-called Monte Carlo methods, and in particular Hamiltonian Monte Carlo (HMC), which simulates a physical system through the Gibbs distribution (Betancourt and Girolami 2013; Neal 2011)

\[\begin{equation} p(\theta) \propto \exp\left(\frac{U(\theta)}{T}\right) \end{equation}\]

where \(U(\theta)\) is the so-called energy of the system at state \(\theta\), and \(T\) the temperature.

We need not worry ourselves too much with the algorithmic details, since the state-of-the-art, general and efficient HMC sampler is already implemented in the Stan programming language (Carpenter et al. 2017). What is important for us is that since the general Bayesian update can be written in terms of such a Gibbs distribution similarly to the standard Bayesian update (since we need only replace the likelihood with some arbitrary exponentiated loss), the very same tools are immediately available to us. In a word, we can immediately perform generalised Bayesian inference in Stan.

A motivating example: linear regression in Stan

Consider now the simple linear regression

\[\begin{equation} y \sim \textrm{normal}(X\beta,\sigma). \end{equation}\]

Suppose we have observed \(n\) realisations of the data. We are Bayesian, and capitalise on this to define some priors over the regression coefficients \(\beta\) and the variance of the residuals \(\sigma^2\). For simplicity, we take the regression coefficients to be independent and identically distributed according to the standard Gaussian, and \(\sigma\) to follow a truncated standard Gaussian.3 This defines the prior over our model parameters, leaving us only to choose our loss function.

We presently consider three different loss functions:

  1. the negative log likelihood;
  2. the CRPS (Gneiting and Raftery 2007);
  3. the Hyvärinen score (Hyvärinen 2005).

In generalised Bayesian inference, we often multiply the loss function by some constant called the “learning rate” to ensure calibration of the posterior. Choosing this parameter is in general non-trivial, and has a direct interpretation in terms of its target (Wu and Martin 2023). For simplicity, and since this is not the primary focus of this post, we fix the learning rate to be 1.4

The standard Bayesian treatment

Begin then with the negative log-likelihood, or the standard Bayesian treatment of the inference problem. In Stan, the complete model is implemented as follows.

data {
  int<lower=0> N;
  int<lower=0> p;
  matrix[N, p] x;
  vector[N] y;
}
parameters {
  real<lower=0> sigma;
  vector[p] beta;
}
model {
  // priors
  target += normal_lpdf(sigma | 0, 1);
  for (j in 1:p) {
    target += normal_lpdf(beta[j] | 0, 1);
  }
  // likelihood
  target += normal_lpdf(y | x * beta, sigma);
}

Here, the data block defines the data input to the model, the parameters block introduces our model parameters, and the model block defines the priors and likelihood, which are implicitly combined in Stan to form the posterior of the model. Then fitting the model in R requires only a few lines of code.

library(simstudy)
library(cmdstanr)
library(posterior)
library(dplyr)
library(ggplot2)
SEED <- 1234
set.seed(SEED)

# define experiment vars 
N <- 1e2
p <- 2
sigma <- 1

# generate data
def <- defRepeat(nVars = p, prefix = "x", formula = "0",
                 variance = "1", dist = "normal")
def <- defData(def, "y", formula = "x1 * 1 - x2 * 0.5", 
               variance = "..sigma", dist = "normal")
dd <- genData(N, def)

# produce Stan data
y <- dd$y
x <- as.matrix(dd)[, paste0("x", 1:p)]
stan_data <- list(N = N,
                  p = p,
                  x = x,
                  y = y)

# compile and fit the stan model
exec_llk <- cmdstan_model(stan_file = "stan/linreg_llk.stan")
fit_llk <- exec_llk$sample(data = stan_data,
                           chains = 4,
                           parallel_chains = 4,
                           refresh = 0,
                           seed = SEED)
Running MCMC with 4 parallel chains...

Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.

All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.2 seconds.
# investigate parameter posteriors
summarise_draws(fit_llk$draws(), "mean", "sd", "rhat", "mcse_mean", "mcse_sd")
# A tibble: 104 × 6
   variable       mean     sd  rhat mcse_mean  mcse_sd
   <chr>         <dbl>  <dbl> <dbl>     <dbl>    <dbl>
 1 lp__       -143.    1.21    1.00  0.0266   0.0280  
 2 sigma         0.978 0.0691  1.00  0.00111  0.00103 
 3 beta[1]       1.05  0.0963  1.00  0.00143  0.00158 
 4 beta[2]      -0.403 0.0944  1.00  0.00142  0.00152 
 5 log_lik[1]   -1.03  0.0831  1.00  0.00135  0.00126 
 6 log_lik[2]   -1.18  0.0496  1.00  0.000734 0.000703
 7 log_lik[3]   -0.908 0.0709  1.00  0.00116  0.00101 
 8 log_lik[4]   -1.31  0.215   1.00  0.00334  0.00377 
 9 log_lik[5]   -0.971 0.0688  1.00  0.00107  0.000969
10 log_lik[6]   -1.17  0.0506  1.00  0.000794 0.000740
# ℹ 94 more rows

Here we show the mean and standard deviation of the parameter posterior, along with their \(\hat{R}\) convergence diagnostic metric. Heuristically, we can be confident that the chains have mixed sufficiently well if \(\hat{R}\) values are all smaller than 1.01 (Gelman and Rubin 1992; Vehtari et al. 2021). As such, we conclude that chains have mixed well, and that the posterior estimates can be trusted.

CRPS-based inference

The continuous ranked probability score (CRPS) is defined, for a probabilistic forecast \(y\) and its associated cumulative distribution function \(F\), as

\[\begin{equation} \mathrm{CRPS}(F, y) = \int_{-\infty}^\infty (F(z) - \mathbb{1}(z \geq y))^2\,\mathrm{d}z \end{equation}\]

and is expressable in closed form for the Gaussian observation family (Gneiting and Raftery 2007). The CRPS was originally motivated as an alternative to scoring rules admitting a point mass at zero (e.g. probability densities), and for comparing probabilistic and non-probabilistic models.

We have so far hidden the fact that the full model likelihood is written as the product of the likelihood computed at all data points. Likewise, when we perform score-based inference, our posterior is written more explicitly as

\[\begin{equation} p(\theta\mid x_{1:n}) \propto \exp\left\{-\sum_{i=1}^n\ell(\theta, x_i)\right\} \pi(\theta). \end{equation}\]

As such, in Stan we need only define the CRPS in a new functions block, and sum over the score evaluated at each observation in the model block.

functions {
  real crps_norm(real y, real location, real scale) {
    // normalise the data
    real y_norm = y - location;
    real z = y_norm / scale;
    // compute CRPS
    return (
      y_norm * (2 * normal_cdf(y_norm, 0, scale) - 1) 
      + scale * (sqrt(2) * exp(-0.5 * z^2) - 1) / sqrt(pi())
    );
  }
}
...
model {
  // priors (unchanged)
  target += normal_lpdf(sigma | 0, 1);
  for (j in 1:p) {
    target += normal_lpdf(beta[j] | 0, 1);
  }
  // the CRPS
  for (n in 1:N) {
    target += -1 * crps_norm(y[n], x[n] * beta, sigma);
  }
}

We have no need to change the rest of the model, since our priors and the data we feed into the model remain unchanged.5 Again, fitting this model in R requires only a few lines of code.

# compile and fit the stan model
exec_crps <- cmdstan_model(stan_file = "stan/linreg_crps.stan")
fit_crps <- exec_crps$sample(data = stan_data,
                             chains = 4,
                             parallel_chains = 4,
                             refresh = 0,
                             seed = SEED)
Running MCMC with 4 parallel chains...

Chain 1 finished in 0.2 seconds.
Chain 2 finished in 0.2 seconds.
Chain 3 finished in 0.2 seconds.
Chain 4 finished in 0.2 seconds.

All 4 chains finished successfully.
Mean chain execution time: 0.2 seconds.
Total execution time: 0.3 seconds.
# investigate parameter posteriors
summarise_draws(fit_crps$draws(), "mean", "sd", "rhat", "mcse_mean", "mcse_sd")
# A tibble: 204 × 6
   variable      mean     sd  rhat mcse_mean mcse_sd
   <chr>        <dbl>  <dbl> <dbl>     <dbl>   <dbl>
 1 lp__       -58.8   1.26   1.00    0.0284  0.0304 
 2 sigma        0.921 0.179  0.999   0.00295 0.00280
 3 beta[1]      0.984 0.131  1.00    0.00211 0.00200
 4 beta[2]     -0.425 0.126  1.00    0.00199 0.00203
 5 log_lik[1]  -0.957 0.178  1.00    0.00306 0.00258
 6 log_lik[2]  -1.18  0.0958 1.00    0.00151 0.00194
 7 log_lik[3]  -0.858 0.185  0.999   0.00310 0.00305
 8 log_lik[4]  -1.20  0.321  1.00    0.00561 0.00637
 9 log_lik[5]  -0.922 0.168  0.999   0.00281 0.00274
10 log_lik[6]  -1.20  0.0990 1.00    0.00167 0.00418
# ℹ 194 more rows

The individual parameter inferences do not differ greatly between the standard Bayesian posterior and the CRPS posterior, and the \(\hat{R}\) metrics remain below the heuristic threshold. This is all good news! Now let’s change tack.

Hyvärinen inference

The second proper scoring rule we investigate is the Hyvärinen score (Hyvärinen 2005). While the CRPS is defined in terms of the cumulative density, the Hyvärinen score is instead expressed in terms of the first and second order derivatives of the log likelihood. Formally,

\[\begin{equation} \mathcal{H}(p, y) = 2\Delta_y \log p(y) + \lVert\nabla_y \log p(y)\rVert^2, \end{equation}\]

where \(\nabla_y\) and \(\Delta_y\) denote the first and second partial derivatives with respect to \(y\) respectively. The Hyvärinen score evaluates the unnormalised predictive density, bypassing the requirement to compute the normalising constant, and has previously been used to produce posterior distributions by Giummolè et al. (2019, although not with HMC) and more recently by Altamirano, Briol, and Knoblauch (2023, using a robust version of the score). Once again, in Stan this requires only adding a definition of the scoring rule evaluated at one observation in the functions block, and updating the model block accordingly.

functions {
  real fprime(real y, real m, real s){
    return(-(y-m)/square(s));
  }
  real fprime_prime(real y, real m, real s){
    return(-1 / square(s));
  }
  real hyva_norm(real y, real location, real scale) {
    // compute Hyvarinen score
    real fp = fprime(y, location, scale);
    real fpp = fprime_prime(y, location, scale);
    return 2*fpp + square(fp);
  }
}
...
model {
  // priors (unchanged)
  target += normal_lpdf(sigma | 0, 1);
  for (j in 1:p) {
    target += normal_lpdf(beta[j] | 0, 1);
  }
  // the Hyvärinen score
  for (n in 1:N) {
    target += -1 * hyva_norm(y[n], x[n] * beta, sigma);
  }
}

The Stan math library exposes the gradient and Hessian methods, which might be used to sample from the target Hyvärinen posterior more generally. See this terrific implementation by Andrew Johnson of the Hyvärinen score for the Gaussian family in Stan with external C++ autodiff for such an example.

Finally then, we fit the Hyvärinen model in R in the same manner as the two losses before it.

# compile and fit the stan model
exec_hyva <- cmdstan_model(stan_file = "stan/linreg_hyva.stan")
fit_hyva <- exec_hyva$sample(data = stan_data,
                             chains = 4,
                             parallel_chains = 4,
                             refresh = 0,
                             seed = SEED)
Running MCMC with 4 parallel chains...

Chain 1 finished in 0.2 seconds.
Chain 2 finished in 0.2 seconds.
Chain 3 finished in 0.2 seconds.
Chain 4 finished in 0.2 seconds.

All 4 chains finished successfully.
Mean chain execution time: 0.2 seconds.
Total execution time: 0.2 seconds.
# investigate parameter posteriors
summarise_draws(fit_hyva$draws(), "mean", "sd", "rhat", "mcse_mean", "mcse_sd")
# A tibble: 204 × 6
   variable      mean     sd  rhat mcse_mean  mcse_sd
   <chr>        <dbl>  <dbl> <dbl>     <dbl>    <dbl>
 1 lp__       103.    1.27    1.00  0.0285   0.0333  
 2 sigma        0.970 0.0335  1.00  0.000534 0.000557
 3 beta[1]      1.05  0.0671  1.00  0.00106  0.00100 
 4 beta[2]     -0.405 0.0649  1.00  0.00107  0.00105 
 5 log_lik[1]  -1.03  0.0522  1.00  0.000822 0.000843
 6 log_lik[2]  -1.17  0.0317  1.00  0.000527 0.000546
 7 log_lik[3]  -0.899 0.0357  1.00  0.000580 0.000632
 8 log_lik[4]  -1.30  0.150   1.00  0.00236  0.00233 
 9 log_lik[5]  -0.962 0.0383  1.00  0.000634 0.000642
10 log_lik[6]  -1.16  0.0314  1.00  0.000502 0.000494
# ℹ 194 more rows

These chains also seem to have mixed well, and are recovering posterior estimates close to the true parameter values. The posterior parameter distributions achieved with the Hyvärinen loss function (and the negative log likelihood) are much more concentrated around their posterior means than the CRPS-based posterior. This may be because the Hyvärinen score is not upper-bounded, while the CPRS is upper-bounded by 2. Because we fix the learning rate across all losses, our posteriors will be more sensitive to the loss/prior trade-off regulated by the learning rate.

Visualising the posteriors

Performing full Bayesian inference means that we achieve more than just the point estimates of the first two moments. Indeed, we can visualise the full parameter posteriors induced by the three different loss functions we have investigated.

# extract the parameter draws
draws_llk <- as_draws_df(fit_llk$draws(variables = c("sigma", "beta"))) 
draws_crps <- as_draws_df(fit_crps$draws(variables = c("sigma", "beta")))
draws_hyva <- as_draws_df(fit_hyva$draws(variables = c("sigma", "beta")))
draws_llk$loss <- "log-likelihood"
draws_crps$loss <- "CRPS"
draws_hyva$loss <- "Hyvärinen"
draws_df <- rbind(draws_llk, draws_crps, draws_hyva) |> 
  dplyr::select(-c(.chain, .iteration, .draw)) |>
  reshape2::melt()

# plot the densities
draws_df |> ggplot(aes(value, colour = loss)) +
  geom_density(size=1.5) +
  facet_wrap(~variable, scales = "free") +
  ylab(NULL) +
  xlab(NULL) +
  theme_bw() +
  theme(plot.background = element_blank(),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank())

We see now more clearly the different shapes the posteriors assume under the different losses (all with the same fixed learning rate). The Hyvärinen posterior is much more concentrated than the other two losses, while the CRPS parameter posteriors are wider. Again, this may be due to the learning rate component which was not considered in great detail.

Recap

Generalised Bayesian inference alleviates the tension between the assumptions underpinning the standard Bayesian update through an optimisation-centric generalisation: replacing the negative log likelihood with an arbitrary loss function. Proper scoring rules are apt candidates for loss functions given their computational feasibility and theoretical characteristics. Since the generalised Bayesian posterior can still be written as a Gibbs distribution, we can still use HMC to compute them numerically. This is easy to implement in the Stan programming language.

It is worth noting, however, that the model we considered in this post was very simple, and that as model complexity increases, certain loss functions may struggle to converge as easily.6

References

Altamirano, Matias, François-Xavier Briol, and Jeremias Knoblauch. 2023. “Robust and Scalable Bayesian Online Changepoint Detection.” arXiv. http://arxiv.org/abs/2302.04759.
Bernardo, José M., and Adrian F. M. Smith. 1994. Bayesian Theory. John Wiley & Sons.
Betancourt, M. J., and Mark Girolami. 2013. “Hamiltonian Monte Carlo for Hierarchical Models.” arXiv. http://arxiv.org/abs/1312.0906.
Bissiri, Pier Giovanni, Chris Holmes, and Stephen Walker. 2016. “A General Framework for Updating Belief Distributions.” Journal of the Royal Statistical Society: Series B (Statistical Methodology) 78 (5): 1103–30. https://doi.org/10.1111/rssb.12158.
Carpenter, Bob, Andrew Gelman, Matthew D. Hoffman, Daniel Lee, Ben Goodrich, Michael Betancourt, Marcus Brubaker, Jiqiang Guo, Peter Li, and Allen Riddell. 2017. Stan : A Probabilistic Programming Language.” Journal of Statistical Software 76 (1). https://doi.org/10.18637/jss.v076.i01.
Gelman, Andrew, and Donald B. Rubin. 1992. “Inference from Iterative Simulation Using Multiple Sequences.” Statistical Science 7 (4). https://doi.org/10.1214/ss/1177011136.
Giummolè, Federica, Valentina Mameli, Erlis Ruli, and Laura Ventura. 2019. “Objective Bayesian Inference with Proper Scoring Rules.” TEST 28 (3): 728–55. https://doi.org/10.1007/s11749-018-0597-z.
Gneiting, Tilmann, and Adrian E Raftery. 2007. “Strictly Proper Scoring Rules, Prediction, and Estimation.” Journal of the American Statistical Association 102 (477): 359–78. https://doi.org/10.1198/016214506000001437.
Hyvärinen, Aapo. 2005. “Estimation of Non-Normalized Statistical Models by Score Matching.” Journal of Machine Learning Research 6 (24): 695–709. http://jmlr.org/papers/v6/hyvarinen05a.html.
Jewson, Jack, Jim Q. Smith, and Chris Holmes. 2023. “On the Stability of General Bayesian Inference.” arXiv. http://arxiv.org/abs/2301.13701.
Knoblauch, Jeremias, Jack Jewson, and Theodoros Damoulas. 2019. “Generalized Variational Inference: Three Arguments for Deriving New Posteriors.” arXiv. https://doi.org/10.48550/arXiv.1904.02063.
Neal, Radford M. 2011. MCMC Using Hamiltonian Dynamics. https://doi.org/10.1201/b10905.
Vehtari, Aki, Andrew Gelman, Daniel Simpson, Bob Carpenter, and Paul-Christian Bürkner. 2021. “Rank-Normalization, Folding, and Localization: An Improved for Assessing Convergence of MCMC (with Discussion).” Bayesian Analysis 16 (2): 667–718. https://doi.org/10.1214/20-BA1221.
Wu, Pei-Shien, and Ryan Martin. 2023. “A Comparison of Learning Rate Selection Methods in Generalized Bayesian Inference.” Bayesian Analysis 18 (1). https://doi.org/10.1214/21-BA1302.

Footnotes

  1. It’s worth noting that the log score may still be useful for diagnosing misspecified models, and if we’re interested in learning tail shapes which other loss functions are less sensitive to.↩︎

  2. I do not know exactly which characteristics of the loss function are theoretically sufficient for compatibility with HMC, but we will see that it is not a very restrictive method.↩︎

  3. In general, such arbitrary definition of the priors is at best a waste of the potential they afford, and at worst putting oneself at risk of over-fitting. This may be the subject of a later post.↩︎

  4. This induces some local and global conditions on the prior for posterior consistency, but we aren’t too worried about these for this toy example.↩︎

  5. One might wonder whether the statistician’s prior beliefs change depending on the loss function they choose to optimise over: perhaps the topic of a later blog post.↩︎

  6. This is something that has been discussed in the literature, but I might demonstrate it numerically in a later post.↩︎

Citation

BibTeX citation:
@online{mclatchie2023,
  author = {Yann McLatchie},
  title = {Generalised {Bayesian} Inference in {Stan}},
  date = {2023-05-10},
  url = {https://yannmclatchie.github.io/blog/posts/gen-bayes-stan/},
  langid = {en}
}
For attribution, please cite this work as:
Yann McLatchie. 2023. “Generalised Bayesian Inference in Stan.” May 10, 2023. https://yannmclatchie.github.io/blog/posts/gen-bayes-stan/.