Markov Chain Monte Carlo With Turing

Overview

This tutorial will give some examples of using Turing.jl and Markov Chain Monte Carlo to sample from posterior distributions.

Setup

using Turing
using Distributions
using Plots
default(fmt = :png) # the tide gauge data is long, this keeps images a manageable size
using LaTeXStrings
using StatsPlots
using Measures
using StatsBase
using Optim
using Random
using DataFrames
using DataFramesMeta
using Dates
using CSV

As this tutorial involves random number generation, we will set a random seed to ensure reproducibility.


Random.seed!(1);

Fitting A Linear Regression Model

Let’s start with a simple example: fitting a linear regression model to simulated data.

Positive Control Tests

Simulating data with a known data-generating process and then trying to obtain the parameters for that process is an important step in any workflow.

Simulating Data

The data-generating process for this example will be: \[ \begin{gather} y = 5 + 2x + \varepsilon \\ \varepsilon \sim \text{Normal}(0, 3), \end{gather} \] where \(\varepsilon\) is so-called “white noise”, which adds stochasticity to the data set. The generated dataset is shown in Figure 1.

Figure 1: Scatterplot of our generated data.

Model Specification

The statistical model for a standard linear regression problem is \[ \begin{gather} y = a + bx + \varepsilon \\ \varepsilon \sim \text{Normal}(0, \sigma). \end{gather} \]

Rearranging, we can rewrite the likelihood function as: \[y \sim \text{Normal}(\mu, \sigma),\] where \(\mu = a + bx\). This means that we have three parameters to fit: \(a\), \(b\), and \(\sigma^2\).

Next, we need to select priors on our parameters. We’ll use relatively generic distributions to avoid using the information we have (since we generated the data ourselves), but in practice, we’d want to use any relevant information that we had from our knowledge of the problem. Let’s use relatively diffuse normal distributions for the trend parameters \(a\) and \(b\) and a half-normal distribution (a normal distribution truncated at 0, to only allow positive values) for the variance \(\sigma^2\), as recommended by Gelman (2006).

Gelman, A. (2006). Prior distributions for variance parameters in hierarchical models (comment on article by Browne and Draper). Bayesian Anal., 1(3), 515–533. https://doi.org/10.1214/06-BA117A

\[ \begin{gather} a \sim \text{Normal(0, 10)} \\ b \sim \text{Normal(0, 10)} \\ \sigma \sim \text{Half-Normal}(0, 25) \end{gather} \]

Using Turing

Coding the Model

Turing.jl uses the @model macro to specify the model function. We’ll follow the setup in the Turing documentation.

To specify distributions on parameters (and the data, which can be thought of as uncertain parameters in Bayesian statistics), use a tilde ~, and use equals = for transformations (which we don’t have in this case).


@model function linear_regression(x, y)
    # set priors
    σ ~ truncated(Normal(0, 25); lower=0)
    a ~ Normal(0, 10)
    b ~ Normal(0, 10)

    # compute the likelihood
    for i = 1:length(y)
        # compute the mean value for the data point
        μ = a + b * x[i]
        y[i] ~ Normal(μ, σ)
    end
end
1
Standard deviations must be positive, so we use a normal distribution truncated at zero.
2
We’ll keep these both relative uninformative to reflect a more “realistic” modeling scenario.
3
In this case, we specify the likelihood with a loop. We could also rewrite this as a joint likelihood over all of the data using linear algebra, which might be more efficient for large and/or complex models or datasets, but the loop is more readable in this simple case.
linear_regression (generic function with 2 methods)

Fitting The Model

Now we can call the sampler to draw from the posterior. We’ll use the No-U-Turn sampler (Hoffman & Gelman, 2014), which is a Hamiltonian Monte Carlo algorithm (a different category of MCMC sampler than the Metropolis-Hastings algorithm discussed in class). We’ll also use 4 chains so we can test that the chains are well-mixed, and each chain will be run for 5,000 iterations1

Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn sampler: Adaptively setting path lengths in Hamiltonian Monte Carlo. J. Mach. Learn. Res., 15(47), 1593–1623.

