Skip to content

Commit

Permalink
support group-level model comparison (ilabcode#217)
Browse files Browse the repository at this point in the history
* ensure that distribution parameters are vector prior to broadcasting

* update tutorial

* api

* github action
  • Loading branch information
LegrandNico committed Aug 9, 2024
1 parent 8896c5b commit a86f38d
Show file tree
Hide file tree
Showing 5 changed files with 361 additions and 162 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ on:
release:
types: [published]
pull_request:
types:
- opened
types: [opened, synchronize, reopened]

permissions:
contents: write

Expand Down
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ embedded in models using PyMC>=5.0.0.
hgf_logp
HGFLogpGradOp
HGFDistribution
HGFPointwise

Model
*****
Expand Down
432 changes: 306 additions & 126 deletions docs/source/notebooks/3-Multilevel_HGF.ipynb

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,14 @@ @InProceedings{mathys:2020
abstract="Active inference relies on state-space models to describe the environments that agents sample with their actions. These actions lead to state changes intended to minimize future surprise. We show that surprise minimization relying on Bayesian inference can be achieved by filtering of the sufficient statistic time series of exponential family input distributions, and we propose the hierarchical Gaussian filter (HGF) as an appropriate, efficient, and scalable tool for active inference agents to achieve this.",
isbn="978-3-030-64919-7"
}

@article{Vehtari:2015,
doi = {10.48550/ARXIV.1507.04544},
url = {https://arxiv.org/abs/1507.04544},
author = {Vehtari, Aki and Gelman, Andrew and Gabry, Jonah},
keywords = {Computation (stat.CO), Methodology (stat.ME), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC},
publisher = {arXiv},
year = {2015},
copyright = {arXiv.org perpetual, non-exclusive license}
}
75 changes: 41 additions & 34 deletions src/pyhgf/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def hgf_logp(
volatility_coupling_1: ArrayLike = 1.0,
volatility_coupling_2: ArrayLike = 1.0,
input_precision: ArrayLike = np.inf,
response_function_parameters: ArrayLike = 1.0,
response_function_parameters: ArrayLike = np.ones(1),
vectorized_logp: Callable = logp,
input_data: ArrayLike = np.nan,
response_function_inputs: ArrayLike = np.nan,
Expand Down Expand Up @@ -260,6 +260,11 @@ def hgf_logp(
# number of models
n = input_data.shape[0]

# ensure that the response parameters have n_model as first dimension
response_function_parameters = jnp.broadcast_to(
response_function_parameters, (n,) + response_function_parameters.shape[1:]
)

# Broadcast inputs to an array with length n>=1
(
_tonic_volatility_1,
Expand Down Expand Up @@ -409,26 +414,26 @@ def make_node(
volatility_coupling_1: ArrayLike = np.array(1.0),
volatility_coupling_2: ArrayLike = np.array(1.0),
input_precision: ArrayLike = np.inf,
response_function_parameters: ArrayLike = np.array([1.0]),
response_function_parameters: ArrayLike = np.array(1.0),
):
"""Initialize node structure."""
# Convert our inputs to symbolic variables
inputs = [
pt.as_tensor_variable(mean_1),
pt.as_tensor_variable(mean_2),
pt.as_tensor_variable(mean_3),
pt.as_tensor_variable(precision_1),
pt.as_tensor_variable(precision_2),
pt.as_tensor_variable(precision_3),
pt.as_tensor_variable(tonic_volatility_1),
pt.as_tensor_variable(tonic_volatility_2),
pt.as_tensor_variable(tonic_volatility_3),
pt.as_tensor_variable(tonic_drift_1),
pt.as_tensor_variable(tonic_drift_2),
pt.as_tensor_variable(tonic_drift_3),
pt.as_tensor_variable(volatility_coupling_1),
pt.as_tensor_variable(volatility_coupling_2),
pt.as_tensor_variable(input_precision),
pt.as_tensor_variable(mean_1, ndim=1),
pt.as_tensor_variable(mean_2, ndim=1),
pt.as_tensor_variable(mean_3, ndim=1),
pt.as_tensor_variable(precision_1, ndim=1),
pt.as_tensor_variable(precision_2, ndim=1),
pt.as_tensor_variable(precision_3, ndim=1),
pt.as_tensor_variable(tonic_volatility_1, ndim=1),
pt.as_tensor_variable(tonic_volatility_2, ndim=1),
pt.as_tensor_variable(tonic_volatility_3, ndim=1),
pt.as_tensor_variable(tonic_drift_1, ndim=1),
pt.as_tensor_variable(tonic_drift_2, ndim=1),
pt.as_tensor_variable(tonic_drift_3, ndim=1),
pt.as_tensor_variable(volatility_coupling_1, ndim=1),
pt.as_tensor_variable(volatility_coupling_2, ndim=1),
pt.as_tensor_variable(input_precision, ndim=1),
pt.as_tensor_variable(response_function_parameters),
]
# This `Op` will return one gradient per input. For simplicity, we assume
Expand Down Expand Up @@ -602,6 +607,8 @@ def __init__(
if time_steps is None:
time_steps = np.ones(shape=input_data.shape)

assert time_steps.shape == input_data.shape

# create the default HGF template to be use by the logp function
self.hgf = HGF(n_levels=n_levels, model_type=model_type)

Expand Down Expand Up @@ -652,25 +659,25 @@ def make_node(
volatility_coupling_1: ArrayLike = np.array(1.0),
volatility_coupling_2: ArrayLike = np.array(1.0),
input_precision: ArrayLike = np.inf,
response_function_parameters: ArrayLike = np.array([1.0]),
response_function_parameters: ArrayLike = np.array(1.0),
):
"""Convert inputs to symbolic variables."""
inputs = [
pt.as_tensor_variable(mean_1),
pt.as_tensor_variable(mean_2),
pt.as_tensor_variable(mean_3),
pt.as_tensor_variable(precision_1),
pt.as_tensor_variable(precision_2),
pt.as_tensor_variable(precision_3),
pt.as_tensor_variable(tonic_volatility_1),
pt.as_tensor_variable(tonic_volatility_2),
pt.as_tensor_variable(tonic_volatility_3),
pt.as_tensor_variable(tonic_drift_1),
pt.as_tensor_variable(tonic_drift_2),
pt.as_tensor_variable(tonic_drift_3),
pt.as_tensor_variable(volatility_coupling_1),
pt.as_tensor_variable(volatility_coupling_2),
pt.as_tensor_variable(input_precision),
pt.as_tensor_variable(mean_1, ndim=1),
pt.as_tensor_variable(mean_2, ndim=1),
pt.as_tensor_variable(mean_3, ndim=1),
pt.as_tensor_variable(precision_1, ndim=1),
pt.as_tensor_variable(precision_2, ndim=1),
pt.as_tensor_variable(precision_3, ndim=1),
pt.as_tensor_variable(tonic_volatility_1, ndim=1),
pt.as_tensor_variable(tonic_volatility_2, ndim=1),
pt.as_tensor_variable(tonic_volatility_3, ndim=1),
pt.as_tensor_variable(tonic_drift_1, ndim=1),
pt.as_tensor_variable(tonic_drift_2, ndim=1),
pt.as_tensor_variable(tonic_drift_3, ndim=1),
pt.as_tensor_variable(volatility_coupling_1, ndim=1),
pt.as_tensor_variable(volatility_coupling_2, ndim=1),
pt.as_tensor_variable(input_precision, ndim=1),
pt.as_tensor_variable(response_function_parameters),
]
# Define the type of output returned by the wrapped JAX function
Expand Down Expand Up @@ -818,7 +825,7 @@ def make_node(
volatility_coupling_1: ArrayLike = np.array(1.0),
volatility_coupling_2: ArrayLike = np.array(1.0),
input_precision: ArrayLike = np.inf,
response_function_parameters: ArrayLike = np.array([1.0]),
response_function_parameters: ArrayLike = np.array(1.0),
):
"""Convert inputs to symbolic variables."""
inputs = [
Expand Down

0 comments on commit a86f38d

Please sign in to comment.