Skip to content

Commit

Permalink
test distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 26, 2024
1 parent 407d299 commit fc6d117
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 74 deletions.
8 changes: 4 additions & 4 deletions src/pyhgf/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def binary_softmax(
"""
# the expected values at the first level of the HGF
beliefs = hgf.node_trajectories[1]["expected_mean"]
beliefs = hgf.node_trajectories[0]["expected_mean"]

# the binary surprises
surprise = binary_surprise(x=response_function_inputs, expected_mean=beliefs)
Expand Down Expand Up @@ -213,10 +213,10 @@ def binary_softmax_inverse_temperature(
"""
# the expected values at the first level of the HGF
beliefs = (
hgf.node_trajectories[1]["expected_mean"] ** response_function_parameters
hgf.node_trajectories[0]["expected_mean"] ** response_function_parameters
) / (
hgf.node_trajectories[1]["expected_mean"] ** response_function_parameters
+ (1 - hgf.node_trajectories[1]["expected_mean"])
hgf.node_trajectories[0]["expected_mean"] ** response_function_parameters
+ (1 - hgf.node_trajectories[0]["expected_mean"])
** response_function_parameters
)

Expand Down
63 changes: 0 additions & 63 deletions src/pyhgf/updates/posterior/binary.py

This file was deleted.

57 changes: 57 additions & 0 deletions src/pyhgf/updates/prediction_error/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from functools import partial
from typing import Dict

import jax.numpy as jnp
from jax import Array, jit

from pyhgf.typing import Edges


@partial(jit, static_argnames=("node_idx"))
def binary_state_node_prediction_error(
Expand Down Expand Up @@ -41,3 +44,57 @@ def binary_state_node_prediction_error(
attributes[node_idx]["temp"]["value_prediction_error"] = value_prediction_error

return attributes


@partial(jit, static_argnames=("edges", "node_idx"))
def binary_finite_state_node_prediction_error(
attributes: Dict, node_idx: int, edges: Edges, **args
) -> Dict:
"""Update the posterior of a binary node given finite precision of the input.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
node_idx :
Pointer to the node that needs to be updated. After continuous updates, the
parameters of value and volatility parents (if any) will be different.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the node
number. For each node, the index lists the value and volatility parents and
children.
Returns
-------
attributes :
The updated attributes of the probabilistic nodes.
References
----------
.. [1] Weber, L. A., Waade, P. T., Legrand, N., Møller, A. H., Stephan, K. E., &
Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1).
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
"""
value_child_idx = edges[node_idx].value_children[0] # type: ignore

delata0 = attributes[value_child_idx]["temp"]["value_prediction_error_0"]
delata1 = attributes[value_child_idx]["temp"]["value_prediction_error_1"]
expected_precision = attributes[value_child_idx]["expected_precision"]

# Likelihood under eta1
und1 = jnp.exp(-expected_precision / 2 * delata1**2)

# Likelihood under eta0
und0 = jnp.exp(-expected_precision / 2 * delata0**2)

# Eq. 39 in Mathys et al. (2014) (i.e., Bayes)
expected_mean = attributes[node_idx]["expected_mean"]
mean = expected_mean * und1 / (expected_mean * und1 + (1 - expected_mean) * und0)
precision = 1 / (expected_mean * (1 - expected_mean))

attributes[node_idx]["mean"] = mean
attributes[node_idx]["precision"] = precision

return attributes
2 changes: 1 addition & 1 deletion tests/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_update_binary_input_parents():
# apply sequence
new_attributes, _ = beliefs_propagation(
attributes=binary_hgf.attributes,
input_data=(data, time_steps, observed),
inputs=(data, time_steps, observed),
update_sequence=binary_hgf.update_sequence,
edges=binary_hgf.edges,
input_idxs=(0,),
Expand Down
8 changes: 2 additions & 6 deletions tests/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,10 @@ def test_hgf_logp():
tonic_volatility_1=np.nan,
tonic_volatility_2=-2.0,
tonic_volatility_3=-6.0,
input_precision=np.inf,
tonic_drift_1=0.0,
tonic_drift_2=0.0,
tonic_drift_3=0.0,
precision_1=np.nan,
precision_2=1.0,
precision_3=1.0,
mean_1=0.0,
mean_1=0.5,
mean_2=0.0,
mean_3=0.0,
volatility_coupling_1=1.0,
Expand Down Expand Up @@ -562,7 +558,7 @@ def logp(value, tonic_volatility_2):

with pm.Model() as model:
y_data = pm.Data("y_data", y)
tonic_volatility_2 = pm.Normal("tonic_volatility_2", -11.0, 2)
tonic_volatility_2 = pm.Normal("tonic_volatility_2", -3.0, 2)
pm.CustomDist("likelihood", tonic_volatility_2, logp=logp, observed=y_data)

initial_point = model.initial_point()
Expand Down

0 comments on commit fc6d117

Please sign in to comment.