Keywords: Bayesian, brms, looic, model selection, multiple regression, posterior probability check, weighted model averaging

Introduction

There are many good reasons to analyse your data using Bayesian methods. Historically, however, these methods have been computationally intensive and difficult to implement, requiring knowledge of sometimes challenging coding platforms and languages, like WinBUGS, JAGS, or Stan. Newer R packages, however, including, r2jags, rstanarm, and brms have made building Bayesian regression models in R relatively straightforward. For some background on Bayesian statistics, there is a Powerpoint presentation here.

Here I will introduce code to run some simple regression models using the brms package. This package offers a little more flexibility than rstanarm, although the both offer many of the same functionality. I encourage you to check out the extremely helpful vignettes written by Paul Buerkner. Paul’s Github page is also a useful resource. I won’t go into too much detail on prior selection, or demonstrating the full flexibility of the brms package (for that, check out the vignettes), but I will try to add useful links where possible. I will also go a bit beyond the models themselves to talk about model selection using loo, and model averaging

Packages

First, lets load the packages, the most important being brms.

The Data

For this analysis, I am going to use the diamonds dataset, from ggplot2. Because it is pretty large, I am going to subset it.

Setting the number of cores

Because these analyses can sometimes be a little sluggish, it is recommended to set the number of cores you use to the maximum number available. You can check how many cores you have available with the following code.

We’ll use this bit of code again when we are running our models and doing model selection.

Examining and visualizing data

What I am interested in is how well the properties of a diamond predict it’s price. Does the size of the diamond matter? What is the relative importance of color vs clarity? Given that the answer to both of these questions is almost certainly yes, let’s see if the models tell us the same thing. To get a description of the data, let’s use the help function.

Let’s take a look at the data. First let’s plot price as a function carat, a well-know metric of diamond quality. Here I plot the raw data and then both variables log-transformed.

We might considering logging price before running our models with a Gaussian family, or consider using a different link function (e.g. log). First, let’s visualize how clarity and color influence price. Here I will first plot boxplots of price by level for clarity and color, and then price vs carat, with colors representing levels of clarity and color.

From these plots, it looks as if there may be differences in the intercepts and slopes (especially for clarity) between color and clarity classes. We can model this using a mixed effects model. But let’s start with simple multiple regression.

Fitting models

For this first model, we will look at how well diamond ‘carat’ correlates with price.

This might take a few minutes to run, depending on the speed of your machine.

A really fantastic tool for interrogating your model is using the ‘launch_shinystan’ function, which you can call as:

‘launch_shinystan(brm.1)’.

For now, we will take a look at a summary of the models in R, as well as plots of the posterior distributions and the Markov chains. We can see from the summary that our chains have converged sufficiently (rhat = 1).

We can also get an R-squared estimate for our model, thanks to a newly-developed method from Andrew Gelman, Ben Goodrich, Jonah Gabry and Imad Ali, with an explanation here.

Note that log(carat) clearly explains a lot of the variation in diamond price (as we’d expect), with a significantly positive slope (1.52 +- 0.01).

Model validation using approximate leave-one-out cross-validation

Another way to get at the model fit is approximate leave-one-out cross-validation, via the loo package, developed by Vehtari, Gelman, and Gabry ( 2017a, 2017b ). Using loo, we can compute a LOOIC, which is similar to an AIC, which some readers may be familiar with.

The plot of the loo shows the Pareto shape k parameter for each data point. This parameter is used to test the reliability and convergence rate of the PSIS-based estimates. For more details, check out the help and the references above. For our purporses, we want to ensure that no data points have too high values of this parameter. The default threshold for a high value is k > 0.7.

Mixed effects models

We can also run models including group-level effects (also called random effects). Here I will run models with clarity and color as grouping levels, first separately and then together in an ‘overall’ model.

Here’s the model with clarity as the group-level effect. I have also run the function ‘loo’, so that we can compare models. The model with the lowest LOOIC is the better model.

And now with color.

And here’s a model with the log of carat as the fixed effect and color and clarity as group-level effects.

brm.4 <- brm(log(price) ~ log(carat) + (1|color) + (1|clarity), 
             
             brmsfamily("gaussian"), 
     
             data = na.omit(diamonds.train), 
             
             chains = 4, #specify the number of Markov chains
             
             cores = getOption("mc.cores", 1),
             
             iter = 3000, warmup = 1500, thin = 5,
             
             prior =  c(prior(normal(0, 3), "b"), # set normal prior on regression coefficients (mean of 0, location of 3)
                        
                        prior(normal(0, 3), "Intercept"))) # set normal prior on intercept (mean of 0, location of 3)

loo(brm.1, brm.2, brm.3, brm.4)
##                  LOOIC    SE
## brm.1          1337.85 73.28
## brm.2            38.45 66.18
## brm.3          1074.62 68.59
## brm.4         -1342.73 71.67
## brm.1 - brm.2  1299.40 52.07
## brm.1 - brm.3   263.23 30.99
## brm.1 - brm.4  2680.58 67.48
## brm.2 - brm.3 -1036.17 64.23
## brm.2 - brm.4  1381.18 62.57
## brm.3 - brm.4  2417.35 63.27

