Skip to content

Commit

Permalink
add generalised Bayesian filter node (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico authored Jun 10, 2024
1 parent 646a0c3 commit 212d29f
Show file tree
Hide file tree
Showing 17 changed files with 513 additions and 216 deletions.
20 changes: 20 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ Continuous nodes
continuous_node_update
continuous_node_update_ehgf

Exponential family
------------------

.. currentmodule:: pyhgf.updates.posterior.exponential

.. autosummary::
:toctree: generated/pyhgf.updates.posterior.exponential

posterior_update_exponential_family

Prediction steps
================

Expand Down Expand Up @@ -115,6 +125,16 @@ Continuous inputs
continuous_input_value_prediction_error
continuous_input_prediction_error

Generic input
^^^^^^^^^^^^^

.. currentmodule:: pyhgf.updates.prediction_error.inputs.generic

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.inputs.generic

generic_input_prediction_error

State nodes
-----------

Expand Down
36 changes: 18 additions & 18 deletions docs/source/notebooks/0.2-Creating_networks.ipynb

Large diffs are not rendered by default.

240 changes: 142 additions & 98 deletions docs/source/notebooks/0.3-Generalised_filtering.ipynb

Large diffs are not rendered by default.

52 changes: 43 additions & 9 deletions src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def switching_propagation(attributes, scan_input):
# because some of the input nodes might not have been updated, here we manually
# insert the input data to the input node (without triggering updates)
for idx, inp in zip(self.inputs.idx, range(input_data.shape[1])):
self.node_trajectories[idx]["value"] = input_data[inp]
self.node_trajectories[idx]["values"] = input_data[inp]

return self

Expand Down Expand Up @@ -340,10 +340,15 @@ def add_nodes(
kind :
The kind of node to create. If `"continuous-state"` (default), the node will
be a regular state node that can have value and/or volatility
parents/children. The hierarchical dependencies are specified using the
corresponding parameters below. If `"binary-state"`, the node should be the
value parent of a binary input. To create an input node, three types of
inputs are supported:
parents/children. If `"binary-state"`, the node should be the
value parent of a binary input. State nodes filtering distribution from the
exponential family can be created using the `"ef-"` prefix (e.g.
`"ef-normal"` for a univariate normal distribution). Note that only a few
distributions are implemented at the moment.
In addition to state nodes, four types of input nodes are supported:
- `generic-input`: receive a value or an array and pass it to the parent
nodes.
- `continuous-input`: receive a continuous observation as input.
- `binary-input` receives a single boolean as observation. The parameters
provided to the binary input node contain: 1. `binary_precision`, the binary
Expand Down Expand Up @@ -460,14 +465,20 @@ def add_nodes(
"expected_precision_children": 0.0,
},
}
elif kind == "generic-input":
default_parameters = {
"values": 0.0,
"time_step": 0.0,
"observed": 0,
}
elif kind == "continuous-input":
default_parameters = {
"volatility_coupling_parents": None,
"value_coupling_parents": None,
"input_precision": 1e4,
"expected_precision": 1e4,
"time_step": 0.0,
"value": 0.0,
"values": 0.0,
"surprise": 0.0,
"observed": 0,
"temp": {
Expand All @@ -482,7 +493,7 @@ def add_nodes(
"eta0": 0.0,
"eta1": 1.0,
"time_step": 0.0,
"value": 0.0,
"values": 0.0,
"observed": 0,
"surprise": 0.0,
}
Expand Down Expand Up @@ -520,9 +531,15 @@ def add_nodes(
"pe": jnp.zeros(n_categories),
"xi": jnp.array([1.0 / n_categories] * n_categories),
"mean": jnp.array([1.0 / n_categories] * n_categories),
"value": jnp.zeros(n_categories),
"values": jnp.zeros(n_categories),
"binary_parameters": binary_parameters,
}
elif "ef-normal" in kind:
default_parameters = {
"nus": 0.0,
"xis": jnp.array([0.0, 0.0]),
"values": 0.0,
}

if bool(additional_parameters):
# ensure that all passed values are valid keys
Expand Down Expand Up @@ -550,10 +567,21 @@ def add_nodes(
node_parameters = default_parameters

if "input" in kind:
# "continuous": 0, "binary": 1, "categorical": 2, "generic": 3
input_type = input_types[kind.split("-")[0]]
else:
input_type = None

# define the type of node that is created
if "input" in kind:
node_type = 0
elif "binary-state" in kind:
node_type = 1
elif "continuous-state" in kind:
node_type = 2
elif "ef-normal" in kind:
node_type = 3

# convert the structure to a list to modify it
edges_as_list: List[AdjacencyLists] = list(self.edges)

Expand All @@ -563,7 +591,11 @@ def add_nodes(
# add a new edge
edges_as_list.append(
AdjacencyLists(
couplings[1][0], couplings[3][0], couplings[0][0], couplings[2][0]
node_type,
couplings[1][0],
couplings[3][0],
couplings[0][0],
couplings[2][0],
)
)

Expand Down Expand Up @@ -601,6 +633,7 @@ def add_nodes(
):
# unpack this node's edges
(
this_node_type,
value_parents,
volatility_parents,
value_children,
Expand Down Expand Up @@ -655,6 +688,7 @@ def add_nodes(

# save the updated edges back
edges_as_list[idx] = AdjacencyLists(
this_node_type,
value_parents,
volatility_parents,
value_children,
Expand Down
2 changes: 2 additions & 0 deletions src/pyhgf/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def plot_network(network: "Network") -> "Source":
label, shape = f"Bi-{idx}", "box"
elif kind == 2:
label, shape = f"Ca-{idx}", "diamond"
elif kind == 3:
label, shape = f"Ge-{idx}", "point"
graphviz_structure.node(
f"x_{idx}",
label=label,
Expand Down
6 changes: 3 additions & 3 deletions src/pyhgf/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def first_level_gaussian_surprise(
# the input value at time t is compared to the Gaussian prediction at t-1
surprise = jnp.sum(
gaussian_surprise(
x=hgf.node_trajectories[0]["value"],
x=hgf.node_trajectories[0]["values"],
expected_mean=hgf.node_trajectories[1]["expected_mean"],
expected_precision=hgf.node_trajectories[1]["expected_precision"],
)
Expand Down Expand Up @@ -85,7 +85,7 @@ def total_gaussian_surprise(
input_parents_list.append(va_pa)
surprise += jnp.sum(
gaussian_surprise(
x=hgf.node_trajectories[idx]["value"],
x=hgf.node_trajectories[idx]["values"],
expected_mean=hgf.node_trajectories[va_pa]["expected_mean"],
expected_precision=hgf.node_trajectories[va_pa]["expected_precision"],
)
Expand Down Expand Up @@ -136,7 +136,7 @@ def first_level_binary_surprise(
"""
surprise = binary_surprise(
expected_mean=hgf.node_trajectories[1]["expected_mean"],
x=hgf.node_trajectories[0]["value"],
x=hgf.node_trajectories[0]["values"],
)

# Return an infinite surprise if the model cannot fit
Expand Down
14 changes: 12 additions & 2 deletions src/pyhgf/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,18 @@


class AdjacencyLists(NamedTuple):
"""Indexes to a node's value and volatility parents."""
"""Indexes to a node's value and volatility parents.
The variable `node_type` encode the type of state node:
* 0: input node.
* 1: binary state node.
* 2: continuous state node.
* 3: exponential family state node - univariate Gaussian distribution with unknown
mean and unknown variance.
"""

node_type: int
value_parents: Optional[Tuple]
volatility_parents: Optional[Tuple]
value_children: Optional[Tuple]
Expand Down Expand Up @@ -35,4 +45,4 @@ class Inputs(NamedTuple):
NetworkParameters = Tuple[Attributes, Structure, UpdateSequence]

# encoding input types using intergers
input_types = {"continuous": 0, "binary": 1, "categorical": 2}
input_types = {"continuous": 0, "binary": 1, "categorical": 2, "generic": 3}
2 changes: 1 addition & 1 deletion src/pyhgf/updates/posterior/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def binary_node_update_infinite(
"""
value_child_idx = edges[node_idx].value_children[0] # type: ignore

attributes[node_idx]["mean"] = attributes[value_child_idx]["value"]
attributes[node_idx]["mean"] = attributes[value_child_idx]["values"]
attributes[node_idx]["precision"] = attributes[value_child_idx][
"expected_precision"
]
Expand Down
8 changes: 4 additions & 4 deletions src/pyhgf/updates/posterior/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ def categorical_input_update(
alpha = jnp.where(jnp.isnan(alpha), 1.0, alpha)

# now retrieve the values observed at time k
attributes[node_idx]["value"] = jnp.array(
attributes[node_idx]["values"] = jnp.array(
[
attributes[vapa]["value"]
attributes[vapa]["values"]
for vapa in edges[node_idx].value_parents # type: ignore
]
)

# compute the prediction error at time K
pe = attributes[node_idx]["value"] - new_xi
pe = attributes[node_idx]["values"] - new_xi
attributes[node_idx]["pe"] = pe # keep PE for later use at k+1
attributes[node_idx]["xi"] = new_xi # keep expectation for later use at k+1

Expand All @@ -90,7 +90,7 @@ def categorical_input_update(
)
attributes[node_idx]["surprise"] = jnp.sum(
binary_surprise(
x=attributes[node_idx]["value"], expected_mean=attributes[node_idx]["xi"]
x=attributes[node_idx]["values"], expected_mean=attributes[node_idx]["xi"]
)
)

Expand Down
59 changes: 59 additions & 0 deletions src/pyhgf/updates/posterior/exponential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import Callable, Dict

from jax import jit

from pyhgf.typing import Attributes, Edges


@partial(jit, static_argnames=("edges", "node_idx", "sufficient_stats_fn"))
def posterior_update_exponential_family(
attributes: Dict, edges: Edges, node_idx: int, sufficient_stats_fn: Callable, **args
) -> Attributes:
r"""Update the parameters of an exponential family distribution.
Assuming that :math:`nu` is fixed, updating the hyperparameters of the distribution
is given by:
.. math::
\xi \leftarrow \xi + \frac{1}{\nu + 1}(t(x)-\xi)
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the node
number. For each node, the index lists the value and volatility parents and
children.
node_idx :
Pointer to the value parent node that will be updated.
sufficient_stats_fn :
Compute the sufficient statistics of the probability distribution. This should
be one of the method implemented in the distribution class in
:py:class:`pyhgf.math.Normal`, for a univariate normal.
Returns
-------
attributes :
The updated attributes of the probabilistic nodes.
References
----------
.. [1] Mathys, C., & Weber, L. (2020). Hierarchical Gaussian Filtering of Sufficient
Statistic Time Series for Active Inference. In Active Inference (pp. 52–58).
Springer International Publishing. https://doi.org/10.1007/978-3-030-64919-7_7
"""
# update the hyperparameter vectors
attributes[node_idx]["xis"] = attributes[node_idx]["xis"] + (
1 / (1 + attributes[node_idx]["nus"])
) * (
sufficient_stats_fn(attributes[node_idx]["values"])
- attributes[node_idx]["xis"]
)

return attributes
4 changes: 2 additions & 2 deletions src/pyhgf/updates/prediction_error/inputs/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def binary_input_prediction_error_infinite_precision(
"""
# store value and time step in the node's parameters
attributes[node_idx]["value"] = value
attributes[node_idx]["values"] = value
attributes[node_idx]["observed"] = observed
attributes[node_idx]["time_step"] = time_step

Expand Down Expand Up @@ -97,7 +97,7 @@ def binary_input_prediction_error_finite_precision(
"""
# store value and time step in the node's parameters
attributes[node_idx]["value"] = value
attributes[node_idx]["values"] = value
attributes[node_idx]["time_step"] = time_step

# Read parameters from the binary input
Expand Down
6 changes: 3 additions & 3 deletions src/pyhgf/updates/prediction_error/inputs/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def continuous_input_value_prediction_error(

# the prediction error is computed using the expected mean from the value parent
value_prediction_error = (
attributes[node_idx]["value"] - attributes[value_parent_idx]["expected_mean"]
attributes[node_idx]["values"] - attributes[value_parent_idx]["expected_mean"]
)

# expected precision from the input node
Expand All @@ -172,7 +172,7 @@ def continuous_input_value_prediction_error(

# compute the Gaussian surprise for the input node
attributes[node_idx]["surprise"] = gaussian_surprise(
x=attributes[node_idx]["value"],
x=attributes[node_idx]["values"],
expected_mean=attributes[value_parent_idx]["expected_mean"],
expected_precision=expected_precision,
)
Expand Down Expand Up @@ -230,7 +230,7 @@ def continuous_input_prediction_error(
"""
# store value and time step in the node's parameters
attributes[node_idx]["value"] = value
attributes[node_idx]["values"] = value
attributes[node_idx]["observed"] = observed
attributes[node_idx]["time_step"] = time_step

Expand Down
Loading

0 comments on commit 212d29f

Please sign in to comment.