1 Hamiltonian Monte Carlo samplers often need to be run for fewer iterations than Metropolis-Hastings samplers, as the exploratory step uses information about the gradient of the statistical model, versus the random walk of Metropolis-Hastings. The disadvantage is that this gradient information must be available, which is not always the case for external simulation models. Simulation models coded in Julia can usually be automatically differentiated by Turing’s tools, however.

# set up the sampler
model = linear_regression(x, y)
n_chains = 4
n_per_chain = 5000
chain = sample(model, NUTS(), MCMCThreads(), n_per_chain, n_chains, drop_warmup=true)
@show chain
1
Initialize the model with the data.
2
We use multiple chains to help diagnose convergence.
3
This sets the number of iterations for each chain.
4
Sample from the posterior using NUTS and drop the iterations used to warmup the sampler. The MCMCThreads() call tells the sampler to use available processor threads for the multiple chains, but it will just sample them in serial if only one thread exists.
5
The @show macro makes the display of the output a bit cleaner.
Warning: Only a single thread available: MCMC chains are not sampled in parallel
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:382
Sampling (1 threads)   0%|                              |  ETA: N/A
Info: Found initial step size
  ϵ = 0.000390625
Info: Found initial step size
  ϵ = 0.00625
Sampling (1 threads)  25%|███████▌                      |  ETA: 0:00:24
Sampling (1 threads)  50%|███████████████               |  ETA: 0:00:08
Info: Found initial step size
  ϵ = 0.00625
Info: Found initial step size
  ϵ = 0.00078125
Sampling (1 threads)  75%|██████████████████████▌       |  ETA: 0:00:03
Sampling (1 threads) 100%|██████████████████████████████| Time: 0:00:09
Sampling (1 threads) 100%|██████████████████████████████| Time: 0:00:09
chain = MCMC chain (5000×15×4 Array{Float64, 3})
Chains MCMC chain (5000×15×4 Array{Float64, 3}):

Iterations        = 1001:1:6000
Number of chains  = 4
Samples per chain = 5000
Wall duration     = 6.86 seconds
Compute duration  = 5.7 seconds
parameters        = σ, a, b
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

           σ    5.3858    0.9845    0.0107   8497.1994   7872.3181    1.0008   ⋯
           a    7.3223    2.1355    0.0234   8409.3893   9696.6954    1.0002   ⋯
           b    1.8009    0.1893    0.0021   8024.7764   9307.0114    1.0004   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           σ    3.8766    4.6898    5.2508    5.9228    7.6846
           a    3.0432    5.9478    7.3204    8.7243   11.4744
           b    1.4259    1.6789    1.8005    1.9222    2.1771

How can we interpret the output? The first parts of the summary statistics are straightforward: we get the mean, standard deviation, and Monte Carlo standard error (mcse) of each parameter. We also get information about the effective sample size (ESS)2 and \(\hat{R}\), which measures the ratio of within-chain variance and across-chain variance as a check for convergence3.

2 The ESS reflects the efficiency of the sampler: this is an estimate of the equivalent number of independent samples; the more correlated the samples, the lower the ESS.

3 The closer \(\hat{R}\) is to 1, the better.

In this case, we can see that we were generally able to recover the “true” data-generating values of \(\sigma = 4\) and \(b = 2\), but \(a\) is slightly off (the mean is 3, rather than the data-generating value of 5). In fact, there is substantial uncertainty about \(a\), with a 95% credible interval of \((3.1, 11.4)\) (compared to \((1.4, 2.2)\) for \(b\)). This isn’t surprising: given the variance of the noise \(\sigma^2\), there are many different intercepts which could fit within that spread.

Let’s now plot the chains for visual inspection.

plot(chain)
Figure 2: Output from the MCMC sampler. Each row corresponds to a different parameter: \(\sigma\), \(a\), and \(b\). Each chain is shown in a different color. The left column shows the sampler traceplots, and the right column the resulting posterior distributions.