bayes_R2(brm.4)
##     Estimate    Est.Error     Q2.5     Q97.5
## R2 0.9750782 0.0002039838 0.974631 0.9754266

summary(brm.4)
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: log(price) ~ log(carat) + (1 | color) + (1 | clarity) 
##    Data: na.omit(diamonds.train) (Number of observations: 1680) 
## Samples: 4 chains, each with iter = 3000; warmup = 1500; thin = 5;
##          total post-warmup samples = 1200
## 
## Group-Level Effects: 
## ~clarity (Number of levels: 8) 
##               Estimate Est.Error l-95% CI u-95% CI Eff.Sample Rhat
## sd(Intercept)     0.45      0.16     0.25     0.83        965 1.00
## 
## ~color (Number of levels: 7) 
##               Estimate Est.Error l-95% CI u-95% CI Eff.Sample Rhat
## sd(Intercept)     0.26      0.11     0.14     0.55       1044 1.00
## 
## Population-Level Effects: 
##           Estimate Est.Error l-95% CI u-95% CI Eff.Sample Rhat
## Intercept     8.45      0.20     8.03     8.83        982 1.00
## logcarat      1.86      0.01     1.84     1.87       1200 1.00
## 
## Family Specific Parameters: 
##       Estimate Est.Error l-95% CI u-95% CI Eff.Sample Rhat
## sigma     0.16      0.00     0.16     0.17       1200 1.00
## 
## Samples were drawn using sampling(NUTS). For each parameter, Eff.Sample 
## is a crude measure of effective sample size, and Rhat is the potential 
## scale reduction factor on split chains (at convergence, Rhat = 1).
coef(brm.4)
## $clarity
## , , Intercept
## 
##      Estimate Est.Error     Q2.5    Q97.5
## I1   7.757952 0.1116812 7.534508 7.972229
## IF   8.896737 0.1113759 8.666471 9.119115
## SI1  8.364881 0.1118541 8.138917 8.585221
## SI2  8.208712 0.1116475 7.976549 8.424202
## VS1  8.564924 0.1114861 8.338425 8.780385
## VS2  8.500922 0.1119241 8.267040 8.715973
## VVS1 8.762394 0.1112272 8.528874 8.978609
## VVS2 8.691808 0.1113552 8.458141 8.909012
## 
## , , logcarat
## 
##      Estimate  Est.Error     Q2.5   Q97.5
## I1   1.857542 0.00766643 1.842588 1.87245
## IF   1.857542 0.00766643 1.842588 1.87245
## SI1  1.857542 0.00766643 1.842588 1.87245
## SI2  1.857542 0.00766643 1.842588 1.87245
## VS1  1.857542 0.00766643 1.842588 1.87245
## VS2  1.857542 0.00766643 1.842588 1.87245
## VVS1 1.857542 0.00766643 1.842588 1.87245
## VVS2 1.857542 0.00766643 1.842588 1.87245
## 
## 
## $color
## , , Intercept
## 
##   Estimate Est.Error     Q2.5    Q97.5
## D 8.717499 0.1646875 8.379620 9.044789
## E 8.628844 0.1640905 8.294615 8.957632
## F 8.569998 0.1645341 8.235241 8.891485
## G 8.489433 0.1644847 8.155874 8.814277
## H 8.414576 0.1642564 8.081458 8.739100
## I 8.273718 0.1639215 7.940648 8.590550
## J 8.123996 0.1638187 7.791308 8.444856
## 
## , , logcarat
## 
##   Estimate  Est.Error     Q2.5   Q97.5
## D 1.857542 0.00766643 1.842588 1.87245
## E 1.857542 0.00766643 1.842588 1.87245
## F 1.857542 0.00766643 1.842588 1.87245
## G 1.857542 0.00766643 1.842588 1.87245
## H 1.857542 0.00766643 1.842588 1.87245
## I 1.857542 0.00766643 1.842588 1.87245
## J 1.857542 0.00766643 1.842588 1.87245

All of the mixed effects models we have looked at so far have only allowed the intercepts of the groups to vary, but, as we saw when we were looking at the data, it seems as if different levels of our groups could have different slopes too. We can specify a model that allow the slope of the price~carat relationship to cary by both color and clarity.

We can now compare our models using ‘loo’.

It looks like the final model we ran is the best model. Let’s take a look at the Bayesian R-squared value for this model, and take a look at the model summary. We can also get more details on the coefficients using the ‘coef’ function. Similarly we could use ‘fixef’ for population-level effects and ‘ranef’ from group-level effects.

bayes_R2(brm.4)
##     Estimate    Est.Error     Q2.5     Q97.5
## R2 0.9750782 0.0002039838 0.974631 0.9754266

