Can your samples go the distance?

Empirical approximations of statistical distances and how they stack up.
Divergence
Approximations
Author

Yann McLatchie

Published

August 15, 2024

Setup code
library(ggplot2)
library(tidyverse)
set.seed(1234)

I recently came across a paper by Drovandi and Frazier (2022) where they investigate different sample approximations to statistical distances for approximate Bayesian computation (ABC). I’ve needed to compute some of these sample approximations in the past, so today we’re going to have a look at how we might implement them1, and how they behave.

Some statistical distances

Let’s limit our discussion a bit: we’re only going to consider continuous distributions to keep things simple.

Wasserstein distance

We will follow Drovandi and Frazier (2022) and use \(\mathcal{P}_p(\mathcal{Y})\) to denote the collection of probability measures \(\mu\) on \(\mathcal{Y}\) admitting a finite \(p\)-th moment. Then for two probability measures \(\mu,\nu \in\mathcal{P}_p(\mathcal{Y})\) with cumulative distribution functions \(F_\mu\) and \(F_\nu\) respectively, the \(p\)-Wasserstein distance between them is defined as

\[ \mathcal{W}_p(\mu, \nu) = \left( \int_{0}^{1} \vert F_\mu^{-1}(x) - F_\nu^{-1}(x) \vert^p \,\mathcal{d}x \right)^{1/p}. \]

When it comes to a sample approximation, the \(1\)-Wasserstein distance between \(y_{1:n}\sim\mu\)2 and \(z_{1:n}\sim\nu\)3, reduces neatly to

\[ \widehat{\mathcal{W}}_p(y_{1:n}, z_{1:n}) = n^{-1}\sum_{i=1}^n\vert y_i - z_i \vert, \]

the \(L_1\) norm between \(y\) and \(z\). There is a lot more theory behind this family of distances, for which the book by Villani (2009) is a good reference.

Implementation in R
# compute the sample 1-Wasserstein distance
sample_wass <- function(y, z) {
  dist <- mean(abs(y - z))
  return(dist)
}

Energy distance

We now denote \(Y_1, Z_1 \in \mathcal{Y}\) two independent random variables distributed according to \(\mu\) and \(\nu\) respectively. Likewise, \(Y_2, Z_2 \in \mathcal{Y}\) are random variables with the same distribution as \(Y_1\) and \(Z_1\), but are independent of them. Then for \(p\geq 1\), the \(p\)-th energy distance between \(\mu\) and \(\nu\) is

\[ \mathcal{E}_p(\mu, \nu) = 2\mathbb{E}\lVert Y_1 - Z_1 \rVert_p - \mathbb{E}\lVert Z_1 - Z_2 \rVert_p - \mathbb{E}\lVert Y_1 - Y_2 \rVert_p. \]

Nguyen et al. (2020) help us out here and propose to approximate it with

\[ \widehat{\mathcal{E}}_p(y_{1:n}, z_{1:n}) = \frac{2}{n^2} \sum_{i=1}^n \sum_{j=1}^n \lVert y_i - z_j \rVert_p - \frac{1}{n^2} \sum_{i=1}^n \sum_{j=1}^n \lVert z_i - z_j \rVert_p - \frac{1}{n^2} \sum_{i=1}^n \sum_{j=1}^n \lVert y_i - y_j \rVert_p. \]

We’re only going to consider the univariate case with \(p = 1\) going forward.

Implementation in R
# compute the sample energy distance
sample_energy <- function(y, z) {
  n <- length(y)
  EYZ <- outer(y, z, function(a, b) abs(a - b)) |> sum()
  EYY <- outer(y, y, function(a, b) abs(a - b)) |> sum()
  EZZ <- outer(z, z, function(a, b) abs(a - b)) |> sum()
  dist <- 2 / n^2 * EYZ - 1 / n^2 * EYY - 1 / n^2 * EZZ
  return(dist)
}

Maximum mean discrepancy

The energy distance is a specific case of the family of maximum mean discrepancy (MMD) distances. They’ve been used, for example, by Dellaporta et al. (2022) in generalised Bayesian inference.4 For \(Y_1, Y_2, Z_1, Z_2\) defined as above, and \(k\) a symmetric, continuous, and is positive-definite kernel, the MMD between \(\mu\) and \(\nu\) is

\[ \mathrm{MMD}^2(\mu,\nu) = \mathbb{E}[k(Y_1, Y_2)] + \mathbb{E}[k(Z_1, Z_2)] - 2\mathbb{E}[k(Y_1, Z_1)]. \]

This can be approximated5 by

\[\begin{multline} \widehat{\mathrm{MMD}}^2(y_{1:n}, z_{1:n}) = \frac{1}{n(n - 1)} \sum_{i=1}^n \sum_{j\ne i} k(y_i, y_j) + \frac{1}{n(n - 1)} \sum_{i=1}^n \sum_{j\ne i} k(z_i, z_j) \\ - \frac{1}{n^2} \sum_{i=1}^n\sum_{j=1}^n k(y_i, z_j). \end{multline}\]