We can see from Figure 2 that our chains mixed well and seem to have converged to similar distributions! The traceplots have a “hairy caterpiller” appearance, suggesting relatively little autocorrelation. We can also see how much more uncertainty there is with the intercept \(a\), while the slope \(b\) is much more constrained.

Another interesting comparison we can make is with the maximum-likelihood estimate (MLE), which we can obtain through optimization.

mle_model = linear_regression(x, y)
mle = optimize(mle_model, MLE())
coef(mle)
1
This is where we use the Optim.jl package in this tutorial.
3-element Named Vector{Float64}
A  │ 
───┼────────
σ  │ 4.75545
a  │ 7.65636
b  │ 1.77736

We could also get the maximum a posteriori (MAP) estimate, which includes the prior density, by replacing MLE() with MAP().

Model Diagnostics and Posterior Predictive Checks

One advantage of the Bayesian modeling approach here is that we have access to a generative model, or a model which we can use to generate datasets. This means that we can now use Monte Carlo simulation, sampling from our posteriors, to look at how uncertainty in the parameter estimates propagates through the model. Let’s write a function which gets samples from the MCMC chains and generates datasets.

function mc_predict_regression(x, chain)
    # get the posterior samples
    a = Array(group(chain, :a))
    b = Array(group(chain, :b))
    σ = Array(group(chain, :σ))

    # loop and generate alternative realizations
    μ = a' .+ x * b'
    y = zeros((length(x), length(a)))
    for i = 1:length(a)
        y[:, i] = rand.(Normal.(μ[:, i], σ[i]))
    end
    return y
end
1
The Array(group()) syntax is more general than we need, but is useful if we have multiple variables which were sampled as a group, for example multiple regression coefficients. Otherwise, we can just use e.g. Array(chain, :a).
mc_predict_regression (generic function with 1 method)

Now we can generate a predictive interval and median and compare to the data.

x_pred = 0:20
y_pred = mc_predict_regression(x_pred, chain)
21×20000 Matrix{Float64}:
  7.79596   20.1959   7.9299    5.13623  …  19.4299   6.4165    9.12683
 -0.302011  10.7975  -2.18332   3.36265     14.574    6.2009    7.13161
 23.3908    14.7209   4.38494  15.7897      12.3529   6.74688  13.0819
  7.89445   13.0295  12.1418   10.6287      25.908   16.7987    6.9323
 27.3525    14.476   12.3627   20.7909      18.6918   3.03541  24.996
 13.7448    17.3157  24.9079   28.326    …  20.7976   9.42544  21.1482
 25.883     36.3196  13.8359   14.1202      24.7914  21.1872   26.3175
 13.6528    12.2218  18.4544   17.9854      11.0375  17.106    17.9101
 17.3555    25.5037  21.7402   20.2539      17.8633  20.3939   21.9678
 17.2484    17.8676  31.8723   23.8969      29.9335  20.512    19.4059
  ⋮                                      ⋱                     
 32.4797    39.9771  38.7365   30.4004      30.4432  31.0244   36.9779
 24.3432    35.2726  29.5895   31.1638      21.0732  18.8629   22.5798
 39.4759    23.0576  25.1498   35.2212      35.5964  26.7561   35.0184
 23.6893    44.3171  36.2901   33.5608   …  38.3512  42.1442   24.07
 31.7503    35.3636  26.7346   35.6234      42.3403  39.5029   35.1565
 34.3006    26.2123  36.1949   42.5345      30.0027  38.9232   36.9146
 30.3765    49.8715  39.5922   42.082       36.674   49.074    34.307
 28.4471    42.6885  40.8327   46.952       34.0547  45.5708   42.6106
 45.1023    39.2439  41.1137   30.6755   …  39.5462  38.2909   31.3747

Notice the dimension of y_pred: we have 20,000 columns, because we have 4 chains with 5,000 samples each. If we had wanted to subsample (which might be necessary if we had hundreds of thousands or millions of samples), we could have done that within mc_linear_regression before simulation.

