Skip to content

Commit

Permalink
Merge pull request #209 from drbenvincent/dev
Browse files Browse the repository at this point in the history
more WAIC progress
  • Loading branch information
drbenvincent authored Mar 8, 2018
2 parents fade821 + 6a57668 commit 57cd871
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 28 deletions.
87 changes: 64 additions & 23 deletions ddToolbox/WAIC.m
Original file line number Diff line number Diff line change
@@ -1,30 +1,62 @@
classdef WAIC
%WAIC WAIC object
% Extended description here
% The WAIC object is intended to help conduct Bayesian model
% comparison.
%
% Step 1: Create a WAIC object for each model we have. We do this by
% creating a WAIC instance, calling it with a table of log likeliood
% values. Each column corresponds to an MCMC sample, and each row
% corresponds to an observation. Creating a single WAIC object will
% result in various stats being calculated, but the intention is to
% compare mulitple models.
%
% We do this by creating an array of WAIC objects, one for each model.
% For example, assuming we already have our log liklihood tables
% produced by 3 models...
% >> waic_stats = [WAIC(ll1,'m1'), WAIC(ll2,'m2'), WAIC(ll3,'m3')]
%
% Step 2: compare
% We now have an object array, and we can call the compare or plot
% methods on this. For example
% >> comparison_table = waic_stats.compare()
% will produce a table of WAIC comparison stats.
%
% and
% >> waic_stats.plot()
% will produce a nicely formatted
%
% References
% Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A.,
% & Rubin, D. B. (2013). Bayesian Data Analysis, Third Edition.
% CRC Press.
%
% McElreath, R. (2016). Statistical Rethinking: A Bayesian Course with
% Examples in R and Stan. CRC Press.

properties (SetAccess = protected)
lppd, pWAIC, WAIC_value, WAIC_standard_error
nSamples, nCases
end

properties (Hidden = true)
properties (Hidden = true, SetAccess = protected)
log_lik
lppd_vec, pWAIC_vec, WAIC_vec
modelName
end

properties
model_name
end

methods

function obj = WAIC(log_lik)
function obj = WAIC(log_lik, model_name)
% WAIC constructor. The input log_lik should be a table of log
% likelihood values. Each column corresponds to an MCMC sample,
% and each row corresponds to an observation

[obj.nCases, obj.nSamples] = size(log_lik);

obj.log_lik = log_lik;
clear log_lik
obj.model_name = model_name;
obj.log_lik = log_lik; clear log_lik

% Calculate lppd
% Equation 7.5 from Gelman et al (2013)
Expand All @@ -37,20 +69,19 @@
obj.pWAIC = sum(obj.pWAIC_vec);

% Calculate WAIC
obj.WAIC_value = -2 * obj.lppd + 2 * obj.pWAIC;

% Calculate WAIC standard error
obj.WAIC_vec = -2 * obj.lppd_vec + 2 * obj.pWAIC_vec;
obj.WAIC_standard_error = sqrt(obj.nCases)*std(obj.WAIC_vec);
obj.WAIC_value = calc_waic(obj.lppd, obj.pWAIC);

% Calculate WAIC standard error
obj.WAIC_vec = calc_waic(obj.lppd_vec, obj.pWAIC_vec);
obj.WAIC_standard_error = standard_error(obj.WAIC_vec);
end

function comparisonTable = compare(obj)
% Compare WAIC info from mulitple models
assert(numel(obj)>1, 'expecting an array of >1 WAIC object')

% Build a table of values
model = {obj.modelName}';
model = {obj.model_name}';
WAIC = [obj.WAIC_value]';
pWAIC = [obj.pWAIC]';
lppd = [obj.lppd]';
Expand All @@ -68,7 +99,7 @@
% Calculate SE of difference (of WAIC values) between
% model m and i_best_model
WAIC_diff = obj(i_best_model).WAIC_vec - obj(m).WAIC_vec;
dSE(m,1) = sqrt(obj(m).nCases)*std(WAIC_diff);
dSE(m,1) = standard_error(WAIC_diff);
end
end
% create table
Expand All @@ -85,35 +116,37 @@ function plot(obj)
% define y-value positions for each model
y = [1:1:size(comparisonTable,1)];

ms = 6;
% set plot options
marker_size = 7;
grey_col = [0.5, 0.5, 0.5];