Going forward, we’re going to take the kernel to a squared exponential so that for lengthscale \(\gamma\),

\[ k(x, y) = \exp\left(-\frac{\lVert y - x\rVert^2}{2\gamma^2}\right). \]

Implementation in R
# define a squared exponential kernel
sq_exp_kernel <- function(x, y, bandwidth = 1) {
  k <- exp(-sum((x - y)^2) / (2 * bandwidth^2))
  return(k)
}
sq_exp_kernel <- Vectorize(sq_exp_kernel)
# compute the sample MMD
sample_mmd <- function(y, z){
  n <- length(y)
  
  # compute the kernel matrices
  k_yy <- outer(y, y, sq_exp_kernel)
  k_zz <- outer(z, z, sq_exp_kernel)
  k_yz <- outer(y, z, sq_exp_kernel)

  # remove diagonal elements for summation
  diag(k_yy) <- NA
  diag(k_zz) <- NA
  
  # compute the MMD components
  mmd2_yy <- sum(k_yy, na.rm = T) / (n * (n - 1))
  mmd2_zz <- sum(k_zz, na.rm = T) / (n * (n - 1))
  mmd2_yz <- -sum(k_yz) / (n^2)
  
  # compute and return distance
  dist <- mmd2_yy + mmd2_zz + mmd2_yz
  return(dist)
}

Kullback-Leibler divergence

Now the big one. This was the reason I started writing this post: I wanted to see if there were any helpful sample approximations to the Kullback-Leibler (KL) divergence for some work on prior choice I was doing. This is what I found.

Denote the densities of \(\mu\) and \(\nu\) by \(p\) and \(q\) respectively, then the KL divergence of \(\nu\) from \(\mu\)6 is

\[ \mathrm{KL}(\mu, \nu) = \int p(x) \log\left\{\frac{p(x)}{q(x)}\right\}\,\mathcal{d}x. \]

Perez-Cruz (2008) proposed to approximate this with a nearest-neighbour approach:

\[ \widehat{\mathrm{KL}}(y_{1:n}, z_{1:n}) = n^{-1}\sum_{i=1}^n \log\left( \frac{\min_j \lVert z_i - y_j \rVert}{\min_{j\ne i} \lVert z_i - z_j \rVert} \right) + \log\left(\frac{n}{n - 1}\right). \]

Implementation in R
# compute the sample KL distance
sample_kl <- function(y, z) {
  # compute pair-wise distances
  distances_yz <- outer(z, y, function(a, b) abs(a - b))
  distances_zz <- outer(z, z, function(a, b) abs(a - b))
  
  # compute the minimum distance for each point in z to all points in y
  min_distances_yz <- apply(distances_yz, 2, min)
  
  # compute the minimum distance for each point in z to all other points in z, excluding itself
  min_distances_zz <- sapply(1:nrow(distances_zz), function(i) {
    min(distances_zz[i, -i])
  })
  
  # compute the log ratio of the minimum distances
  log_ratio <- log(min_distances_yz / min_distances_zz)
  
  # compute the sample KL divergence
  dist <- mean(log_ratio) + log(length(y) / (length(y) - 1))
  return(dist)
}

OK, so how good are these approximations?

Let’s take two univariate Gaussians $$ = (m_1,s_1^2)= (m_2, s_2^2),

$$ and sample \(y_{1:n}\sim\mu,\,z_{1:n}\sim\nu\).

Let’s now consider \(m_1 = 0, m_2 = 5, s_1 = 0.5, s_2 = 2\), so that our two distributions look as follows.

Show code
# moments
m_1 <- 0
m_2 <- 5
s_1 <- 0.5
s_2 <- 2

# plot densities
ggplot() +
  xlim(-5, 10) + 
  geom_function(fun = dnorm, 
                args = list(mean = m_1, sd = s_1),
                aes(colour = "mu")) +
  geom_function(fun = dnorm, 
                args = list(mean = m_2, sd = s_2),
                aes(colour = "nu")) +
  scale_colour_discrete(labels=c(expression(mu), expression(nu))) +
  theme_classic() +
  theme(legend.title=element_blank(),
        axis.line.y=element_blank(),
        axis.title.y=element_blank(),
        axis.text.y=element_blank(),
        axis.ticks.y=element_blank())

Estimator convergence and variance: loosely speaking

Now we’re going to sample \(25\) data replicates of size \(n\) from \(\mu\) and \(\nu\). For each data replicate we will compute the empirical approximation of each of the distances presented above and record the computation time. We do this over a logarithmic grid of \(n\), and show the empirical average over data replicates, and \(5\%\) and \(95\%\) quantiles, in the following plot.

