diff --git a/src/pyhgf/response.py b/src/pyhgf/response.py index 5db5f84d..892e7106 100644 --- a/src/pyhgf/response.py +++ b/src/pyhgf/response.py @@ -167,7 +167,7 @@ def binary_softmax( """ # the expected values at the first level of the HGF - beliefs = hgf.node_trajectories[1]["expected_mean"] + beliefs = hgf.node_trajectories[0]["expected_mean"] # the binary surprises surprise = binary_surprise(x=response_function_inputs, expected_mean=beliefs) @@ -213,10 +213,10 @@ def binary_softmax_inverse_temperature( """ # the expected values at the first level of the HGF beliefs = ( - hgf.node_trajectories[1]["expected_mean"] ** response_function_parameters + hgf.node_trajectories[0]["expected_mean"] ** response_function_parameters ) / ( - hgf.node_trajectories[1]["expected_mean"] ** response_function_parameters - + (1 - hgf.node_trajectories[1]["expected_mean"]) + hgf.node_trajectories[0]["expected_mean"] ** response_function_parameters + + (1 - hgf.node_trajectories[0]["expected_mean"]) ** response_function_parameters ) diff --git a/src/pyhgf/updates/posterior/binary.py b/src/pyhgf/updates/posterior/binary.py deleted file mode 100644 index 2e694ef3..00000000 --- a/src/pyhgf/updates/posterior/binary.py +++ /dev/null @@ -1,63 +0,0 @@ -# Author: Nicolas Legrand - -from functools import partial -from typing import Dict - -import jax.numpy as jnp -from jax import jit - -from pyhgf.typing import Edges - - -@partial(jit, static_argnames=("edges", "node_idx")) -def binary_node_posterior_update_finite( - attributes: Dict, node_idx: int, edges: Edges, **args -) -> Dict: - """Update the posterior of a binary node given finite precision of the input. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - node_idx : - Pointer to the node that needs to be updated. After continuous updates, the - parameters of value and volatility parents (if any) will be different. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the node - number. For each node, the index lists the value and volatility parents and - children. - - Returns - ------- - attributes : - The updated attributes of the probabilistic nodes. - - References - ---------- - .. [1] Weber, L. A., Waade, P. T., Legrand, N., Møller, A. H., Stephan, K. E., & - Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1). - arXiv. https://doi.org/10.48550/ARXIV.2305.10937 - - """ - value_child_idx = edges[node_idx].value_children[0] # type: ignore - - delata0 = attributes[value_child_idx]["temp"]["value_prediction_error_0"] - delata1 = attributes[value_child_idx]["temp"]["value_prediction_error_1"] - expected_precision = attributes[value_child_idx]["expected_precision"] - - # Likelihood under eta1 - und1 = jnp.exp(-expected_precision / 2 * delata1**2) - - # Likelihood under eta0 - und0 = jnp.exp(-expected_precision / 2 * delata0**2) - - # Eq. 39 in Mathys et al. (2014) (i.e., Bayes) - expected_mean = attributes[node_idx]["expected_mean"] - mean = expected_mean * und1 / (expected_mean * und1 + (1 - expected_mean) * und0) - precision = 1 / (expected_mean * (1 - expected_mean)) - - attributes[node_idx]["mean"] = mean - attributes[node_idx]["precision"] = precision - - return attributes diff --git a/src/pyhgf/updates/prediction_error/binary.py b/src/pyhgf/updates/prediction_error/binary.py index a4a77a24..8882b75d 100644 --- a/src/pyhgf/updates/prediction_error/binary.py +++ b/src/pyhgf/updates/prediction_error/binary.py @@ -3,8 +3,11 @@ from functools import partial from typing import Dict +import jax.numpy as jnp from jax import Array, jit +from pyhgf.typing import Edges + @partial(jit, static_argnames=("node_idx")) def binary_state_node_prediction_error( @@ -41,3 +44,57 @@ def binary_state_node_prediction_error( attributes[node_idx]["temp"]["value_prediction_error"] = value_prediction_error return attributes + + +@partial(jit, static_argnames=("edges", "node_idx")) +def binary_finite_state_node_prediction_error( + attributes: Dict, node_idx: int, edges: Edges, **args +) -> Dict: + """Update the posterior of a binary node given finite precision of the input. + + Parameters + ---------- + attributes : + The attributes of the probabilistic nodes. + node_idx : + Pointer to the node that needs to be updated. After continuous updates, the + parameters of value and volatility parents (if any) will be different. + edges : + The edges of the probabilistic nodes as a tuple of + :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the node + number. For each node, the index lists the value and volatility parents and + children. + + Returns + ------- + attributes : + The updated attributes of the probabilistic nodes. + + References + ---------- + .. [1] Weber, L. A., Waade, P. T., Legrand, N., Møller, A. H., Stephan, K. E., & + Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1). + arXiv. https://doi.org/10.48550/ARXIV.2305.10937 + + """ + value_child_idx = edges[node_idx].value_children[0] # type: ignore + + delata0 = attributes[value_child_idx]["temp"]["value_prediction_error_0"] + delata1 = attributes[value_child_idx]["temp"]["value_prediction_error_1"] + expected_precision = attributes[value_child_idx]["expected_precision"] + + # Likelihood under eta1 + und1 = jnp.exp(-expected_precision / 2 * delata1**2) + + # Likelihood under eta0 + und0 = jnp.exp(-expected_precision / 2 * delata0**2) + + # Eq. 39 in Mathys et al. (2014) (i.e., Bayes) + expected_mean = attributes[node_idx]["expected_mean"] + mean = expected_mean * und1 / (expected_mean * und1 + (1 - expected_mean) * und0) + precision = 1 / (expected_mean * (1 - expected_mean)) + + attributes[node_idx]["mean"] = mean + attributes[node_idx]["precision"] = precision + + return attributes diff --git a/tests/test_binary.py b/tests/test_binary.py index d1949cb3..eb057c51 100644 --- a/tests/test_binary.py +++ b/tests/test_binary.py @@ -46,7 +46,7 @@ def test_update_binary_input_parents(): # apply sequence new_attributes, _ = beliefs_propagation( attributes=binary_hgf.attributes, - input_data=(data, time_steps, observed), + inputs=(data, time_steps, observed), update_sequence=binary_hgf.update_sequence, edges=binary_hgf.edges, input_idxs=(0,), diff --git a/tests/test_distribution.py b/tests/test_distribution.py index 5a632468..3b0acb3f 100644 --- a/tests/test_distribution.py +++ b/tests/test_distribution.py @@ -264,14 +264,10 @@ def test_hgf_logp(): tonic_volatility_1=np.nan, tonic_volatility_2=-2.0, tonic_volatility_3=-6.0, - input_precision=np.inf, - tonic_drift_1=0.0, - tonic_drift_2=0.0, - tonic_drift_3=0.0, precision_1=np.nan, precision_2=1.0, precision_3=1.0, - mean_1=0.0, + mean_1=0.5, mean_2=0.0, mean_3=0.0, volatility_coupling_1=1.0, @@ -562,7 +558,7 @@ def logp(value, tonic_volatility_2): with pm.Model() as model: y_data = pm.Data("y_data", y) - tonic_volatility_2 = pm.Normal("tonic_volatility_2", -11.0, 2) + tonic_volatility_2 = pm.Normal("tonic_volatility_2", -3.0, 2) pm.CustomDist("likelihood", tonic_volatility_2, logp=logp, observed=y_data) initial_point = model.initial_point()