clf
hold on

% in-sample deviance as solid circles
in_sample_deviance = -2*comparisonTable.lppd;
isd = plot(in_sample_deviance, y, 'ko',...
'MarkerFaceColor','k',...
'MarkerSize', ms);
'MarkerSize', marker_size);

% WAIC as empty cirlcles, with SE errorbars
%waic = plot(comparisonTable.WAIC, y, 'ko');
waic_eb = errorbar(comparisonTable.WAIC,y,comparisonTable.SE,...
'horizontal',...
'o',...
'LineStyle', 'none',...
'Color', 'k',...
'MarkerFaceColor','w',...
'MarkerSize', ms);
'MarkerSize', marker_size);

% plot dSE models
% plot WAIC as compared to best model, in a different colour
waic_diff = errorbar(comparisonTable.dWAIC([2:end])+min(comparisonTable.WAIC),...
y([2:end])-0.2, comparisonTable.dSE([2:end]),...
'horizontal',...
'^',...
'LineStyle', 'none',...
'Color', [0.5 0.5 0.5],...
'Color', grey_col,...
'MarkerFaceColor','w',...
'MarkerSize', ms);
'MarkerSize', marker_size);

% formatting
xlabel('deviance');
Expand All @@ -123,10 +156,10 @@ function plot(obj)
'YDir','reverse');
ylim([min(y)-1, max(y)+1]);

vline(min(comparisonTable.WAIC), 'Color',[0.5 0.5 0.5]);
vline(min(comparisonTable.WAIC), 'Color', grey_col);

legend([isd, waic_eb, waic_diff],...
{'in-sample deviance', 'WAIC (+/- SE)', 'SE of WAIC difference (+/- SE)'},...
{'in-sample deviance', 'WAIC (+/- SE)', 'WAIC difference (+/- SE)'},...
'location', 'eastoutside');

title('WAIC Model Comparison')
Expand All @@ -135,4 +168,12 @@ function plot(obj)

end

end

function SE = standard_error(x)
SE = sqrt(numel(x))*std(x);
end

function waic = calc_waic(lppd, pWAIC)
waic = -2 * lppd + 2 * pWAIC;
end
5 changes: 2 additions & 3 deletions ddToolbox/models/Model.m
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,8 @@ function export(obj)
[chains, samples_per_chain, N] = size(samples.log_lik);
log_lik = reshape(samples.log_lik, chains*samples_per_chain, N)';
% second, create WAIC object
obj.WAIC_stats = WAIC(log_lik);

obj.WAIC_stats.modelName = class(obj);
model_name = class(obj);
obj.WAIC_stats = WAIC(log_lik, model_name);
end

function auc = calcAreaUnderCurveForAll(obj, MAX_DELAY)
Expand Down
68 changes: 68 additions & 0 deletions demo/demo_pooling_WAIC.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
% demo_pooling_WAIC
% This example examines how level of pooling (hierarchical inference)
% effects the model complexity. We will fit the kirby dataset with various
% models and examine what we get from the WAIC model comparison.

%% setup
toolbox_path = '~/git-local/delay-discounting-analysis/ddToolbox';
addpath(toolbox_path)
datapath = '~/git-local/delay-discounting-analysis/demo/datasets/kirby';

addpath(toolbox_path)
ddAnalysisSetUp();

% Running this multiple times will result in slightly different model
% comparison results. For real model comparison contexts, eg for research
% papers, then you'd want to ensire your posteriors are good
% approximations. Best way to start is by increasing the number of MCMC
% samples.
mcmc_params = struct('nsamples', 10000,...
'nchains', 4,...
'nburnin', 2000);

%% Set up the data to be analysed
myData = Data(datapath, 'files', allFilesInFolder(datapath, 'txt'));

%% Fit multiple models to the dataset

modelA = ModelSeparateLogK(...
myData,...
'savePath', fullfile(pwd,'output','modelA'),...
'mcmcParams', mcmc_params);

modelB = ModelMixedLogK(...
myData,...
'savePath', fullfile(pwd,'output','modelB'),...
'mcmcParams', mcmc_params);

modelC = ModelHierarchicalLogK(...
myData,...
'savePath', fullfile(pwd,'output','modelC'),...
'mcmcParams', mcmc_params);