Show code
# experiment settings
num_iters <- 25
num_ns <- 12
iters <- 1:num_iters
ns <- 2^(seq(1, num_ns, length.out = num_ns))
distances <- c("energy", "kl", "mmd", "wass")

# calculate some distances between two Gaussians
gauss_dist <- function(iter, n, dist_name, m_1, m_2, s_1, s_2) {
  # simulate data from the two distributions
  y <- rnorm(n, m_1, s_1)
  z <- rnorm(n, m_2, s_2)
  
  # calculate distance
  fun_name <- paste("sample", dist_name, sep="_")
  start_time <- Sys.time()
  dist <- do.call(fun_name, list(y, z))
  end_time <- Sys.time()
  
  # and log the time taken
  time_taken <- end_time - start_time
  
  # return the distances
  return(list(iter = iter,
              n = n,
              dist_name = dist_name,
              dist = dist,
              comp_time = time_taken))
}

run_experiment <- function(iters, ns, distances, read_from_file = T,
                           save_to_file = F) {
  # define file name
  file_name <- "distances.csv"
  
  # check whether to rerun or just read
  if (!read_from_file) {
    # evaluate the tvd across combinations
    combis <- expand.grid(iter = iters, n = ns, dist_name = distances)
    df <- combis |>
      pmap(\(iter, n, dist_name) gauss_dist(iter = iter, 
                                            n = n,
                                            dist_name = dist_name,
                                            m_1 = m_1,
                                            m_2 = m_2,
                                            s_1 = s_1,
                                            s_2 = s_2),
           .progress = TRUE) |>
      bind_rows()
    
    if (save_to_file) {
      # save resutls to csv to avoid rerunning 
      write_csv(df, file = file_name)
    }
  } else {
    # read from file
    df <- read_csv(file_name)
  }
  # return experiment results
  return(df)
}

# run the experiment
df <- run_experiment(iters, ns, distances)

# produce ribbons for figures
gdf <- df |>
  group_by(n, dist_name) |>
  summarize(mean = mean(dist),
            lwr = quantile(dist, probs = 0.05),
            upr = quantile(dist, probs = 0.95),
            time = mean(comp_time))

# plot figures
ggplot() +
  geom_ribbon(data = gdf,
              aes(ymin = lwr,
                  ymax = upr,
                  x = n),
              fill = "grey",
              colour = "black",
              alpha = 0.3,
              linetype = "dotted") +
  geom_line(data = gdf,
            aes(y = mean,
                x = n),
            colour = "black") +
  facet_wrap(~ dist_name, scales = "free_y", ncol = 2) +
  scale_x_continuous(trans = "log2") +
  xlab(expression(n)) +
  ylab("distance") +
  theme_bw() +
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        strip.background = element_blank(),
        panel.background = element_blank(),
        legend.position="none")

The first thing I notice is that the sample approximation to the KL divergence goes below zero for some data replicates when \(n\) is small, which isn’t ideal.7 It also doesn’t seem to be converging as quickly as the three other distances. In this case, the other approximations look like they have converged after even \(n \approx 50\), while the variation in the KL approximation is only comparable after \(n \approx 4,000\).

Approximation error

OK so there is some notion of convergence, but how good are they? Let’s just look at the 1-Wasserstein and KL divergences, since those are the ones we can easily compute for the Gaussian case. For instance, we can easily compute the 1-Wasserstein distance with quadrature8 to find that \(W_1(\mu,\nu) = 5\), and likewise, \[ \mathrm{KL}(\mu, \nu) = \log\left(\frac{s_2}{s_1}\right) + \frac{s_1^2 + (m_1 - m_2)^2}{2s_2^2} - \frac{1}{2} \approx 4. \]

Show code
# compute the true distances where possible
true_kl <- log(s_2 / s_1) + (s_1^2 + (m_1 - m_2)^2) / (2 * s_2^2) - 0.5
integrand <- function(x) abs(pnorm(x, m_1, s_1) - pnorm(x, m_2, s_2))
integral <- integrate(integrand, lower = -Inf, upper = Inf)
true_wass <- integral$value

# compute the error over iterations
kl_error_df <- df |>
  filter(dist_name == "kl") |>
  mutate(error = dist - true_kl)
wass_error_df <- df |>
  filter(dist_name == "wass") |>
  mutate(error = dist - true_wass)
error_df <- rbind(kl_error_df, wass_error_df)

# produce ribbons for figures
error_gdf <- error_df |>
  group_by(n, dist_name) |>
  summarize(mean = mean(error),
            lwr = quantile(error, probs = 0.05),
            upr = quantile(error, probs = 0.95))