# get the boundaries for the 95% prediction interval and the median
y_ci_low = quantile.(eachrow(y_pred), 0.025)
y_ci_hi = quantile.(eachrow(y_pred), 0.975)
y_med = quantile.(eachrow(y_pred), 0.5)

Now, let’s plot the prediction interval and median, and compare to the original data.

# plot prediction interval
plot(x_pred, y_ci_low, fillrange=y_ci_hi, xlabel=L"$x$", ylabel=L"$y$", fillalpha=0.3, fillcolor=:blue, label="95% Prediction Interval", legend=:topleft, linealpha=0)
plot!(x_pred, y_med, color=:blue, label="Prediction Median")
scatter!(x, y, color=:red, label="Data")
1
Plot the 95% posterior prediction interval as a shaded blue ribbon.
2
Plot the posterior prediction median as a blue line.
3
Plot the data as discrete red points.
Figure 3: Posterior 95% predictive interval and median for the linear regression model. The data is plotted in red for comparison.

From Figure 3, it looks like our model might be slightly under-confident, as with 20 data points, we would expect 5% of them (or 1 data point) to be outside the 95% prediction interval. It’s hard to tell with only 20 data points, though! We could resolve this by tightening our priors, but this depends on how much information we used to specify them in the first place. The goal shouldn’t be to hit a specific level of uncertainty, but if there is a sound reason to tighten the priors, we could do so.

Now let’s look at the residuals from the posterior median and the data. The partial autocorrelations plotted in Figure 4 are not fully convincing, as there are large autocorrelation coefficients with long lags, but the dataset is quite small, so it’s hard to draw strong conclusions. We won’t go further down this rabbit hole as we know our data-generating process involved independent noise, but for a real dataset, we might want to try a model specification with autocorrelated errors to compare.

# calculate the median predictions and residuals
y_pred_data = mc_predict_regression(x, chain)
y_med_data = quantile.(eachrow(y_pred_data), 0.5)
residuals = y_med_data .- y

# plot the residuals and a line to show the zero
plot(pacf(residuals, 1:4), line=:stem, marker=:circle, legend=:false, grid=:false, linewidth=2, xlabel="Lag", ylabel="Partial Autocorrelation", markersize=8, tickfontsize=14, guidefontsize=16, legendfontsize=16)
hline!([0], linestyle=:dot, color=:red)
Figure 4: Partial autocorrelation function of model residuals, relative to the predictive median.

Fitting Extreme Value Models to Tide Gauge Data

Let’s now look at an example of fitting an extreme value distribution (namely, a generalized extreme value distribution, or GEV) to tide gauge data. GEV distributions have three parameters:

  • \(\mu\), the location parameter, which reflects the positioning of the bulk of the GEV distribution;
  • \(\sigma\), the scale parameter, which reflects the width of the bulk;
  • \(\xi\), the shape parameter, which reflects the thickness and boundedness of the tail.

The shape parameter \(\xi\) is often of interest, as there are three classes of GEV distributions corresponding to different signs:

  • \(\xi < 0\) means that the distribution is bounded;
  • \(\xi = 0\) means that the distribution has a thinner tail, so the “extreme extremes” are less likely;
  • \(\xi > 0\) means that the distribution has a thicker tail.

Load Data

First, let’s load the data. We’ll use data from the University of Hawaii Sea Level Center (Caldwell et al., 2015) for San Francisco, from 1897-2013. If you don’t have this data and are working with the notebook, download it here. We’ll assume it’s in a data/ subdirectory, but change the path as needed.

Caldwell, P. C., Merrifield, M. A., & Thompson, P. R. (2015). Sea level measured by tide gauges from global oceans — the joint archive for sea level holdings (NCEI accession 0019568). NOAA National Centers for Environmental Information (NCEI). https://doi.org/10.7289/V5V40S7W

The dataset consists of dates and hours and the tide-gauge measurement, in mm. We’ll load the dataset into a DataFrame.