modelD = ModelHierarchicalMEUpdated(...
myData,...
'savePath', fullfile(pwd,'output','modelD'),...
'mcmcParams', mcmc_params);


%% Examine WAIC stats for models
waic = [modelA.WAIC_stats,...
modelB.WAIC_stats,...
modelC.WAIC_stats,...
modelD.WAIC_stats];

waic_comparison_table = waic.compare()

waic.plot()

figure(2), clf

%% Examine WAIC stats for models
waic = [modelA.WAIC_stats,...
modelB.WAIC_stats,...
modelC.WAIC_stats];

waic_comparison_table = waic.compare()

waic.plot()
Binary file added docs/discussion/model_complexity.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
104 changes: 104 additions & 0 deletions docs/discussion/pooling_and_complexity.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Counter-intuitive aspects of model complexity

The initial assumption would be that the 'separate' models which treat each participant independently are the simplest, and the 'full hierarchical' models which model participant-level parameters as coming from a group-level distribution, are the most complex. In some ways this is right: we are adding hyperparameters and estimating more things, and the Bayesian Network diagrams look more complex.

However a more accurate view of model complexity would focus not on the number of parameters, or adding more levels in the hierarchy. Instead, we can equate model complexity with the breadth of data space which is consistent with the model (see Figure below, from Pitt _et al_, 2002).

![](model_complexity.png)


Under this approach, complex models are ones which make a broad range of predictions in data space. Simple models are ones which make more specific predictions in data space. This focus on data space does not include a notion of number of parameters. You can imagine some situations where more model parameters will mean a more complex model in that has a greater ability to fit more data.

However, you can also imagine a situation where adding parameters might actually _decrease_ the model complexity. An example of this, relevant to our particular context, would be the addition of a hyper-parameter, i.e. adding a group-level distribution of discount rates. Despite the addition of another 'level' of inference and more parameters to estimate, what we have done is to actually _decrease_ the range of plausible discount rates at the participant level. We can think of this not as adding parameters, but as adding prior information, which could actually decrease the breadth of plausible observations in data space under the model.

## An example

For this example (see `demo_pooling_WAIC.m`) we will fit the example dataset collected with the Kirby (2009) method. We will conduct parameter estimation for the following models:
- ModelSeparateLogK (simplest in the number of parameters sense)
- ModelMixedLogK
- ModelHierarchicalLogK
- ModelHierarchicalMEUpdated (most complex in the number of parameters sense)

First, some set up
```matlab
toolbox_path = '~/git-local/delay-discounting-analysis/ddToolbox';
addpath(toolbox_path)
datapath = '~/git-local/delay-discounting-analysis/demo/datasets/kirby';
addpath(toolbox_path)
ddAnalysisSetUp();
mcmc_params = struct('nsamples', 10000,...
'nchains', 4,...
'nburnin', 2000);
myData = Data(datapath, 'files', allFilesInFolder(datapath, 'txt'));
```

Then we conduct parameter estimation with the models

```matlab
modelA = ModelSeparateLogK(...
myData,...
'savePath', fullfile(pwd,'pooling_demo','modelA'),...
'mcmcParams', mcmc_params);
modelB = ModelMixedLogK(...
myData,...
'savePath', fullfile(pwd,'pooling_demo','modelB'),...
'mcmcParams', mcmc_params);
modelC = ModelHierarchicalLogK(...
myData,...
'savePath', fullfile(pwd,'pooling_demo','modelC'),...
'mcmcParams', mcmc_params);
modelD = ModelHierarchicalMEUpdated(...
myData,...
'savePath', fullfile(pwd,'pooling_demo','modelD'),...
'mcmcParams', mcmc_params);
```

Then we can do some model comparison with WAIC.

```matlab
waic = [modelA.WAIC_stats,...
modelB.WAIC_stats,...
modelC.WAIC_stats,...
modelD.WAIC_stats];
waic_comparison_table = waic.compare()
```

![](pooling_waic_table.png)

and get a corresponding plot

```matlab
waic.plot()
```

![](pooling_waic_fig.png)