# plot figures
ggplot() +
  geom_hline(yintercept = 0, colour = "black",
             linetype = "dashed") +
  geom_ribbon(data = error_gdf,
              aes(ymin = lwr,
                  ymax = upr,
                  x = n),
              fill = "grey",
              colour = "black",
              alpha = 0.3,
              linetype = "dotted") +
  geom_line(data = error_gdf,
            aes(y = mean,
                x = n),
            colour = "black") +
  facet_wrap(~ dist_name, scales = "fixed") +
  scale_x_continuous(trans = "log2") +
  xlab(expression(n)) +
  ylab("error") +
  theme_bw() +
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        strip.background = element_blank(),
        panel.background = element_blank(),
        legend.position="none")

The \(1\)-Wasserstein approximation seems empirically unbiased, while the KL approximation consistently under-estimates the true distance in this experiment. The above references have some proofs for the consistency of these estimators for those interested in their theoretical guarantees.

Computational cost

Another thing people care about is time. So how much of it did I waste running these calculations? And which distance measure might save me the most going forward?9

Show code
# ensure that time is in a usable format
gdf$time <- as.double(gdf$time)

# plot the figure
ggplot() +
  geom_line(data = gdf,
            aes(y = time,
                x = n,
                colour = dist_name)) +
  scale_x_continuous(trans = "log",
                     breaks = c(2, 10, 100, 1000, 4000)) +
  scale_y_continuous(trans = "log",
                     breaks = 10^seq(-5, 2, by = 1),
                     labels = scales::label_scientific(digits = 1)) +
  scale_colour_discrete(name = "Distance") +
  labs(title = "Average time per distance evaluation (log-log scale)") +
  xlab(expression(n)) +
  ylab("time (seconds)") +
  theme_bw() +
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        strip.background = element_blank(),
        panel.background = element_blank())

The \(1\)-Wasserstein distance looks super efficient next to the others, with the MMD taking \(\approx 100\) seconds to compute for \(n \approx 4,000\)!

References

Dellaporta, Charita, Jeremias Knoblauch, Theodoros Damoulas, and Francois-Xavier Briol. 2022. “Robust Bayesian Inference for Simulator-Based Models via the MMD Posterior Bootstrap.” In Proceedings of The 25th International Conference on Artificial Intelligence and Statistics, edited by Gustau Camps-Valls, Francisco J. R. Ruiz, and Isabel Valera, 151:943–70. Proceedings of Machine Learning Research. PMLR. https://proceedings.mlr.press/v151/dellaporta22a.html.
Drovandi, Christopher, and David T. Frazier. 2022. “A Comparison of Likelihood-Free Methods with and Without Summary Statistics.” Statistics and Computing 32 (3): 42. https://doi.org/10.1007/s11222-022-10092-4.
Nguyen, Hien Duy, Julyan Arbel, Hongliang Lü, and Florence Forbes. 2020. “Approximate Bayesian Computation via the Energy Statistic.” IEEE Access 8: 131683–98. https://doi.org/10.1109/ACCESS.2020.3009878.
Perez-Cruz, Fernando. 2008. “Kullback-Leibler Divergence Estimation of Continuous Distributions.” In 2008 IEEE International Symposium on Information Theory, 1666–70. https://doi.org/10.1109/ISIT.2008.4595271.
Villani, Cédric. 2009. Optimal Transport. Edited by M. Berger, B. Eckmann, P. De La Harpe, F. Hirzebruch, N. Hitchin, L. Hörmander, A. Kupiainen, et al. Vol. 338. Grundlehren Der Mathematischen Wissenschaften. Berlin, Heidelberg: Springer Berlin Heidelberg. https://doi.org/10.1007/978-3-540-71050-9.

Footnotes

  1. “might” being the operative word.↩︎

  2. “iid sampling”.↩︎

  3. Going forward we’ll assume that the number of samples from \(\mu\) and \(\nu\) is the same for simplicity, denoted \(n\), but this can be loosened in some cases.↩︎

  4. Read my previous blog post to find out more about generalised Bayes!↩︎

  5. There might be more efficient ways of estimating this quantity, but I just wanted to do something quick for the purposes of this post.↩︎

  6. Since this isn’t a symmetric distance.↩︎

  7. Since we know that for all probability measures \(\mu, \nu\), the KL divergence \(\mathrm{KL}(\mu, \nu) \geq 0\).↩︎

  8. Just to keep things simple.↩︎

  9. I’m not going to claim that my R code is even close to optimal, so these comparisons should be taken with a centurion’s wage of salt.↩︎

Citation

BibTeX citation:
@online{mclatchie2024,
  author = {Yann McLatchie},
  title = {Can Your Samples Go the Distance?},
  date = {2024-08-15},
  url = {https://yannmclatchie.github.io/blog/posts/distances},
  langid = {en}
}
For attribution, please cite this work as:
Yann McLatchie. 2024. “Can Your Samples Go the Distance?” August 15, 2024. https://yannmclatchie.github.io/blog/posts/distances.