function load_data(fname)
    date_format = DateFormat("yyyy-mm-dd HH:MM:SS")
    df = @chain fname begin
        CSV.File(; delim=',', header=false)
        DataFrame
        rename("Column1" => "year",
                "Column2" => "month",
                "Column3" => "day",
                "Column4" => "hour",
                "Column5" => "gauge")
        # need to reformat the decimal date in the data file
        @transform :datetime = DateTime.(:year, :month, :day, :hour)
        # replace -99999 with missing
        @transform :gauge = ifelse.(abs.(:gauge) .>= 9999, missing, :gauge)
        select(:datetime, :gauge)
    end
    return df
end
1
This uses the DataFramesMeta.jl package, which makes it easy to string together commands to load and process data
2
Load the file, assuming there is no header.
3
Convert to a DataFrame.
4
Rename columns for ease of access.
5
Reformat the decimal datetime provided in the file into a Julia DateTime.
6
Replace missing data with missing.
7
Select only the :datetime and :gauge columns.
load_data (generic function with 1 method)
dat = load_data("data/h551a.csv")
first(dat, 6)
Table 1: Processed hourly tide gauge data from San Francisco, from 8/1/1897-1/31/2023.
6×2 DataFrame
Row datetime gauge
DateTime Int64?
1 1897-08-01T08:00:00 3292
2 1897-08-01T09:00:00 3322
3 1897-08-01T10:00:00 3139
4 1897-08-01T11:00:00 2835
5 1897-08-01T12:00:00 2377
6 1897-08-01T13:00:00 2012
@df dat plot(:datetime, :gauge, label="Observations", bottom_margin=9mm)
xaxis!("Date", xrot=30)
yaxis!("Mean Water Level")
1
This uses the DataFrame plotting recipe with the @df macro from StatsPlots.jl. This is not needed (you could replace e.g. :datetime with dat.datetime), but it cleans things up slightly.
Figure 5: Hourly mean water at the San Francisco tide gauge from 1897-2023.

Next, we need to detrend the data to remove the impacts of sea-level rise. We do this by removing a one-year moving average, centered on the data point, per the recommendation of Arns et al. (2013).

# calculate the moving average and subtract it off
ma_length = 366
ma_offset = Int(floor(ma_length/2))
moving_average(series,n) = [mean(@view series[i-n:i+n]) for i in n+1:length(series)-n]
dat_ma = DataFrame(datetime=dat.datetime[ma_offset+1:end-ma_offset], residual=dat.gauge[ma_offset+1:end-ma_offset] .- moving_average(dat.gauge, ma_offset))

# plot
@df dat_ma plot(:datetime, :residual, label="Detrended Observations", bottom_margin=9mm)
xaxis!("Date", xrot=30)
yaxis!("Mean Water Level")
Figure 6: Mean water level from the San Francisco tide gauge, detrended using a 1-year moving average centered on the data point, per the recommendation of Arns et al. (2013).
Arns, A., Wahl, T., Haigh, I. D., Jensen, J., & Pattiaratchi, C. (2013). Estimating extreme water level probabilities: A comparison of the direct methods and recommendations for best practise. Coast. Eng., 81, 51–66. https://doi.org/10.1016/j.coastaleng.2013.07.003

The last step in preparing the data is to find the annual maxima. We can do this using the groupby, transform, and combine functions from DataFrames.jl, as below.

# calculate the annual maxima
dat_ma = dropmissing(dat_ma)
dat_annmax = combine(dat_ma -> dat_ma[argmax(dat_ma.residual), :],
                groupby(DataFrames.transform(dat_ma, :datetime => x->year.(x)), :datetime_function))
delete!(dat_annmax, nrow(dat_annmax))

# make a histogram of the maxima to see the distribution
histogram(dat_annmax.residual, label=false)
ylabel!("Count")
xlabel!("Mean Water Level (mm)")
1
If we don’t drop the values which are missing, they will affect the next call to argmax.
2
This first groups the data based on the year (with groupby and using Dates.year() to get the year of each data point), then pulls the rows which correspond to the maxima for each year (using argmax).
3
This will delete the last year, in this case 2023, because the dataset only goes until March 2023 and this data point is almost certainly an outlier due to the limited data from that year.
Figure 7: Histogram of annual block maxima from 1898-2022 from the San Francisco tide gauge dataset.

