diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c883e617..48e5fe8d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,12 +9,12 @@ repos: hooks: - id: isort - repo: https://github.com/ambv/black - rev: 24.4.2 + rev: 24.8.0 hooks: - id: black language_version: python3 - repo: https://github.com/pycqa/flake8 - rev: 7.1.0 + rev: 7.1.1 hooks: - id: flake8 - repo: https://github.com/pycqa/pydocstyle @@ -24,7 +24,7 @@ repos: args: ['--ignore', 'D213,D100,D203,D104'] files: ^src/ - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.10.1' + rev: 'v1.11.1' hooks: - id: mypy files: ^src/ diff --git a/tests/test_binary.py b/tests/test_binary.py index a70f813f..1505ec70 100644 --- a/tests/test_binary.py +++ b/tests/test_binary.py @@ -1,8 +1,5 @@ # Author: Nicolas Legrand -import unittest -from unittest import TestCase - import jax.numpy as jnp from jax.lax import scan from jax.tree_util import Partial @@ -26,193 +23,191 @@ from pyhgf.utils import beliefs_propagation -class Testbinary(TestCase): - def test_gaussian_density(self): - surprise = gaussian_density( - x=jnp.array([1.0, 1.0]), - mean=jnp.array([0.0, 0.0]), - precision=jnp.array([1.0, 1.0]), - ) - assert jnp.all(jnp.isclose(surprise, 0.24197073)) - - def test_sgm(self): - assert jnp.all(jnp.isclose(sigmoid(jnp.array([0.3, 0.3])), 0.5744425)) - - def test_binary_surprise(self): - surprise = binary_surprise( - x=jnp.array([1.0]), - expected_mean=jnp.array([0.2]), - ) - assert jnp.all(jnp.isclose(surprise, 1.609438)) - - def test_update_binary_input_parents(self): - ########################## - # three level binary HGF # - ########################## - input_node_parameters = { - "expected_precision": jnp.inf, - "eta0": 0.0, - "eta1": 1.0, - "surprise": 0.0, - "time_step": 0.0, - "values": 0.0, - "observed": 1, - "volatility_coupling_parents": None, - "value_coupling_parents": (1.0,), - } - node_parameters_1 = { - "expected_precision": 1.0, - "precision": 1.0, - "expected_mean": 1.0, - "value_coupling_children": (1.0,), - "value_coupling_parents": (1.0,), - "volatility_coupling_parents": None, - "volatility_coupling_children": None, - "autoconnection_strength": 1.0, - "mean": 1.0, - "observed": 1, - "tonic_volatility": 1.0, - "tonic_drift": 0.0, - "binary_expected_precision": jnp.nan, - "temp": { - "value_prediction_error": 0.0, - }, - } - node_parameters_2 = { - "expected_precision": 1.0, - "precision": 1.0, - "expected_mean": 1.0, - "value_coupling_children": (1.0,), - "value_coupling_parents": None, - "volatility_coupling_parents": (1.0,), - "volatility_coupling_children": None, - "autoconnection_strength": 1.0, - "mean": 1.0, - "observed": 1, - "tonic_volatility": 1.0, - "tonic_drift": 0.0, - "temp": { - "effective_precision": 1.0, - "value_prediction_error": 0.0, - "volatility_prediction_error": 0.0, - }, - } - node_parameters_3 = { - "expected_precision": 1.0, - "precision": 1.0, - "expected_mean": 1.0, - "value_coupling_children": None, - "value_coupling_parents": None, - "volatility_coupling_parents": None, - "volatility_coupling_children": (1.0,), - "autoconnection_strength": 1.0, - "mean": 1.0, - "observed": 1, - "tonic_volatility": 1.0, - "tonic_drift": 0.0, - "temp": { - "effective_precision": 1.0, - "value_prediction_error": 0.0, - "volatility_prediction_error": 0.0, - }, - } - - edges = ( - AdjacencyLists(0, (1,), None, None, None, (None,)), - AdjacencyLists(1, (2,), None, (0,), None, (None,)), - AdjacencyLists(2, None, (3,), (1,), None, (None,)), - AdjacencyLists(2, None, None, None, (2,), (None,)), - ) - attributes = { - 0: input_node_parameters, - 1: node_parameters_1, - 2: node_parameters_2, - 3: node_parameters_3, - } - - # create update sequence - sequence1 = 3, continuous_node_prediction - sequence2 = 2, continuous_node_prediction - sequence3 = 1, binary_state_node_prediction - sequence4 = 0, binary_input_prediction_error_infinite_precision - sequence5 = 1, binary_node_update_infinite - sequence6 = 1, binary_state_node_prediction_error - sequence7 = 2, continuous_node_update - sequence8 = 2, continuous_node_prediction_error - sequence9 = 3, continuous_node_update - update_sequence = ( - sequence1, - sequence2, - sequence3, - sequence4, - sequence5, - sequence6, - sequence7, - sequence8, - sequence9, - ) - data = jnp.ones(1) - time_steps = jnp.ones(1) - observed = jnp.ones(1) - inputs = Inputs(0, 1) - - # apply sequence - new_attributes, _ = beliefs_propagation( - structure=(inputs, edges), - attributes=attributes, - update_sequence=update_sequence, - input_data=(data, time_steps, observed), - ) - for idx, val in zip( - ["mean", "expected_mean", "binary_expected_precision"], - [1.0, 0.7310586, 5.0861616], - ): - assert jnp.isclose(new_attributes[1][idx], val) - for idx, val in zip( - ["mean", "expected_mean", "precision", "expected_precision"], - [1.8515793, 1.0, 0.31581485, 0.11920292], - ): - assert jnp.isclose(new_attributes[2][idx], val) - for idx, val in zip( - ["mean", "expected_mean", "precision", "expected_precision"], - [0.5050575, 1.0, 0.47702926, 0.26894143], - ): - assert jnp.isclose(new_attributes[3][idx], val) - - # use scan - u, _ = load_data("binary") - - # Create the data (value and time steps vectors) - only use the 5 first trials - # as the priors are ill defined here - data = jnp.array([u[:5]]).T - time_steps = jnp.ones((len(u[:5]), 1)) - observed = jnp.ones((len(u[:5]), 1)) - inputs = Inputs(0, 1) - - # create the function that will be scaned - scan_fn = Partial( - beliefs_propagation, - update_sequence=update_sequence, - structure=(inputs, edges), - ) - - # Run the entire for loop - last, _ = scan(scan_fn, attributes, (data, time_steps, observed)) - for idx, val in zip( - ["mean", "expected_mean", "binary_expected_precision"], - [0.0, 0.95616907, 23.860779], - ): - assert jnp.isclose(last[1][idx], val) - for idx, val in zip( - ["mean", "expected_mean", "precision", "expected_precision"], - [-2.1582031, 3.0825963, 0.18244718, 0.1405374], - ): - assert jnp.isclose(last[2][idx], val) - for idx, val in zip( - ["expected_mean", "expected_precision"], [-0.30260748, 0.14332297] - ): - assert jnp.isclose(last[3][idx], val) - - -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) +def test_gaussian_density(): + surprise = gaussian_density( + x=jnp.array([1.0, 1.0]), + mean=jnp.array([0.0, 0.0]), + precision=jnp.array([1.0, 1.0]), + ) + assert jnp.all(jnp.isclose(surprise, 0.24197073)) + + +def test_sgm(): + assert jnp.all(jnp.isclose(sigmoid(jnp.array([0.3, 0.3])), 0.5744425)) + + +def test_binary_surprise(): + surprise = binary_surprise( + x=jnp.array([1.0]), + expected_mean=jnp.array([0.2]), + ) + assert jnp.all(jnp.isclose(surprise, 1.609438)) + + +def test_update_binary_input_parents(): + ########################## + # three level binary HGF # + ########################## + input_node_parameters = { + "expected_precision": jnp.inf, + "eta0": 0.0, + "eta1": 1.0, + "surprise": 0.0, + "time_step": 0.0, + "values": 0.0, + "observed": 1, + "volatility_coupling_parents": None, + "value_coupling_parents": (1.0,), + } + node_parameters_1 = { + "expected_precision": 1.0, + "precision": 1.0, + "expected_mean": 1.0, + "value_coupling_children": (1.0,), + "value_coupling_parents": (1.0,), + "volatility_coupling_parents": None, + "volatility_coupling_children": None, + "autoconnection_strength": 1.0, + "mean": 1.0, + "observed": 1, + "tonic_volatility": 1.0, + "tonic_drift": 0.0, + "binary_expected_precision": jnp.nan, + "temp": { + "value_prediction_error": 0.0, + }, + } + node_parameters_2 = { + "expected_precision": 1.0, + "precision": 1.0, + "expected_mean": 1.0, + "value_coupling_children": (1.0,), + "value_coupling_parents": None, + "volatility_coupling_parents": (1.0,), + "volatility_coupling_children": None, + "autoconnection_strength": 1.0, + "mean": 1.0, + "observed": 1, + "tonic_volatility": 1.0, + "tonic_drift": 0.0, + "temp": { + "effective_precision": 1.0, + "value_prediction_error": 0.0, + "volatility_prediction_error": 0.0, + }, + } + node_parameters_3 = { + "expected_precision": 1.0, + "precision": 1.0, + "expected_mean": 1.0, + "value_coupling_children": None, + "value_coupling_parents": None, + "volatility_coupling_parents": None, + "volatility_coupling_children": (1.0,), + "autoconnection_strength": 1.0, + "mean": 1.0, + "observed": 1, + "tonic_volatility": 1.0, + "tonic_drift": 0.0, + "temp": { + "effective_precision": 1.0, + "value_prediction_error": 0.0, + "volatility_prediction_error": 0.0, + }, + } + + edges = ( + AdjacencyLists(0, (1,), None, None, None, (None,)), + AdjacencyLists(1, (2,), None, (0,), None, (None,)), + AdjacencyLists(2, None, (3,), (1,), None, (None,)), + AdjacencyLists(2, None, None, None, (2,), (None,)), + ) + attributes = { + 0: input_node_parameters, + 1: node_parameters_1, + 2: node_parameters_2, + 3: node_parameters_3, + } + + # create update sequence + sequence1 = 3, continuous_node_prediction + sequence2 = 2, continuous_node_prediction + sequence3 = 1, binary_state_node_prediction + sequence4 = 0, binary_input_prediction_error_infinite_precision + sequence5 = 1, binary_node_update_infinite + sequence6 = 1, binary_state_node_prediction_error + sequence7 = 2, continuous_node_update + sequence8 = 2, continuous_node_prediction_error + sequence9 = 3, continuous_node_update + update_sequence = ( + sequence1, + sequence2, + sequence3, + sequence4, + sequence5, + sequence6, + sequence7, + sequence8, + sequence9, + ) + data = jnp.ones(1) + time_steps = jnp.ones(1) + observed = jnp.ones(1) + inputs = Inputs(0, 1) + + # apply sequence + new_attributes, _ = beliefs_propagation( + structure=(inputs, edges), + attributes=attributes, + update_sequence=update_sequence, + input_data=(data, time_steps, observed), + ) + for idx, val in zip( + ["mean", "expected_mean", "binary_expected_precision"], + [1.0, 0.7310586, 5.0861616], + ): + assert jnp.isclose(new_attributes[1][idx], val) + for idx, val in zip( + ["mean", "expected_mean", "precision", "expected_precision"], + [1.8515793, 1.0, 0.31581485, 0.11920292], + ): + assert jnp.isclose(new_attributes[2][idx], val) + for idx, val in zip( + ["mean", "expected_mean", "precision", "expected_precision"], + [0.5050575, 1.0, 0.47702926, 0.26894143], + ): + assert jnp.isclose(new_attributes[3][idx], val) + + # use scan + u, _ = load_data("binary") + + # Create the data (value and time steps vectors) - only use the 5 first trials + # as the priors are ill defined here + data = jnp.array([u[:5]]).T + time_steps = jnp.ones((len(u[:5]), 1)) + observed = jnp.ones((len(u[:5]), 1)) + inputs = Inputs(0, 1) + + # create the function that will be scaned + scan_fn = Partial( + beliefs_propagation, + update_sequence=update_sequence, + structure=(inputs, edges), + ) + + # Run the entire for loop + last, _ = scan(scan_fn, attributes, (data, time_steps, observed)) + for idx, val in zip( + ["mean", "expected_mean", "binary_expected_precision"], + [0.0, 0.95616907, 23.860779], + ): + assert jnp.isclose(last[1][idx], val) + for idx, val in zip( + ["mean", "expected_mean", "precision", "expected_precision"], + [-2.1582031, 3.0825963, 0.18244718, 0.1405374], + ): + assert jnp.isclose(last[2][idx], val) + for idx, val in zip( + ["expected_mean", "expected_precision"], [-0.30260748, 0.14332297] + ): + assert jnp.isclose(last[3][idx], val) diff --git a/tests/test_categorical.py b/tests/test_categorical.py index 79ce228c..e08eafc0 100644 --- a/tests/test_categorical.py +++ b/tests/test_categorical.py @@ -1,36 +1,28 @@ # Author: Nicolas Legrand -import unittest -from unittest import TestCase - import numpy as np from pyhgf.model import HGF -class Testbinary(TestCase): - def test_categorical_state_node(self): - # generate some categorical inputs data - input_data = np.array( - [np.random.multinomial(n=1, pvals=[0.1, 0.2, 0.7]) for _ in range(3)] - ).T - input_data = np.vstack([[0.0] * input_data.shape[1], input_data]) - - # create the categorical HGF - categorical_hgf = HGF(model_type=None, verbose=False).add_nodes( - kind="categorical-input", - node_parameters={ - "n_categories": 3, - "binary_parameters": {"tonic_volatility_2": -2.0}, - }, - ) - - # fitting the model forwards - categorical_hgf.input_data(input_data=input_data.T) +def test_categorical_state_node(): + # generate some categorical inputs data + input_data = np.array( + [np.random.multinomial(n=1, pvals=[0.1, 0.2, 0.7]) for _ in range(3)] + ).T + input_data = np.vstack([[0.0] * input_data.shape[1], input_data]) - # export to pandas data frame - categorical_hgf.to_pandas() + # create the categorical HGF + categorical_hgf = HGF(model_type=None, verbose=False).add_nodes( + kind="categorical-input", + node_parameters={ + "n_categories": 3, + "binary_parameters": {"tonic_volatility_2": -2.0}, + }, + ) + # fitting the model forwards + categorical_hgf.input_data(input_data=input_data.T) -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) + # export to pandas data frame + categorical_hgf.to_pandas() diff --git a/tests/test_continuous.py b/tests/test_continuous.py index 19552696..647d0e01 100644 --- a/tests/test_continuous.py +++ b/tests/test_continuous.py @@ -1,7 +1,5 @@ # Author: Nicolas Legrand -import unittest - import jax.numpy as jnp import pytest from jax.lax import scan @@ -290,7 +288,3 @@ def coupling_fn(x): [10000.982, 0.98201376, 0.19998036, 0.0], ): assert jnp.isclose(test_HGF.node_trajectories[1][idx], val) - - -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) diff --git a/tests/test_distribution.py b/tests/test_distribution.py index 1f9f8d0d..5a632468 100644 --- a/tests/test_distribution.py +++ b/tests/test_distribution.py @@ -1,8 +1,5 @@ # Author: Nicolas Legrand -import unittest -from unittest import TestCase - import arviz as az import jax.numpy as jnp import numpy as np @@ -25,556 +22,556 @@ ) -class TestDistribution(TestCase): - - def test_logp(self): - """Test the log-probability function for single model fit.""" - timeseries = load_data("continuous") - hgf = HGF(n_levels=2, model_type="continuous") - - log_likelihood = logp( - tonic_volatility_1=-3.0, - tonic_volatility_2=-3.0, - tonic_volatility_3=jnp.nan, - input_precision=np.array(1e4), - tonic_drift_1=0.0, - tonic_drift_2=0.0, - tonic_drift_3=jnp.nan, - precision_1=1e4, - precision_2=1e1, - precision_3=jnp.nan, - mean_1=1.0, - mean_2=0.0, - mean_3=jnp.nan, - volatility_coupling_1=1.0, - volatility_coupling_2=jnp.nan, - response_function_parameters=[jnp.nan], - input_data=timeseries, - time_steps=np.ones(shape=timeseries.shape), - response_function_inputs=jnp.nan, - response_function=first_level_gaussian_surprise, +def test_logp(): + """Test the log-probability function for single model fit.""" + timeseries = load_data("continuous") + hgf = HGF(n_levels=2, model_type="continuous") + + log_likelihood = logp( + tonic_volatility_1=-3.0, + tonic_volatility_2=-3.0, + tonic_volatility_3=jnp.nan, + input_precision=np.array(1e4), + tonic_drift_1=0.0, + tonic_drift_2=0.0, + tonic_drift_3=jnp.nan, + precision_1=1e4, + precision_2=1e1, + precision_3=jnp.nan, + mean_1=1.0, + mean_2=0.0, + mean_3=jnp.nan, + volatility_coupling_1=1.0, + volatility_coupling_2=jnp.nan, + response_function_parameters=[jnp.nan], + input_data=timeseries, + time_steps=np.ones(shape=timeseries.shape), + response_function_inputs=jnp.nan, + response_function=first_level_gaussian_surprise, + hgf=hgf, + ) + + assert jnp.isclose(log_likelihood.sum(), 1141.0585) + + +def test_vectorized_logp(): + """Test the vectorized version of the log-probability function.""" + + timeseries = load_data("continuous") + hgf = HGF(n_levels=2, model_type="continuous") + + # generate data with 2 parameters set + input_data = np.array([timeseries] * 2) + time_steps = np.ones(shape=input_data.shape) + + tonic_volatility_1 = -3.0 + tonic_volatility_2 = -3.0 + tonic_volatility_3 = jnp.nan + input_precision = np.array(1e4) + tonic_drift_1 = 0.0 + tonic_drift_2 = 0.0 + tonic_drift_3 = jnp.nan + precision_1 = 1e4 + precision_2 = 1e1 + precision_3 = jnp.nan + mean_1 = 1.0 + mean_2 = 0.0 + mean_3 = jnp.nan + volatility_coupling_1 = 1.0 + volatility_coupling_2 = jnp.nan + response_function_parameters = jnp.ones(2) + + # Broadcast inputs to an array with length n>=1 + ( + _tonic_volatility_1, + _tonic_volatility_2, + _tonic_volatility_3, + _input_precision, + _tonic_drift_1, + _tonic_drift_2, + _tonic_drift_3, + _precision_1, + _precision_2, + _precision_3, + _mean_1, + _mean_2, + _mean_3, + _volatility_coupling_1, + _volatility_coupling_2, + _, + ) = jnp.broadcast_arrays( + tonic_volatility_1, + tonic_volatility_2, + tonic_volatility_3, + input_precision, + tonic_drift_1, + tonic_drift_2, + tonic_drift_3, + precision_1, + precision_2, + precision_3, + mean_1, + mean_2, + mean_3, + volatility_coupling_1, + volatility_coupling_2, + jnp.zeros(2), + ) + + # create the vectorized version of the function + vectorized_logp_two_levels = vmap( + Partial( + logp, hgf=hgf, + response_function=first_level_gaussian_surprise, ) - - assert jnp.isclose(log_likelihood.sum(), 1141.0585) - - def test_vectorized_logp(self): - """Test the vectorized version of the log-probability function.""" - - timeseries = load_data("continuous") - hgf = HGF(n_levels=2, model_type="continuous") - - # generate data with 2 parameters set - input_data = np.array([timeseries] * 2) - time_steps = np.ones(shape=input_data.shape) - - tonic_volatility_1 = -3.0 - tonic_volatility_2 = -3.0 - tonic_volatility_3 = jnp.nan - input_precision = np.array(1e4) - tonic_drift_1 = 0.0 - tonic_drift_2 = 0.0 - tonic_drift_3 = jnp.nan - precision_1 = 1e4 - precision_2 = 1e1 - precision_3 = jnp.nan - mean_1 = 1.0 - mean_2 = 0.0 - mean_3 = jnp.nan - volatility_coupling_1 = 1.0 - volatility_coupling_2 = jnp.nan - response_function_parameters = jnp.ones(2) - - # Broadcast inputs to an array with length n>=1 - ( - _tonic_volatility_1, - _tonic_volatility_2, - _tonic_volatility_3, - _input_precision, - _tonic_drift_1, - _tonic_drift_2, - _tonic_drift_3, - _precision_1, - _precision_2, - _precision_3, - _mean_1, - _mean_2, - _mean_3, - _volatility_coupling_1, - _volatility_coupling_2, - _, - ) = jnp.broadcast_arrays( - tonic_volatility_1, - tonic_volatility_2, - tonic_volatility_3, - input_precision, - tonic_drift_1, - tonic_drift_2, - tonic_drift_3, - precision_1, - precision_2, - precision_3, - mean_1, - mean_2, - mean_3, - volatility_coupling_1, - volatility_coupling_2, - jnp.zeros(2), - ) - - # create the vectorized version of the function - vectorized_logp_two_levels = vmap( - Partial( - logp, - hgf=hgf, - response_function=first_level_gaussian_surprise, - ) - ) - - # model fit - log_likelihoods = vectorized_logp_two_levels( - input_data=input_data, - response_function_inputs=jnp.ones(2), - response_function_parameters=response_function_parameters, - time_steps=time_steps, - mean_1=_mean_1, - mean_2=_mean_2, - mean_3=_mean_3, - precision_1=_precision_1, - precision_2=_precision_2, - precision_3=_precision_3, - tonic_volatility_1=_tonic_volatility_1, - tonic_volatility_2=_tonic_volatility_2, - tonic_volatility_3=_tonic_volatility_3, - tonic_drift_1=_tonic_drift_1, - tonic_drift_2=_tonic_drift_2, - tonic_drift_3=_tonic_drift_3, - volatility_coupling_1=_volatility_coupling_1, - volatility_coupling_2=_volatility_coupling_2, - input_precision=_input_precision, + ) + + # model fit + log_likelihoods = vectorized_logp_two_levels( + input_data=input_data, + response_function_inputs=jnp.ones(2), + response_function_parameters=response_function_parameters, + time_steps=time_steps, + mean_1=_mean_1, + mean_2=_mean_2, + mean_3=_mean_3, + precision_1=_precision_1, + precision_2=_precision_2, + precision_3=_precision_3, + tonic_volatility_1=_tonic_volatility_1, + tonic_volatility_2=_tonic_volatility_2, + tonic_volatility_3=_tonic_volatility_3, + tonic_drift_1=_tonic_drift_1, + tonic_drift_2=_tonic_drift_2, + tonic_drift_3=_tonic_drift_3, + volatility_coupling_1=_volatility_coupling_1, + volatility_coupling_2=_volatility_coupling_2, + input_precision=_input_precision, + ) + + assert jnp.isclose(log_likelihoods.sum(), 2282.1165).all() + + +def test_hgf_logp(): + """Test the hgf_logp function used by Distribution Ops on three level models""" + + ############################## + # Three-level continuous HGF # + ############################## + timeseries = load_data("continuous") + continuous_hgf = HGF(n_levels=3, model_type="continuous") + + # generate data with 2 parameters set + input_data = np.array([timeseries] * 2) + time_steps = np.ones(shape=input_data.shape) + + # create the vectorized version of the function + vectorized_logp_three_levels = vmap( + Partial( + logp, + hgf=continuous_hgf, + response_function=first_level_gaussian_surprise, ) - - assert jnp.isclose(log_likelihoods.sum(), 2282.1165).all() - - def test_hgf_logp(self): - """Test the hgf_logp function used by Distribution Ops on three level models""" - - ############################## - # Three-level continuous HGF # - ############################## - timeseries = load_data("continuous") - continuous_hgf = HGF(n_levels=3, model_type="continuous") - - # generate data with 2 parameters set - input_data = np.array([timeseries] * 2) - time_steps = np.ones(shape=input_data.shape) - - # create the vectorized version of the function - vectorized_logp_three_levels = vmap( + ) + + sum_log_likelihoods, log_likelihoods = hgf_logp( + tonic_volatility_1=-3.0, + tonic_volatility_2=-3.0, + tonic_volatility_3=-3.0, + input_precision=1e4, + tonic_drift_1=0.0, + tonic_drift_2=0.0, + tonic_drift_3=0.0, + precision_1=1e4, + precision_2=1e1, + precision_3=1.0, + mean_1=1.0, + mean_2=0.0, + mean_3=0.0, + volatility_coupling_1=1.0, + volatility_coupling_2=1.0, + response_function_parameters=np.ones(2), + response_function_inputs=np.ones(2), + vectorized_logp=vectorized_logp_three_levels, + input_data=input_data, + time_steps=time_steps, + ) + assert sum_log_likelihoods == log_likelihoods.sum() + assert jnp.isclose(sum_log_likelihoods, 2269.6929).all() + + # test the gradient + grad_logp = jit( + grad( Partial( - logp, - hgf=continuous_hgf, - response_function=first_level_gaussian_surprise, - ) - ) - - sum_log_likelihoods, log_likelihoods = hgf_logp( - tonic_volatility_1=-3.0, - tonic_volatility_2=-3.0, - tonic_volatility_3=-3.0, - input_precision=1e4, - tonic_drift_1=0.0, - tonic_drift_2=0.0, - tonic_drift_3=0.0, - precision_1=1e4, - precision_2=1e1, - precision_3=1.0, - mean_1=1.0, - mean_2=0.0, - mean_3=0.0, - volatility_coupling_1=1.0, - volatility_coupling_2=1.0, - response_function_parameters=np.ones(2), - response_function_inputs=np.ones(2), - vectorized_logp=vectorized_logp_three_levels, - input_data=input_data, - time_steps=time_steps, - ) - assert sum_log_likelihoods == log_likelihoods.sum() - assert jnp.isclose(sum_log_likelihoods, 2269.6929).all() - - # test the gradient - grad_logp = jit( - grad( - Partial( - hgf_logp, - vectorized_logp=vectorized_logp_three_levels, - input_data=input_data, - time_steps=time_steps, - response_function_inputs=np.ones(2), - ), - argnums=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - has_aux=True, + hgf_logp, + vectorized_logp=vectorized_logp_three_levels, + input_data=input_data, + time_steps=time_steps, + response_function_inputs=np.ones(2), ), + argnums=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + has_aux=True, + ), + ) + + gradients, _ = grad_logp( + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + 1.0, + -3.0, + -3.0, + -3.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 1e4, + np.ones(2), + ) + assert jnp.isclose(gradients[0], 0.09576362) + assert jnp.isclose(gradients[1], -8.531818) + assert jnp.isclose(gradients[2], -189.85936) + + ########################## + # Three-level binary HGF # + ########################## + + # Create the data (value and time vectors) + u, y = load_data("binary") + binary_hgf = HGF(n_levels=3, model_type="binary") + + # generate data with 2 parameters set + input_data = np.array([u] * 2) + response_function_inputs = np.array([y] * 2) + time_steps = np.ones(shape=input_data.shape) + + # create the vectorized version of the function + vectorized_logp_three_levels = vmap( + Partial( + logp, + hgf=binary_hgf, + response_function=binary_softmax_inverse_temperature, ) - - gradients, _ = grad_logp( - 1.0, - 0.0, - 0.0, - 1.0, - 1.0, - 1.0, - -3.0, - -3.0, - -3.0, - 0.0, - 0.0, - 0.0, - 1.0, - 1.0, - 1e4, - np.ones(2), - ) - assert jnp.isclose(gradients[0], 0.09576362) - assert jnp.isclose(gradients[1], -8.531818) - assert jnp.isclose(gradients[2], -189.85936) - - ########################## - # Three-level binary HGF # - ########################## - - # Create the data (value and time vectors) - u, y = load_data("binary") - binary_hgf = HGF(n_levels=3, model_type="binary") - - # generate data with 2 parameters set - input_data = np.array([u] * 2) - response_function_inputs = np.array([y] * 2) - time_steps = np.ones(shape=input_data.shape) - - # create the vectorized version of the function - vectorized_logp_three_levels = vmap( + ) + + sum_log_likelihoods, log_likelihoods = hgf_logp( + vectorized_logp=vectorized_logp_three_levels, + tonic_volatility_1=np.nan, + tonic_volatility_2=-2.0, + tonic_volatility_3=-6.0, + input_precision=np.inf, + tonic_drift_1=0.0, + tonic_drift_2=0.0, + tonic_drift_3=0.0, + precision_1=np.nan, + precision_2=1.0, + precision_3=1.0, + mean_1=0.0, + mean_2=0.0, + mean_3=0.0, + volatility_coupling_1=1.0, + volatility_coupling_2=1.0, + response_function_parameters=np.array([1.0, 1.0]), + input_data=input_data, + response_function_inputs=response_function_inputs, + time_steps=time_steps, + ) + + assert sum_log_likelihoods == log_likelihoods.sum() + assert jnp.isclose(sum_log_likelihoods, -248.07889) + + # test the gradient + grad_logp = jit( + grad( Partial( - logp, - hgf=binary_hgf, - response_function=binary_softmax_inverse_temperature, - ) - ) - - sum_log_likelihoods, log_likelihoods = hgf_logp( - vectorized_logp=vectorized_logp_three_levels, - tonic_volatility_1=np.nan, - tonic_volatility_2=-2.0, - tonic_volatility_3=-6.0, - input_precision=np.inf, - tonic_drift_1=0.0, - tonic_drift_2=0.0, - tonic_drift_3=0.0, - precision_1=np.nan, - precision_2=1.0, - precision_3=1.0, - mean_1=0.0, - mean_2=0.0, - mean_3=0.0, - volatility_coupling_1=1.0, - volatility_coupling_2=1.0, - response_function_parameters=np.array([1.0, 1.0]), - input_data=input_data, - response_function_inputs=response_function_inputs, - time_steps=time_steps, - ) - - assert sum_log_likelihoods == log_likelihoods.sum() - assert jnp.isclose(sum_log_likelihoods, -248.07889) - - # test the gradient - grad_logp = jit( - grad( - Partial( - hgf_logp, - vectorized_logp=vectorized_logp_three_levels, - input_data=input_data, - time_steps=time_steps, - response_function_inputs=response_function_inputs, - ), - argnums=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - has_aux=True, + hgf_logp, + vectorized_logp=vectorized_logp_three_levels, + input_data=input_data, + time_steps=time_steps, + response_function_inputs=response_function_inputs, + ), + argnums=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + has_aux=True, + ), + ) + + gradients, _ = grad_logp( + 0.5, + 0.0, + 0.0, + 1.0, + 1.0, + 1.0, + -3.0, + -2.0, + -6.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + np.inf, + np.ones(2), + ) + assert jnp.isclose(gradients[0], 0.0) + assert jnp.isclose(gradients[1], 1.9314771) + assert jnp.isclose(gradients[2], 30.185408) + + +def test_pytensor_pointwise_logp(): + """Test the pytensor HGFPointwise op.""" + + ############## + # Binary HGF # + ############## + + # Create the data (value and time vectors) + u, y = load_data("binary") + + hgf_logp_op = HGFPointwise( + input_data=u[np.newaxis, :], + model_type="binary", + n_levels=2, + response_function=binary_softmax_inverse_temperature, + response_function_inputs=y[np.newaxis, :], + ) + + logp = hgf_logp_op( + tonic_volatility_1=np.inf, + tonic_volatility_2=-6.0, + tonic_volatility_3=np.inf, + input_precision=np.inf, + tonic_drift_1=0.0, + tonic_drift_2=0.0, + tonic_drift_3=np.inf, + precision_1=0.0, + precision_2=1e4, + precision_3=np.inf, + mean_1=np.inf, + mean_2=0.5, + mean_3=np.inf, + volatility_coupling_1=1.0, + volatility_coupling_2=np.inf, + response_function_parameters=np.array([1.0]), + ).eval() + + assert jnp.isclose(logp.sum(), -200.2442167699337) + + +def test_pytensor_logp(): + """Test the pytensor hgf_logp op.""" + + ################## + # Continuous HGF # + ################## + + # Create the data (value and time vectors) + timeseries = load_data("continuous") + + hgf_logp_op = HGFDistribution( + input_data=timeseries[np.newaxis, :], + model_type="continuous", + n_levels=2, + response_function=first_level_gaussian_surprise, + response_function_inputs=None, + ) + + logp = hgf_logp_op( + mean_1=np.array(1.0), + mean_2=np.array(0.0), + mean_3=np.array(0.0), + precision_1=np.array(1e4), + precision_2=np.array(1e1), + precision_3=np.array(0.0), + tonic_volatility_1=np.array(-3.0), + tonic_volatility_2=np.array(-3.0), + tonic_volatility_3=np.array(0.0), + tonic_drift_1=np.array(0.0), + tonic_drift_2=np.array(0.0), + tonic_drift_3=np.array(0.0), + volatility_coupling_1=np.array(1.0), + volatility_coupling_2=np.array(0.0), + input_precision=np.array(1e4), + response_function_parameters=np.ones(1), + ).eval() + + assert jnp.isclose(logp, 1141.05847168) + + ############## + # Binary HGF # + ############## + + # Create the data (value and time vectors) + u, y = load_data("binary") + + hgf_logp_op = HGFDistribution( + input_data=u[np.newaxis, :], + model_type="binary", + n_levels=2, + response_function=binary_softmax_inverse_temperature, + response_function_inputs=y[np.newaxis, :], + ) + + logp = hgf_logp_op( + tonic_volatility_1=np.inf, + tonic_volatility_2=-6.0, + tonic_volatility_3=np.inf, + input_precision=np.inf, + tonic_drift_1=0.0, + tonic_drift_2=0.0, + tonic_drift_3=np.inf, + precision_1=0.0, + precision_2=1e4, + precision_3=np.inf, + mean_1=np.inf, + mean_2=0.5, + mean_3=np.inf, + volatility_coupling_1=1.0, + volatility_coupling_2=np.inf, + response_function_parameters=np.array([1.0]), + ).eval() + + assert jnp.isclose(logp, -200.2442167699337) + + +def test_pytensor_grad_logp(): + """Test the pytensor gradient hgf_logp op.""" + + ################## + # Continuous HGF # + ################## + + # Create the data (value and time vectors) + timeseries = load_data("continuous") + + hgf_logp_grad_op = HGFLogpGradOp( + model_type="continuous", + input_data=timeseries[np.newaxis, :], + n_levels=2, + response_function=first_level_gaussian_surprise, + response_function_inputs=None, + ) + + tonic_volatility_1 = hgf_logp_grad_op( + tonic_volatility_1=-3.0, + tonic_volatility_2=-3.0, + input_precision=1e4, + tonic_drift_1=0.0, + tonic_drift_2=0.0, + precision_1=1.0, + precision_2=1.0, + mean_1=1.0, + mean_2=0.0, + volatility_coupling_1=1.0, + )[6].eval() + + assert jnp.isclose(tonic_volatility_1, -3.3479857) + + ############## + # Binary HGF # + ############## + + # Create the data (value and time vectors) + u, y = load_data("binary") + + hgf_logp_grad_op = HGFLogpGradOp( + model_type="binary", + input_data=u[np.newaxis, :], + n_levels=2, + response_function=binary_softmax_inverse_temperature, + response_function_inputs=y[np.newaxis, :], + ) + + tonic_volatility_2 = hgf_logp_grad_op( + tonic_volatility_1=jnp.inf, + tonic_volatility_2=-6.0, + input_precision=jnp.inf, + tonic_drift_1=0.0, + tonic_drift_2=0.0, + precision_1=0.0, + precision_2=1e4, + mean_1=jnp.inf, + mean_2=0.5, + volatility_coupling_1=1.0, + )[7].eval() + + assert jnp.isclose(tonic_volatility_2, 10.866466) + + +def test_pymc_sampling(): + """Test the pytensor hgf_logp op.""" + + ############## + # Continuous # + ############## + + # Create the data (value and time vectors) + timeseries = load_data("continuous") + + hgf_logp_op = HGFDistribution( + n_levels=2, + input_data=timeseries[np.newaxis, :], + response_function=first_level_gaussian_surprise, + ) + + with pm.Model() as model: + tonic_volatility_2 = pm.Uniform("tonic_volatility_2", -10, 0) + + pm.Potential( + "hhgf_loglike", + hgf_logp_op( + tonic_volatility_2=tonic_volatility_2, + input_precision=np.array(1e4), ), ) - gradients, _ = grad_logp( - 0.5, - 0.0, - 0.0, - 1.0, - 1.0, - 1.0, - -3.0, - -2.0, - -6.0, - 0.0, - 0.0, - 0.0, - 1.0, - 1.0, - np.inf, - np.ones(2), - ) - assert jnp.isclose(gradients[0], 0.0) - assert jnp.isclose(gradients[1], 1.9314771) - assert jnp.isclose(gradients[2], 30.185408) - - def test_pytensor_pointwise_logp(self): - """Test the pytensor HGFPointwise op.""" - - ############## - # Binary HGF # - ############## - - # Create the data (value and time vectors) - u, y = load_data("binary") - - hgf_logp_op = HGFPointwise( - input_data=u[np.newaxis, :], - model_type="binary", - n_levels=2, - response_function=binary_softmax_inverse_temperature, - response_function_inputs=y[np.newaxis, :], - ) - - logp = hgf_logp_op( - tonic_volatility_1=np.inf, - tonic_volatility_2=-6.0, - tonic_volatility_3=np.inf, - input_precision=np.inf, - tonic_drift_1=0.0, - tonic_drift_2=0.0, - tonic_drift_3=np.inf, - precision_1=0.0, - precision_2=1e4, - precision_3=np.inf, - mean_1=np.inf, - mean_2=0.5, - mean_3=np.inf, - volatility_coupling_1=1.0, - volatility_coupling_2=np.inf, - response_function_parameters=np.array([1.0]), - ).eval() - - assert jnp.isclose(logp.sum(), -200.2442167699337) - - def test_pytensor_logp(self): - """Test the pytensor hgf_logp op.""" - - ################## - # Continuous HGF # - ################## - - # Create the data (value and time vectors) - timeseries = load_data("continuous") - - hgf_logp_op = HGFDistribution( - input_data=timeseries[np.newaxis, :], - model_type="continuous", - n_levels=2, - response_function=first_level_gaussian_surprise, - response_function_inputs=None, - ) - - logp = hgf_logp_op( - mean_1=np.array(1.0), - mean_2=np.array(0.0), - mean_3=np.array(0.0), - precision_1=np.array(1e4), - precision_2=np.array(1e1), - precision_3=np.array(0.0), - tonic_volatility_1=np.array(-3.0), - tonic_volatility_2=np.array(-3.0), - tonic_volatility_3=np.array(0.0), - tonic_drift_1=np.array(0.0), - tonic_drift_2=np.array(0.0), - tonic_drift_3=np.array(0.0), - volatility_coupling_1=np.array(1.0), - volatility_coupling_2=np.array(0.0), - input_precision=np.array(1e4), - response_function_parameters=np.ones(1), - ).eval() - - assert jnp.isclose(logp, 1141.05847168) - - ############## - # Binary HGF # - ############## - - # Create the data (value and time vectors) - u, y = load_data("binary") - - hgf_logp_op = HGFDistribution( - input_data=u[np.newaxis, :], - model_type="binary", - n_levels=2, - response_function=binary_softmax_inverse_temperature, - response_function_inputs=y[np.newaxis, :], - ) - - logp = hgf_logp_op( - tonic_volatility_1=np.inf, - tonic_volatility_2=-6.0, - tonic_volatility_3=np.inf, - input_precision=np.inf, - tonic_drift_1=0.0, - tonic_drift_2=0.0, - tonic_drift_3=np.inf, - precision_1=0.0, - precision_2=1e4, - precision_3=np.inf, - mean_1=np.inf, - mean_2=0.5, - mean_3=np.inf, - volatility_coupling_1=1.0, - volatility_coupling_2=np.inf, - response_function_parameters=np.array([1.0]), - ).eval() - - assert jnp.isclose(logp, -200.2442167699337) - - def test_pytensor_grad_logp(self): - """Test the pytensor gradient hgf_logp op.""" - - ################## - # Continuous HGF # - ################## - - # Create the data (value and time vectors) - timeseries = load_data("continuous") - - hgf_logp_grad_op = HGFLogpGradOp( - model_type="continuous", - input_data=timeseries[np.newaxis, :], - n_levels=2, - response_function=first_level_gaussian_surprise, - response_function_inputs=None, - ) - - tonic_volatility_1 = hgf_logp_grad_op( - tonic_volatility_1=-3.0, - tonic_volatility_2=-3.0, - input_precision=1e4, - tonic_drift_1=0.0, - tonic_drift_2=0.0, - precision_1=1.0, - precision_2=1.0, - mean_1=1.0, - mean_2=0.0, - volatility_coupling_1=1.0, - )[6].eval() - - assert jnp.isclose(tonic_volatility_1, -3.3479857) - - ############## - # Binary HGF # - ############## - - # Create the data (value and time vectors) - u, y = load_data("binary") - - hgf_logp_grad_op = HGFLogpGradOp( - model_type="binary", - input_data=u[np.newaxis, :], - n_levels=2, - response_function=binary_softmax_inverse_temperature, - response_function_inputs=y[np.newaxis, :], - ) - - tonic_volatility_2 = hgf_logp_grad_op( - tonic_volatility_1=jnp.inf, - tonic_volatility_2=-6.0, - input_precision=jnp.inf, - tonic_drift_1=0.0, - tonic_drift_2=0.0, - precision_1=0.0, - precision_2=1e4, - mean_1=jnp.inf, - mean_2=0.5, - volatility_coupling_1=1.0, - )[7].eval() - - assert jnp.isclose(tonic_volatility_2, 10.866466) - - def test_pymc_sampling(self): - """Test the pytensor hgf_logp op.""" - - ############## - # Continuous # - ############## - - # Create the data (value and time vectors) - timeseries = load_data("continuous") - - hgf_logp_op = HGFDistribution( - n_levels=2, - input_data=timeseries[np.newaxis, :], - response_function=first_level_gaussian_surprise, - ) - - with pm.Model() as model: - tonic_volatility_2 = pm.Uniform("tonic_volatility_2", -10, 0) - - pm.Potential( - "hhgf_loglike", - hgf_logp_op( - tonic_volatility_2=tonic_volatility_2, - input_precision=np.array(1e4), - ), - ) - - initial_point = model.initial_point() - - pointslogs = model.point_logps(initial_point) - assert pointslogs["tonic_volatility_2"] == -1.39 - assert pointslogs["hhgf_loglike"] == 1468.28 + initial_point = model.initial_point() - with model: - idata = pm.sample(chains=2, cores=1, tune=1000) + pointslogs = model.point_logps(initial_point) + assert pointslogs["tonic_volatility_2"] == -1.39 + assert pointslogs["hhgf_loglike"] == 1468.28 - assert -9.5 > az.summary(idata)["mean"].values[0] > -10.5 - assert az.summary(idata)["r_hat"].values[0] <= 1.02 + with model: + idata = pm.sample(chains=2, cores=1, tune=1000) - ########## - # Binary # - ########## + assert -9.5 > az.summary(idata)["mean"].values[0] > -10.5 + assert az.summary(idata)["r_hat"].values[0] <= 1.02 - # Create the data (value and time vectors) - u, y = load_data("binary") - - hgf_logp_op = HGFDistribution( - n_levels=2, - model_type="binary", - input_data=u[np.newaxis, :], - response_function=binary_softmax_inverse_temperature, - response_function_inputs=y[np.newaxis, :], - ) + ########## + # Binary # + ########## - def logp(value, tonic_volatility_2): - return hgf_logp_op(tonic_volatility_2=tonic_volatility_2) + # Create the data (value and time vectors) + u, y = load_data("binary") - with pm.Model() as model: - y_data = pm.Data("y_data", y) - tonic_volatility_2 = pm.Normal("tonic_volatility_2", -11.0, 2) - pm.CustomDist("likelihood", tonic_volatility_2, logp=logp, observed=y_data) + hgf_logp_op = HGFDistribution( + n_levels=2, + model_type="binary", + input_data=u[np.newaxis, :], + response_function=binary_softmax_inverse_temperature, + response_function_inputs=y[np.newaxis, :], + ) - initial_point = model.initial_point() + def logp(value, tonic_volatility_2): + return hgf_logp_op(tonic_volatility_2=tonic_volatility_2) - pointslogs = model.point_logps(initial_point) - assert pointslogs["tonic_volatility_2"] == -1.61 - assert pointslogs["likelihood"] == -212.59 + with pm.Model() as model: + y_data = pm.Data("y_data", y) + tonic_volatility_2 = pm.Normal("tonic_volatility_2", -11.0, 2) + pm.CustomDist("likelihood", tonic_volatility_2, logp=logp, observed=y_data) - with model: - idata = pm.sample(chains=2, cores=1, tune=1000) + initial_point = model.initial_point() - assert -2 < round(az.summary(idata)["mean"].values[0]) < 0 + pointslogs = model.point_logps(initial_point) + assert pointslogs["tonic_volatility_2"] == -1.61 + assert pointslogs["likelihood"] == -212.59 + with model: + idata = pm.sample(chains=2, cores=1, tune=1000) -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) + assert -2 < round(az.summary(idata)["mean"].values[0]) < 0 diff --git a/tests/test_math.py b/tests/test_math.py index bae221af..f25e4a99 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,8 +1,4 @@ # Author: Nicolas Legrand - -import unittest -from unittest import TestCase - import jax.numpy as jnp from pyhgf.math import ( @@ -14,51 +10,48 @@ ) -class TestMath(TestCase): - def test_multivariate_normal(self): +def test_multivariate_normal(): + + ss = MultivariateNormal.sufficient_statistics(jnp.array([1.0, 2.0])) + assert jnp.isclose(ss, jnp.array([1.0, 2.0, 1.0, 2.0, 4.0], dtype="float32")).all() + + bm = MultivariateNormal.base_measure(2) + assert bm == 0.15915494309189535 - ss = MultivariateNormal.sufficient_statistics(jnp.array([1.0, 2.0])) - assert jnp.isclose( - ss, jnp.array([1.0, 2.0, 1.0, 2.0, 4.0], dtype="float32") - ).all() - bm = MultivariateNormal.base_measure(2) - assert bm == 0.15915494309189535 +def test_normal(): - def test_normal(self): + ss = Normal.sufficient_statistics(jnp.array(1.0)) + assert jnp.isclose(ss, jnp.array([1.0, 1.0], dtype="float32")).all() - ss = Normal.sufficient_statistics(jnp.array(1.0)) - assert jnp.isclose(ss, jnp.array([1.0, 1.0], dtype="float32")).all() + bm = Normal.base_measure() + assert bm == 0.3989423 - bm = Normal.base_measure() - assert bm == 0.3989423 + ess = Normal.expected_sufficient_statistics(mu=0.0, sigma=1.0) + assert jnp.isclose(ess, jnp.array([0.0, 1.0], dtype="float32")).all() - ess = Normal.expected_sufficient_statistics(mu=0.0, sigma=1.0) - assert jnp.isclose(ess, jnp.array([0.0, 1.0], dtype="float32")).all() + par = Normal.parameters(xis=[5.0, 29.0]) + assert jnp.isclose(jnp.array(par), jnp.array([5.0, 4.0], dtype="float32")).all() - par = Normal.parameters(xis=[5.0, 29.0]) - assert jnp.isclose(jnp.array(par), jnp.array([5.0, 4.0], dtype="float32")).all() - def test_gaussian_predictive_distribution(self): +def test_gaussian_predictive_distribution(): - pdf = gaussian_predictive_distribution(x=1.5, xi=[0.0, 1 / 8], nu=5.0) - assert jnp.isclose(pdf, jnp.array(0.00845728, dtype="float32")) + pdf = gaussian_predictive_distribution(x=1.5, xi=[0.0, 1 / 8], nu=5.0) + assert jnp.isclose(pdf, jnp.array(0.00845728, dtype="float32")) - def test_binary_surprise_finite_precision(self): - surprise = binary_surprise_finite_precision( - value=1.0, - expected_mean=0.0, - expected_precision=1.0, - eta0=0.0, - eta1=1.0, - ) - assert surprise == 1.4189385 +def test_binary_surprise_finite_precision(): - def test_sigmoid_inverse_temperature(self): - s = sigmoid_inverse_temperature(x=0.4, temperature=6.0) - assert jnp.isclose(s, jnp.array(0.08070617906683485, dtype="float32")) + surprise = binary_surprise_finite_precision( + value=1.0, + expected_mean=0.0, + expected_precision=1.0, + eta0=0.0, + eta1=1.0, + ) + assert surprise == 1.4189385 -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) +def test_sigmoid_inverse_temperature(): + s = sigmoid_inverse_temperature(x=0.4, temperature=6.0) + assert jnp.isclose(s, jnp.array(0.08070617906683485, dtype="float32")) diff --git a/tests/test_model.py b/tests/test_model.py index 87ddb0ee..ee50f758 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,8 +1,5 @@ # Author: Nicolas Legrand -import unittest -from unittest import TestCase - import jax.numpy as jnp import numpy as np @@ -11,152 +8,145 @@ from pyhgf.response import total_gaussian_surprise -class Testmodel(TestCase): - def test_HGF(self): - """Test the model class""" - - ##################### - # Creating networks # - ##################### - - custom_hgf = ( - HGF(model_type=None) - .add_nodes(kind="continuous-input") - .add_nodes(kind="binary-input") - .add_nodes(value_children=0) - .add_nodes( - kind="binary-state", - value_children=1, - ) - .add_nodes(value_children=[2, 3]) - .add_nodes(value_children=4) - .add_nodes(volatility_children=[2, 3]) - .add_nodes(volatility_children=2) - .add_nodes(volatility_children=7) - ) - - custom_hgf.cache_belief_propagation_fn() - custom_hgf.create_belief_propagation_fn(overwrite=False) - custom_hgf.create_belief_propagation_fn(overwrite=True) - - custom_hgf.input_data(input_data=np.array([0.2, 1])) - - ############## - # Continuous # - ############## - timeserie = load_data("continuous") - - # two-level - # --------- - two_level_continuous_hgf = HGF( - n_levels=2, - model_type="continuous", - initial_mean={"1": timeserie[0], "2": 0.0}, - initial_precision={"1": 1e4, "2": 1e1}, - tonic_volatility={"1": -3.0, "2": -3.0}, - tonic_drift={"1": 0.0, "2": 0.0}, - volatility_coupling={"1": 1.0}, - ) - - two_level_continuous_hgf.input_data(input_data=timeserie) - - surprise = ( - two_level_continuous_hgf.surprise() - ) # Sum the surprise for this model - assert jnp.isclose(surprise.sum(), -1141.0911) - assert len(two_level_continuous_hgf.node_trajectories[1]["mean"]) == 614 - - # three-level - # ----------- - three_level_continuous_hgf = HGF( - n_levels=3, - model_type="continuous", - initial_mean={"1": 1.04, "2": 1.0, "3": 1.0}, - initial_precision={"1": 1e4, "2": 1e1, "3": 1e1}, - tonic_volatility={"1": -13.0, "2": -2.0, "3": -2.0}, - tonic_drift={"1": 0.0, "2": 0.0, "3": 0.0}, - volatility_coupling={"1": 1.0, "2": 1.0}, - ) - three_level_continuous_hgf.input_data(input_data=timeserie) - surprise = three_level_continuous_hgf.surprise() - assert jnp.isclose(surprise.sum(), -892.82227) - - # test an alternative response function - sp = total_gaussian_surprise(three_level_continuous_hgf) - assert jnp.isclose(sp.sum(), 1159.1089) - - ########## - # Binary # - ########## - u, _ = load_data("binary") - - # two-level - # --------- - two_level_binary_hgf = HGF( - n_levels=2, - model_type="binary", - initial_mean={"1": 0.0, "2": 0.5}, - initial_precision={"1": 0.0, "2": 1e4}, - tonic_volatility={"1": None, "2": -6.0}, - tonic_drift={"1": None, "2": 0.0}, - volatility_coupling={"1": None}, - eta0=0.0, - eta1=1.0, - binary_precision=jnp.inf, - ) - - # Provide new observations - two_level_binary_hgf = two_level_binary_hgf.input_data(u) - surprise = two_level_binary_hgf.surprise() - assert jnp.isclose(surprise.sum(), 215.58821) - - # three-level - # ----------- - three_level_binary_hgf = HGF( - n_levels=3, - model_type="binary", - initial_mean={"1": 0.0, "2": 0.5, "3": 0.0}, - initial_precision={"1": 0.0, "2": 1e4, "3": 1e1}, - tonic_volatility={"1": None, "2": -6.0, "3": -2.0}, - tonic_drift={"1": None, "2": 0.0, "3": 0.0}, - volatility_coupling={"1": None, "2": 1.0}, - eta0=0.0, - eta1=1.0, - binary_precision=jnp.inf, - ) - three_level_binary_hgf.input_data(input_data=u) - surprise = three_level_binary_hgf.surprise() - assert jnp.isclose(surprise.sum(), 215.59067) - - ############################ - # dynamic update sequences # - ############################ - - three_level_binary_hgf = HGF( - n_levels=3, - model_type="binary", - initial_mean={"1": 0.0, "2": 0.5, "3": 0.0}, - initial_precision={"1": 0.0, "2": 1e4, "3": 1e1}, - tonic_volatility={"1": None, "2": -6.0, "3": -2.0}, - tonic_drift={"1": None, "2": 0.0, "3": 0.0}, - volatility_coupling={"1": None, "2": 1.0}, - eta0=0.0, - eta1=1.0, - binary_precision=jnp.inf, - ) +def test_HGF(): + """Test the model class""" - # create a custom update series - update_sequence1 = three_level_binary_hgf.update_sequence - update_sequence2 = update_sequence1[:2] - update_branches = (update_sequence1, update_sequence2) - branches_idx = np.random.binomial(n=1, p=0.5, size=len(u)) + ##################### + # Creating networks # + ##################### - three_level_binary_hgf.input_custom_sequence( - update_branches=update_branches, - branches_idx=branches_idx, - input_data=u, + custom_hgf = ( + HGF(model_type=None) + .add_nodes(kind="continuous-input") + .add_nodes(kind="binary-input") + .add_nodes(value_children=0) + .add_nodes( + kind="binary-state", + value_children=1, ) - - -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) + .add_nodes(value_children=[2, 3]) + .add_nodes(value_children=4) + .add_nodes(volatility_children=[2, 3]) + .add_nodes(volatility_children=2) + .add_nodes(volatility_children=7) + ) + + custom_hgf.cache_belief_propagation_fn() + custom_hgf.create_belief_propagation_fn(overwrite=False) + custom_hgf.create_belief_propagation_fn(overwrite=True) + + custom_hgf.input_data(input_data=np.array([0.2, 1])) + + ############## + # Continuous # + ############## + timeserie = load_data("continuous") + + # two-level + # --------- + two_level_continuous_hgf = HGF( + n_levels=2, + model_type="continuous", + initial_mean={"1": timeserie[0], "2": 0.0}, + initial_precision={"1": 1e4, "2": 1e1}, + tonic_volatility={"1": -3.0, "2": -3.0}, + tonic_drift={"1": 0.0, "2": 0.0}, + volatility_coupling={"1": 1.0}, + ) + + two_level_continuous_hgf.input_data(input_data=timeserie) + + surprise = two_level_continuous_hgf.surprise() # Sum the surprise for this model + assert jnp.isclose(surprise.sum(), -1141.0911) + assert len(two_level_continuous_hgf.node_trajectories[1]["mean"]) == 614 + + # three-level + # ----------- + three_level_continuous_hgf = HGF( + n_levels=3, + model_type="continuous", + initial_mean={"1": 1.04, "2": 1.0, "3": 1.0}, + initial_precision={"1": 1e4, "2": 1e1, "3": 1e1}, + tonic_volatility={"1": -13.0, "2": -2.0, "3": -2.0}, + tonic_drift={"1": 0.0, "2": 0.0, "3": 0.0}, + volatility_coupling={"1": 1.0, "2": 1.0}, + ) + three_level_continuous_hgf.input_data(input_data=timeserie) + surprise = three_level_continuous_hgf.surprise() + assert jnp.isclose(surprise.sum(), -892.82227) + + # test an alternative response function + sp = total_gaussian_surprise(three_level_continuous_hgf) + assert jnp.isclose(sp.sum(), 1159.1089) + + ########## + # Binary # + ########## + u, _ = load_data("binary") + + # two-level + # --------- + two_level_binary_hgf = HGF( + n_levels=2, + model_type="binary", + initial_mean={"1": 0.0, "2": 0.5}, + initial_precision={"1": 0.0, "2": 1e4}, + tonic_volatility={"1": None, "2": -6.0}, + tonic_drift={"1": None, "2": 0.0}, + volatility_coupling={"1": None}, + eta0=0.0, + eta1=1.0, + binary_precision=jnp.inf, + ) + + # Provide new observations + two_level_binary_hgf = two_level_binary_hgf.input_data(u) + surprise = two_level_binary_hgf.surprise() + assert jnp.isclose(surprise.sum(), 215.58821) + + # three-level + # ----------- + three_level_binary_hgf = HGF( + n_levels=3, + model_type="binary", + initial_mean={"1": 0.0, "2": 0.5, "3": 0.0}, + initial_precision={"1": 0.0, "2": 1e4, "3": 1e1}, + tonic_volatility={"1": None, "2": -6.0, "3": -2.0}, + tonic_drift={"1": None, "2": 0.0, "3": 0.0}, + volatility_coupling={"1": None, "2": 1.0}, + eta0=0.0, + eta1=1.0, + binary_precision=jnp.inf, + ) + three_level_binary_hgf.input_data(input_data=u) + surprise = three_level_binary_hgf.surprise() + assert jnp.isclose(surprise.sum(), 215.59067) + + ############################ + # dynamic update sequences # + ############################ + + three_level_binary_hgf = HGF( + n_levels=3, + model_type="binary", + initial_mean={"1": 0.0, "2": 0.5, "3": 0.0}, + initial_precision={"1": 0.0, "2": 1e4, "3": 1e1}, + tonic_volatility={"1": None, "2": -6.0, "3": -2.0}, + tonic_drift={"1": None, "2": 0.0, "3": 0.0}, + volatility_coupling={"1": None, "2": 1.0}, + eta0=0.0, + eta1=1.0, + binary_precision=jnp.inf, + ) + + # create a custom update series + update_sequence1 = three_level_binary_hgf.update_sequence + update_sequence2 = update_sequence1[:2] + update_branches = (update_sequence1, update_sequence2) + branches_idx = np.random.binomial(n=1, p=0.5, size=len(u)) + + three_level_binary_hgf.input_custom_sequence( + update_branches=update_branches, + branches_idx=branches_idx, + input_data=u, + ) diff --git a/tests/test_plots.py b/tests/test_plots.py index 4465fe3d..24956207 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,8 +1,5 @@ # Author: Nicolas Legrand -import unittest -from unittest import TestCase - import jax.numpy as jnp import numpy as np @@ -10,151 +7,146 @@ from pyhgf.model import HGF -class Testplots(TestCase): - def test_plotting_functions(self): - # Read USD-CHF data - timeserie = load_data("continuous") - - ############## - # Continuous # - # ------------ - - # Set up standard 2-level HGF for continuous inputs - two_level_continuous = HGF( - n_levels=2, - model_type="continuous", - initial_mean={"1": 1.04, "2": 1.0}, - initial_precision={"1": 1e4, "2": 1e1}, - tonic_volatility={"1": -13.0, "2": -2.0}, - tonic_drift={"1": 0.0, "2": 0.0}, - volatility_coupling={"1": 1.0}, - ).input_data(input_data=timeserie) - - # plot trajectories - two_level_continuous.plot_trajectories() - - # plot correlations - two_level_continuous.plot_correlations() - - # plot node structures - two_level_continuous.plot_network() - - # plot nodes - two_level_continuous.plot_nodes( - node_idxs=2, show_current_state=True, show_observations=True - ) - - # Set up standard 3-level HGF for continuous inputs - three_level_continuous = HGF( - n_levels=3, - model_type="continuous", - initial_mean={"1": 1.04, "2": 1.0, "3": 1.0}, - initial_precision={"1": 1e4, "2": 1e1, "3": 1e1}, - tonic_volatility={"1": -13.0, "2": -2.0, "3": -2.0}, - tonic_drift={"1": 0.0, "2": 0.0, "3": 0.0}, - volatility_coupling={"1": 1.0, "2": 1.0}, - ).input_data(input_data=timeserie) - - # plot trajectories - three_level_continuous.plot_trajectories() - - # plot correlations - three_level_continuous.plot_correlations() - - # plot node structures - three_level_continuous.plot_network() - - # plot nodes - three_level_continuous.plot_nodes( - node_idxs=2, show_current_state=True, show_observations=True - ) - - ########## - # Binary # - # -------- - - # Read binary input - u, _ = load_data("binary") - - two_level_binary_hgf = HGF( - n_levels=2, - model_type="binary", - initial_mean={"1": 0.0, "2": 0.5}, - initial_precision={"1": 0.0, "2": 1e4}, - tonic_volatility={"1": None, "2": -6.0}, - tonic_drift={"1": None, "2": 0.0}, - volatility_coupling={"1": None}, - eta0=0.0, - eta1=1.0, - binary_precision=jnp.inf, - ).input_data(u) - - # plot trajectories - two_level_binary_hgf.plot_trajectories() - - # plot correlations - two_level_binary_hgf.plot_correlations() - - # plot node structures - two_level_binary_hgf.plot_network() - - # plot node structures - two_level_binary_hgf.plot_nodes( - node_idxs=2, show_current_state=True, show_observations=True - ) - - three_level_binary_hgf = HGF( - n_levels=3, - model_type="binary", - initial_mean={"1": 0.0, "2": 0.5, "3": 0.0}, - initial_precision={"1": 0.0, "2": 1e4, "3": 1e1}, - tonic_volatility={"1": None, "2": -6.0, "3": -2.0}, - tonic_drift={"1": None, "2": 0.0, "3": 0.0}, - volatility_coupling={"1": None, "2": 1.0}, - eta0=0.0, - eta1=1.0, - binary_precision=jnp.inf, - ).input_data(u) - - # plot trajectories - three_level_binary_hgf.plot_trajectories() - - # plot correlations - three_level_binary_hgf.plot_correlations() - - # plot node structures - three_level_binary_hgf.plot_network() - - # plot node structures - three_level_binary_hgf.plot_nodes( - node_idxs=2, show_current_state=True, show_observations=True - ) - - ############# - # Categorical - # ----------- - - # generate some categorical inputs data - input_data = np.array( - [np.random.multinomial(n=1, pvals=[0.1, 0.2, 0.7]) for _ in range(3)] - ).T - input_data = np.vstack([[0.0] * input_data.shape[1], input_data]) - - # create the categorical HGF - categorical_hgf = HGF(model_type=None, verbose=False).add_nodes( - kind="categorical-input", - node_parameters={ - "n_categories": 3, - "binary_parameters": {"tonic_volatility_2": -2.0}, - }, - ) - - # fitting the model forwards - categorical_hgf.input_data(input_data=input_data.T) - - # plot node structures - categorical_hgf.plot_network() - - -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) +def test_plotting_functions(): + # Read USD-CHF data + timeserie = load_data("continuous") + + ############## + # Continuous # + # ------------ + + # Set up standard 2-level HGF for continuous inputs + two_level_continuous = HGF( + n_levels=2, + model_type="continuous", + initial_mean={"1": 1.04, "2": 1.0}, + initial_precision={"1": 1e4, "2": 1e1}, + tonic_volatility={"1": -13.0, "2": -2.0}, + tonic_drift={"1": 0.0, "2": 0.0}, + volatility_coupling={"1": 1.0}, + ).input_data(input_data=timeserie) + + # plot trajectories + two_level_continuous.plot_trajectories() + + # plot correlations + two_level_continuous.plot_correlations() + + # plot node structures + two_level_continuous.plot_network() + + # plot nodes + two_level_continuous.plot_nodes( + node_idxs=2, show_current_state=True, show_observations=True + ) + + # Set up standard 3-level HGF for continuous inputs + three_level_continuous = HGF( + n_levels=3, + model_type="continuous", + initial_mean={"1": 1.04, "2": 1.0, "3": 1.0}, + initial_precision={"1": 1e4, "2": 1e1, "3": 1e1}, + tonic_volatility={"1": -13.0, "2": -2.0, "3": -2.0}, + tonic_drift={"1": 0.0, "2": 0.0, "3": 0.0}, + volatility_coupling={"1": 1.0, "2": 1.0}, + ).input_data(input_data=timeserie) + + # plot trajectories + three_level_continuous.plot_trajectories() + + # plot correlations + three_level_continuous.plot_correlations() + + # plot node structures + three_level_continuous.plot_network() + + # plot nodes + three_level_continuous.plot_nodes( + node_idxs=2, show_current_state=True, show_observations=True + ) + + ########## + # Binary # + # -------- + + # Read binary input + u, _ = load_data("binary") + + two_level_binary_hgf = HGF( + n_levels=2, + model_type="binary", + initial_mean={"1": 0.0, "2": 0.5}, + initial_precision={"1": 0.0, "2": 1e4}, + tonic_volatility={"1": None, "2": -6.0}, + tonic_drift={"1": None, "2": 0.0}, + volatility_coupling={"1": None}, + eta0=0.0, + eta1=1.0, + binary_precision=jnp.inf, + ).input_data(u) + + # plot trajectories + two_level_binary_hgf.plot_trajectories() + + # plot correlations + two_level_binary_hgf.plot_correlations() + + # plot node structures + two_level_binary_hgf.plot_network() + + # plot node structures + two_level_binary_hgf.plot_nodes( + node_idxs=2, show_current_state=True, show_observations=True + ) + + three_level_binary_hgf = HGF( + n_levels=3, + model_type="binary", + initial_mean={"1": 0.0, "2": 0.5, "3": 0.0}, + initial_precision={"1": 0.0, "2": 1e4, "3": 1e1}, + tonic_volatility={"1": None, "2": -6.0, "3": -2.0}, + tonic_drift={"1": None, "2": 0.0, "3": 0.0}, + volatility_coupling={"1": None, "2": 1.0}, + eta0=0.0, + eta1=1.0, + binary_precision=jnp.inf, + ).input_data(u) + + # plot trajectories + three_level_binary_hgf.plot_trajectories() + + # plot correlations + three_level_binary_hgf.plot_correlations() + + # plot node structures + three_level_binary_hgf.plot_network() + + # plot node structures + three_level_binary_hgf.plot_nodes( + node_idxs=2, show_current_state=True, show_observations=True + ) + + ############# + # Categorical + # ----------- + + # generate some categorical inputs data + input_data = np.array( + [np.random.multinomial(n=1, pvals=[0.1, 0.2, 0.7]) for _ in range(3)] + ).T + input_data = np.vstack([[0.0] * input_data.shape[1], input_data]) + + # create the categorical HGF + categorical_hgf = HGF(model_type=None, verbose=False).add_nodes( + kind="categorical-input", + node_parameters={ + "n_categories": 3, + "binary_parameters": {"tonic_volatility_2": -2.0}, + }, + ) + + # fitting the model forwards + categorical_hgf.input_data(input_data=input_data.T) + + # plot node structures + categorical_hgf.plot_network() diff --git a/tests/test_responses.py b/tests/test_responses.py index ceb07ae2..bfa8c807 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -1,8 +1,5 @@ # Author: Nicolas Legrand -import unittest -from unittest import TestCase - import numpy as np from pyhgf import load_data @@ -10,36 +7,31 @@ from pyhgf.response import binary_softmax, binary_softmax_inverse_temperature -class TestResponses(TestCase): - def test_binary_responses(self): - u, y = load_data("binary") - - # two-level binary HGF - # -------------------- - two_level_binary_hgf = HGF( - n_levels=2, - model_type="binary", - initial_mean={"1": 0.5, "2": 0.0}, - initial_precision={"1": 0.0, "2": 1.0}, - tonic_volatility={"2": -6.0}, - ).input_data(input_data=u) - - # binary sofmax - # ------------- - surprise = two_level_binary_hgf.surprise( - response_function=binary_softmax, response_function_inputs=y - ) - assert np.isclose(surprise.sum(), 195.81573) - - # binary sofmax with inverse temperature - # -------------------------------------- - surprise = two_level_binary_hgf.surprise( - response_function=binary_softmax_inverse_temperature, - response_function_inputs=y, - response_function_parameters=2.0, - ) - assert np.isclose(surprise.sum(), 188.77818) - - -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) +def test_binary_responses(): + u, y = load_data("binary") + + # two-level binary HGF + # -------------------- + two_level_binary_hgf = HGF( + n_levels=2, + model_type="binary", + initial_mean={"1": 0.5, "2": 0.0}, + initial_precision={"1": 0.0, "2": 1.0}, + tonic_volatility={"2": -6.0}, + ).input_data(input_data=u) + + # binary sofmax + # ------------- + surprise = two_level_binary_hgf.surprise( + response_function=binary_softmax, response_function_inputs=y + ) + assert np.isclose(surprise.sum(), 195.81573) + + # binary sofmax with inverse temperature + # -------------------------------------- + surprise = two_level_binary_hgf.surprise( + response_function=binary_softmax_inverse_temperature, + response_function_inputs=y, + response_function_parameters=2.0, + ) + assert np.isclose(surprise.sum(), 188.77818) diff --git a/tests/test_updates/prediction_errors/inputs/test_prediction_errors.py b/tests/test_updates/prediction_errors/inputs/test_prediction_errors.py index f8d4f63c..b3c28564 100644 --- a/tests/test_updates/prediction_errors/inputs/test_prediction_errors.py +++ b/tests/test_updates/prediction_errors/inputs/test_prediction_errors.py @@ -1,32 +1,24 @@ # Author: Nicolas Legrand -import unittest -from unittest import TestCase - from pyhgf.model import Network from pyhgf.updates.prediction_error.inputs.generic import generic_input_prediction_error -class TestPredictionErrors(TestCase): - def test_generic_input(self): - """Test the generic input nodes""" - - ############################################### - # one value parent with one volatility parent # - ############################################### - network = Network().add_nodes(kind="generic-input").add_nodes(value_children=0) - - attributes, (_, edges), _ = network.get_network() +def test_generic_input(): + """Test the generic input nodes""" - attributes = generic_input_prediction_error( - attributes=attributes, - time_step=1.0, - edges=edges, - node_idx=0, - value=10.0, - observed=True, - ) + ############################################### + # one value parent with one volatility parent # + ############################################### + network = Network().add_nodes(kind="generic-input").add_nodes(value_children=0) + attributes, (_, edges), _ = network.get_network() -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) + attributes = generic_input_prediction_error( + attributes=attributes, + time_step=1.0, + edges=edges, + node_idx=0, + value=10.0, + observed=True, + ) diff --git a/tests/test_updates/prediction_errors/nodes/test_dirichlet.py b/tests/test_updates/prediction_errors/nodes/test_dirichlet.py index c69c73c3..d69b7d8c 100644 --- a/tests/test_updates/prediction_errors/nodes/test_dirichlet.py +++ b/tests/test_updates/prediction_errors/nodes/test_dirichlet.py @@ -1,8 +1,5 @@ # Author: Nicolas Legrand -import unittest -from unittest import TestCase - import jax.numpy as jnp import numpy as np @@ -13,49 +10,45 @@ ) -class TestDirichletNode(TestCase): - def test_get_candidate(self): - mean, precision = get_candidate( - value=5.0, - sensory_precision=1.0, - expected_mean=jnp.array([0.0, -5.0]), - expected_sigma=jnp.array([1.0, 3.0]), - ) +def test_get_candidate(): + mean, precision = get_candidate( + value=5.0, + sensory_precision=1.0, + expected_mean=jnp.array([0.0, -5.0]), + expected_sigma=jnp.array([1.0, 3.0]), + ) - assert jnp.isclose(mean, 5.026636) - assert jnp.isclose(precision, 1.2752448) - - def test_dirichlet_node_prediction_error(self): - - network = ( - Network() - .add_nodes(kind="generic-input") - .add_nodes(kind="DP-state", value_children=0, batch_size=2) - .add_nodes( - kind="ef-normal", - n_nodes=2, - value_children=1, - xis=jnp.array([0.0, 1 / 8]), - nus=15.0, - ) - ) + assert jnp.isclose(mean, 5.026636) + assert jnp.isclose(precision, 1.2752448) - attributes, (_, edges), _ = network.get_network() - dirichlet_node_prediction_error( - edges=edges, - attributes=attributes, - node_idx=1, - ) - # test the plotting function - network.plot_network() +def test_dirichlet_node_prediction_error(): + + network = ( + Network() + .add_nodes(kind="generic-input") + .add_nodes(kind="DP-state", value_children=0, batch_size=2) + .add_nodes( + kind="ef-normal", + n_nodes=2, + value_children=1, + xis=jnp.array([0.0, 1 / 8]), + nus=15.0, + ) + ) - # add observations - network.input_data(input_data=np.random.normal(0, 1, 5)) + attributes, (_, edges), _ = network.get_network() + dirichlet_node_prediction_error( + edges=edges, + attributes=attributes, + node_idx=1, + ) - # export to pandas - network.to_pandas() + # test the plotting function + network.plot_network() + # add observations + network.input_data(input_data=np.random.normal(0, 1, 5)) -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) + # export to pandas + network.to_pandas() diff --git a/tests/test_utils.py b/tests/test_utils.py index 602d0d7e..e88fa785 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,5 @@ # Author: Nicolas Legrand -import unittest -from unittest import TestCase - import jax.numpy as jnp from pytest import raises @@ -19,178 +16,176 @@ from pyhgf.utils import beliefs_propagation, list_branches -class TestUtils(TestCase): - - def test_imports(self): - """Test the data import function""" - _ = load_data("continuous") - _, _ = load_data("binary") - - with raises(Exception): - load_data("error") - - def test_beliefs_propagation(self): - """Test the loop_inputs function""" - - ############################################### - # one value parent with one volatility parent # - ############################################### - input_node_parameters = { - "input_precision": 1e4, - "expected_precision": jnp.nan, - "surprise": 0.0, - "time_step": 0.0, - "values": 0.0, - "observed": 1, - "volatility_coupling_parents": None, - "value_coupling_parents": (1.0,), - "temp": { - "effective_precision": 1.0, - "value_prediction_error": 0.0, - "volatility_prediction_error": 0.0, - }, - } - node_parameters_1 = { - "expected_precision": 1.0, - "precision": 1.0, - "expected_mean": 1.0, - "value_coupling_children": (1.0,), - "value_coupling_parents": None, - "volatility_coupling_parents": (1.0,), - "volatility_coupling_children": None, - "mean": 1.0, - "observed": 1, - "tonic_volatility": -3.0, - "tonic_drift": 0.0, - "temp": { - "effective_precision": 1.0, - "value_prediction_error": 0.0, - "volatility_prediction_error": 0.0, - }, - } - node_parameters_2 = { - "expected_precision": 1.0, - "precision": 1.0, - "expected_mean": 1.0, - "value_coupling_children": None, - "value_coupling_parents": None, - "volatility_coupling_parents": None, - "volatility_coupling_children": (1.0,), - "mean": 1.0, - "observed": 1, - "tonic_volatility": -3.0, - "tonic_drift": 0.0, - "temp": { - "effective_precision": 1.0, - "value_prediction_error": 0.0, - "volatility_prediction_error": 0.0, - }, - } - edges = ( - AdjacencyLists(0, (1,), None, None, None, (None,)), - AdjacencyLists(2, None, (2,), (0,), None, (None,)), - AdjacencyLists(2, None, None, None, (1,), (None,)), - ) - attributes = ( - input_node_parameters, - node_parameters_1, - node_parameters_2, - ) - - # create update sequence - sequence1 = 0, continuous_input_prediction_error - sequence2 = 1, continuous_node_update - sequence3 = 2, continuous_node_update_ehgf - update_sequence = (sequence1, sequence2, sequence3) - - # one batch of new observations with time step - data = jnp.array([0.2]) - time_steps = jnp.ones(1) - observed = jnp.ones(1) - inputs = Inputs(0, 1) - - # apply sequence - new_attributes, _ = beliefs_propagation( - attributes=attributes, - input_data=(data, time_steps, observed), - update_sequence=update_sequence, - structure=(inputs, edges), - ) - - assert new_attributes[1]["mean"] == 0.20008 - assert new_attributes[2]["precision"] == 1.5 - - def test_add_edges(self): - """Test the add_edges function.""" - network = Network().add_nodes(kind="continuous-input").add_nodes(n_nodes=3) - with raises(Exception): - network.add_edges(kind="error") - - network.add_edges( - kind="volatility", parent_idxs=2, children_idxs=0, coupling_strengths=1 +def test_imports(): + """Test the data import function""" + _ = load_data("continuous") + _, _ = load_data("binary") + + with raises(Exception): + load_data("error") + + +def test_beliefs_propagation(): + """Test the loop_inputs function""" + + ############################################### + # one value parent with one volatility parent # + ############################################### + input_node_parameters = { + "input_precision": 1e4, + "expected_precision": jnp.nan, + "surprise": 0.0, + "time_step": 0.0, + "values": 0.0, + "observed": 1, + "volatility_coupling_parents": None, + "value_coupling_parents": (1.0,), + "temp": { + "effective_precision": 1.0, + "value_prediction_error": 0.0, + "volatility_prediction_error": 0.0, + }, + } + node_parameters_1 = { + "expected_precision": 1.0, + "precision": 1.0, + "expected_mean": 1.0, + "value_coupling_children": (1.0,), + "value_coupling_parents": None, + "volatility_coupling_parents": (1.0,), + "volatility_coupling_children": None, + "mean": 1.0, + "observed": 1, + "tonic_volatility": -3.0, + "tonic_drift": 0.0, + "temp": { + "effective_precision": 1.0, + "value_prediction_error": 0.0, + "volatility_prediction_error": 0.0, + }, + } + node_parameters_2 = { + "expected_precision": 1.0, + "precision": 1.0, + "expected_mean": 1.0, + "value_coupling_children": None, + "value_coupling_parents": None, + "volatility_coupling_parents": None, + "volatility_coupling_children": (1.0,), + "mean": 1.0, + "observed": 1, + "tonic_volatility": -3.0, + "tonic_drift": 0.0, + "temp": { + "effective_precision": 1.0, + "value_prediction_error": 0.0, + "volatility_prediction_error": 0.0, + }, + } + edges = ( + AdjacencyLists(0, (1,), None, None, None, (None,)), + AdjacencyLists(2, None, (2,), (0,), None, (None,)), + AdjacencyLists(2, None, None, None, (1,), (None,)), + ) + attributes = ( + input_node_parameters, + node_parameters_1, + node_parameters_2, + ) + + # create update sequence + sequence1 = 0, continuous_input_prediction_error + sequence2 = 1, continuous_node_update + sequence3 = 2, continuous_node_update_ehgf + update_sequence = (sequence1, sequence2, sequence3) + + # one batch of new observations with time step + data = jnp.array([0.2]) + time_steps = jnp.ones(1) + observed = jnp.ones(1) + inputs = Inputs(0, 1) + + # apply sequence + new_attributes, _ = beliefs_propagation( + attributes=attributes, + input_data=(data, time_steps, observed), + update_sequence=update_sequence, + structure=(inputs, edges), + ) + + assert new_attributes[1]["mean"] == 0.20008 + assert new_attributes[2]["precision"] == 1.5 + + +def test_add_edges(): + """Test the add_edges function.""" + network = Network().add_nodes(kind="continuous-input").add_nodes(n_nodes=3) + with raises(Exception): + network.add_edges(kind="error") + + network.add_edges( + kind="volatility", parent_idxs=2, children_idxs=0, coupling_strengths=1 + ) + network.add_edges(parent_idxs=1, children_idxs=0, coupling_strengths=1.0) + + +def test_find_branch(): + """Test the find_branch function.""" + edges = ( + AdjacencyLists(0, (1,), None, None, None, (None,)), + AdjacencyLists(2, None, (2,), (0,), None, (None,)), + AdjacencyLists(2, None, None, None, (1,), (None,)), + AdjacencyLists(2, (4,), None, None, None, (None,)), + AdjacencyLists(2, None, None, (3,), None, (None,)), + ) + branch_list = list_branches([0], edges, branch_list=[]) + assert branch_list == [0, 1, 2] + + +def test_set_update_sequence(): + """Test the set_update_sequence function.""" + + # a standard binary HGF + network1 = ( + Network() + .add_nodes(kind="binary-input") + .add_nodes(kind="binary-state", value_children=0) + .add_nodes(value_children=1) + .set_update_sequence() + ) + assert len(network1.update_sequence) == 6 + + # a standard continuous HGF + network2 = ( + Network() + .add_nodes(kind="continuous-input") + .add_nodes(value_children=0) + .add_nodes(volatility_children=1) + .set_update_sequence(update_type="standard") + ) + assert len(network2.update_sequence) == 6 + + # a generic input with a normal-EF node + network3 = ( + Network() + .add_nodes(kind="generic-input") + .add_nodes(kind="ef-normal") + .set_update_sequence() + ) + assert len(network3.update_sequence) == 2 + + # a Dirichlet node + network4 = ( + Network() + .add_nodes(kind="generic-input") + .add_nodes(kind="DP-state", value_children=0, alpha=0.1, batch_size=2) + .add_nodes( + kind="ef-normal", + n_nodes=2, + value_children=1, + xis=jnp.array([0.0, 1 / 8]), + nus=15.0, ) - network.add_edges(parent_idxs=1, children_idxs=0, coupling_strengths=1.0) - - def test_find_branch(self): - """Test the find_branch function.""" - edges = ( - AdjacencyLists(0, (1,), None, None, None, (None,)), - AdjacencyLists(2, None, (2,), (0,), None, (None,)), - AdjacencyLists(2, None, None, None, (1,), (None,)), - AdjacencyLists(2, (4,), None, None, None, (None,)), - AdjacencyLists(2, None, None, (3,), None, (None,)), - ) - branch_list = list_branches([0], edges, branch_list=[]) - assert branch_list == [0, 1, 2] - - def test_set_update_sequence(self): - """Test the set_update_sequence function.""" - - # a standard binary HGF - network1 = ( - Network() - .add_nodes(kind="binary-input") - .add_nodes(kind="binary-state", value_children=0) - .add_nodes(value_children=1) - .set_update_sequence() - ) - assert len(network1.update_sequence) == 6 - - # a standard continuous HGF - network2 = ( - Network() - .add_nodes(kind="continuous-input") - .add_nodes(value_children=0) - .add_nodes(volatility_children=1) - .set_update_sequence(update_type="standard") - ) - assert len(network2.update_sequence) == 6 - - # a generic input with a normal-EF node - network3 = ( - Network() - .add_nodes(kind="generic-input") - .add_nodes(kind="ef-normal") - .set_update_sequence() - ) - assert len(network3.update_sequence) == 2 - - # a Dirichlet node - network4 = ( - Network() - .add_nodes(kind="generic-input") - .add_nodes(kind="DP-state", value_children=0, alpha=0.1, batch_size=2) - .add_nodes( - kind="ef-normal", - n_nodes=2, - value_children=1, - xis=jnp.array([0.0, 1 / 8]), - nus=15.0, - ) - .set_update_sequence() - ) - assert len(network4.update_sequence) == 5 - - -if __name__ == "__main__": - unittest.main(argv=["first-arg-is-ignored"], exit=False) + .set_update_sequence() + ) + assert len(network4.update_sequence) == 5