summary(brm.4)
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: log(price) ~ log(carat) + (1 | color) + (1 | clarity) 
##    Data: na.omit(diamonds.train) (Number of observations: 1680) 
## Samples: 4 chains, each with iter = 3000; warmup = 1500; thin = 5;
##          total post-warmup samples = 1200
## 
## Group-Level Effects: 
## ~clarity (Number of levels: 8) 
##               Estimate Est.Error l-95% CI u-95% CI Eff.Sample Rhat
## sd(Intercept)     0.45      0.16     0.25     0.83        965 1.00
## 
## ~color (Number of levels: 7) 
##               Estimate Est.Error l-95% CI u-95% CI Eff.Sample Rhat
## sd(Intercept)     0.26      0.11     0.14     0.55       1044 1.00
## 
## Population-Level Effects: 
##           Estimate Est.Error l-95% CI u-95% CI Eff.Sample Rhat
## Intercept     8.45      0.20     8.03     8.83        982 1.00
## logcarat      1.86      0.01     1.84     1.87       1200 1.00
## 
## Family Specific Parameters: 
##       Estimate Est.Error l-95% CI u-95% CI Eff.Sample Rhat
## sigma     0.16      0.00     0.16     0.17       1200 1.00
## 
## Samples were drawn using sampling(NUTS). For each parameter, Eff.Sample 
## is a crude measure of effective sample size, and Rhat is the potential 
## scale reduction factor on split chains (at convergence, Rhat = 1).

coef(brm.4)
## $clarity
## , , Intercept
## 
##      Estimate Est.Error     Q2.5    Q97.5
## I1   7.757952 0.1116812 7.534508 7.972229
## IF   8.896737 0.1113759 8.666471 9.119115
## SI1  8.364881 0.1118541 8.138917 8.585221
## SI2  8.208712 0.1116475 7.976549 8.424202
## VS1  8.564924 0.1114861 8.338425 8.780385
## VS2  8.500922 0.1119241 8.267040 8.715973
## VVS1 8.762394 0.1112272 8.528874 8.978609
## VVS2 8.691808 0.1113552 8.458141 8.909012
## 
## , , logcarat
## 
##      Estimate  Est.Error     Q2.5   Q97.5
## I1   1.857542 0.00766643 1.842588 1.87245
## IF   1.857542 0.00766643 1.842588 1.87245
## SI1  1.857542 0.00766643 1.842588 1.87245
## SI2  1.857542 0.00766643 1.842588 1.87245
## VS1  1.857542 0.00766643 1.842588 1.87245
## VS2  1.857542 0.00766643 1.842588 1.87245
## VVS1 1.857542 0.00766643 1.842588 1.87245
## VVS2 1.857542 0.00766643 1.842588 1.87245
## 
## 
## $color
## , , Intercept
## 
##   Estimate Est.Error     Q2.5    Q97.5
## D 8.717499 0.1646875 8.379620 9.044789
## E 8.628844 0.1640905 8.294615 8.957632
## F 8.569998 0.1645341 8.235241 8.891485
## G 8.489433 0.1644847 8.155874 8.814277
## H 8.414576 0.1642564 8.081458 8.739100
## I 8.273718 0.1639215 7.940648 8.590550
## J 8.123996 0.1638187 7.791308 8.444856
## 
## , , logcarat
## 
##   Estimate  Est.Error     Q2.5   Q97.5
## D 1.857542 0.00766643 1.842588 1.87245
## E 1.857542 0.00766643 1.842588 1.87245
## F 1.857542 0.00766643 1.842588 1.87245
## G 1.857542 0.00766643 1.842588 1.87245
## H 1.857542 0.00766643 1.842588 1.87245
## I 1.857542 0.00766643 1.842588 1.87245
## J 1.857542 0.00766643 1.842588 1.87245

Clearly, the variables we have included have a really strong influence on diamond price!

Posterior probability checking using pp_check

The pp_check allows for graphical posterior predictive checking. We can generate figures to compare the observed data to simulated data from the posterior predictive distribution. This is a great graphical way to evaluate your model.

There are many different options of plots to choose from. In the first plot I use density plots, where the observed y values are plotted with expected values from the posterior distribution.

It is good to see that our model is doing a fairly good job of capturing the slight bimodality in logged diamond prices, althogh specifying a different family of model might help to improve this.

Here, ‘nsamples’ refers to the number of draws from the posterior distribution to use to calculate yrep values.

We can aslo look at the fit based on groups. Here, for example, are scatteplots with the observed prices (log scale) on the y-axis and the average (across all posterior samples) on the x-axis.

Model prediction

Finally, we can evaluate how well our model does at predicting diamond data that we held out. We can use the ‘predict’ function (as we would with a more standard model). We can also get estimates of error around each data point! We can plot the prediction using ggplot2.

That’s all for this post.

Thanks for reading!

Contact Me

Please check out my personal website at timothyemoore.com