Fit The Model

@model function gev_annmax(y)               
    μ ~ Normal(1000, 100)
    σ ~ truncated(Normal(0, 100); lower=0)
    ξ ~ Normal(0, 0.5)

    y ~ GeneralizedExtremeValue(μ, σ, ξ)
end

gev_model = gev_annmax(dat_annmax.residual)
n_chains = 4
n_per_chain = 5000
gev_chain = sample(gev_model, NUTS(), MCMCThreads(), n_per_chain, n_chains; drop_warmup=true)
@show gev_chain
1
Location parameter prior: We know that this is roughly on the 1000 mm order of magnitude, but want to keep this relatively broad.
2
Scale parameter prior: This parameter must be positive, so we use a normal truncated at zero.
3
Shape parameter prior: These are usually small and are hard to constrain, so we will use a more informative prior.
4
The data is independently GEV-distributed as we’ve removed the long-term trend and are using long blocks.
5
Initialize the model.
6
We use multiple chains to help diagnose convergence.
7
This sets the number of iterations for each chain.
8
Sample from the posterior using NUTS and drop the iterations used to warmup the sampler.
Warning: Only a single thread available: MCMC chains are not sampled in parallel
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:382
Sampling (1 threads)   0%|                              |  ETA: N/A
Info: Found initial step size
  ϵ = 0.2
Info: Found initial step size
  ϵ = 0.05
Sampling (1 threads)  25%|███████▌                      |  ETA: 0:00:10
Sampling (1 threads)  50%|███████████████               |  ETA: 0:00:04
Info: Found initial step size
  ϵ = 0.05
Info: Found initial step size
  ϵ = 0.025
Sampling (1 threads)  75%|██████████████████████▌       |  ETA: 0:00:02
Sampling (1 threads) 100%|██████████████████████████████| Time: 0:00:05
Sampling (1 threads) 100%|██████████████████████████████| Time: 0:00:05
gev_chain = MCMC chain (5000×15×4 Array{Float64, 3})
Chains MCMC chain (5000×15×4 Array{Float64, 3}):

Iterations        = 1001:1:6000
Number of chains  = 4
Samples per chain = 5000
Wall duration     = 4.85 seconds
Compute duration  = 4.42 seconds
parameters        = μ, σ, ξ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters        mean       std      mcse     ess_bulk     ess_tail      rh     Symbol     Float64   Float64   Float64      Float64      Float64   Float ⋯

           μ   1257.7511    5.6792    0.0517   12111.6749   13632.2573    1.00 ⋯
           σ     57.1490    4.2032    0.0368   13190.6502   13407.5135    1.00 ⋯
           ξ      0.0297    0.0624    0.0006   11233.6422   10660.9175    1.00 ⋯
                                                               2 columns omitted

Quantiles
  parameters        2.5%       25.0%       50.0%       75.0%       97.5% 
      Symbol     Float64     Float64     Float64     Float64     Float64 

           μ   1246.9623   1253.9177   1257.6376   1261.5320   1269.1341
           σ     49.6305     54.1836     56.9100     59.8732     65.9955
           ξ     -0.0811     -0.0138      0.0258      0.0697      0.1607
plot(gev_chain)
Figure 8: Traceplots (left) and marginal distributions (right) from the MCMC sampler for the GEV model.

From Figure 8, it looks like all of the chains have converged to the same distribution; the Gelman-Rubin diagnostic is also close to 1 for all parameters. Next, we can look at a corner plot to see how the parameters are correlated.

corner(gev_chain)
Figure 9: Corner plot for the GEV model.

Figure 9 suggests that the location and scale parameters \(\mu\) and \(\sigma\) are positively correlated. This makes some intuitive sense, as increasing the location parameter shifts the bulk of the distribution in a positive direction, and the increasing scale parameter then increases the likelihood of lower values. However, if these parameters are increased, the shape parameter \(\xi\) decreases, as the tail of the GEV does not need to be as thick due to the increased proximity of outliers to the bulk.