So what is going on here. We might initially think that the `ModelHierarchicalMEUpdated` model is more complex because it has both slope and intercept parameters for the magnitude effect _and_ group level estimates of the mean and variance of both these parameters. This could be because the data shows a strong magnitude effect, and the vanilla hyperbolic models (which don't capture the magnitude effect) are simply less capable of fitting the data. We can see that the in-sample deviance is lowest for this model, so it does seem as if it can account for the observed data better than the other models. However, the model also wins by WAIC, and so we can probably eliminate any concerns about model overfitting here.

Looking at the WAIC values for the hyperbolic models (`*logK`) we actually see that the hierarchical is the next best, followed by the partial pooling model, and the separate model is the worst. This is against our naive predictions of model complexity increasing with the number of parameters.

Let's eliminate the magnitude effect model for the moment to take a closer look.

```matlab
waic = [modelA.WAIC_stats,...
modelB.WAIC_stats,...
modelC.WAIC_stats];
waic.plot()
```

![](pooling_waic_fig2.png)

Now we have a better look at the WAIC difference (gray triangles), relative to the best model (`ModelHierarchicalLogK`) under consideration. It looks as if the in-sample deviance is lower, but we are less interested in this. In terms of WAIC, the model is worse, but the standard error overlaps with the WAIC of the winning model. However the standard error of the WAIC difference of the `ModelSeparateLogK` model does not overlap with the winning model. So we have some suspicion that this model is genuinely worse.

This is pretty interesting. What we have seen here is that the _apparent_ increase in model complexity by adding (group-level) parameters actually makes the model simpler. The group-level hyperparameters are acting as shrinkage priors, making a _smaller_ region of data space plausible under the model (and priors). And what was most interesting was that adding even more parameters and specifying our knowledge of the magnitude effect made the model even better, as measured by WAIC.

This is not to say that adding more parameters and group-level inferences will _always_ make the models simpler. It depends what the priors and hyperpriors are, but in general we should remember that adding more knowledge actually can _reduce_ model complexity, and make the model better.


## References

Pitt, M. A., Myung, I. J., & Zhang, S. (2002). Toward a method of selecting among computational models of cognition. Psychological Review, 109(3), 472–491. http://doi.org/10.1037//0033-295X.109.3.472
Binary file added docs/discussion/pooling_waic_fig.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/discussion/pooling_waic_fig2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/discussion/pooling_waic_table.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ Vincent, B. T. (2016) **[Hierarchical Bayesian estimation and hypothesis testing
- [Psychometric link function](discussion/psychometric_link_function.md)
- [Hyperpriors / parameter pooling](discussion/hyperpriors.md)
- [Level of parameter pooling](discussion/level_of_pooling.md)
- [Counterintuitive aspects of model complexity](discussion/pooling_and_complexity.md)
- [Hypothesis testing](discussion/hypothesis_testing.md)
12 changes: 10 additions & 2 deletions docs/tutorial/model_comparison.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Then create a `Data` object that we'll analyse with multiple models.
myData = Data(datapath, 'files', allFilesInFolder(datapath, 'txt'));
```

And conduct parameter estimation on multiple models. In a real context you'd perhaps want to use mode descriptive names for the the models and output folders, but this will be enough to get the right idea. This will take a little while to compute.
And conduct parameter estimation on multiple models. In a real context you'd perhaps want to use more descriptive names for the the models and output folders, but this will be enough to get the right idea. This will take a little while to compute.

```matlab
modelA = ModelHierarchicalLogK(...
Expand Down Expand Up @@ -103,7 +103,15 @@ waic.plot()

![](WAICplot.png)

Again, readers are referred to McElreath (2016) for a thorough explanation of this plot, the style of which is directly copied from that book.
I will avoid attempting to provide a thorough description of how to interpret this WAIC comparison, I'll leave this to others who've already done a thorough job. Readers are referred to McElreath (2016) for an explanation of this plot, the style of which is directly copied from that book.

In case you don't have a copy of McElreath's book, here is a video lecture he gave on model comparison to get you started.

<iframe width="560" height="315" src="https://www.youtube.com/embed/vSjL2Zc-gEQ" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen></iframe>

Although briefly, in this example, the winning model is the hierarchical hyperboloid model as it has the lowest WAIC. The models are ordered from best (top) to worst (bottom). The black open circles +/- SE represent the WAIC for other models, but the gray triangles +/- SE represent the difference between a model and the winning model. So as a first pass analysis one could look to see whether these non-winning models are worse than the winning model by seeing how far away these WAIC differences are from the winning model.



## Reference

Expand Down

0 comments on commit 57cd871

Please sign in to comment.