Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for nonlinear value coupling #215

Merged
merged 32 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3b2096a
test_2
KoraTMontemagno Jun 5, 2024
6e1d4d9
Non-linear update functions
KoraTMontemagno Jul 4, 2024
a238466
Hi! I have updated the code. The non-linear function is inserted in t…
KoraTMontemagno Jul 23, 2024
19a862a
Multiple value parents
KoraTMontemagno Jul 24, 2024
9af1b82
Some errors corrected - tested on ReLu function
KoraTMontemagno Aug 5, 2024
bfed60c
Tutorial for non-linear value coupling
KoraTMontemagno Aug 5, 2024
7e80044
minor modifications
KoraTMontemagno Aug 5, 2024
dfd3864
sin function added to tutorial
KoraTMontemagno Aug 7, 2024
8e574fc
pre-commit errors
LegrandNico Aug 9, 2024
4219912
simplify if else conditions
LegrandNico Aug 9, 2024
c0e230e
all tests passing
LegrandNico Aug 9, 2024
51102d1
add double arrow to plot non linear value coupling in pålot_network
LegrandNico Aug 12, 2024
554c6af
multiple coupling functions for multiple children
KoraTMontemagno Aug 12, 2024
b3e8d24
add sin example in the tutorial
LegrandNico Aug 12, 2024
5ec55a2
add coupling_fn in add_edges
LegrandNico Aug 12, 2024
d29eb28
tutorial
LegrandNico Aug 12, 2024
a9ca71f
small fix
LegrandNico Aug 12, 2024
6e41c5f
length of coupling_fn tuple adjusted
KoraTMontemagno Aug 12, 2024
9c12404
adjusted tutorial
KoraTMontemagno Aug 12, 2024
0462381
non-linear test draft (+ minor modification)
KoraTMontemagno Aug 12, 2024
830c494
Non-linear tests added
KoraTMontemagno Aug 13, 2024
b1c977b
Tutorial: eliminated resonant frequencies
KoraTMontemagno Aug 13, 2024
8324bb3
Tutorial: frequency tracking
KoraTMontemagno Aug 14, 2024
f3789c1
All test passed
KoraTMontemagno Aug 14, 2024
3bc9519
test continuous_node_update_nonlinear
KoraTMontemagno Aug 14, 2024
b24fe74
minor fix
KoraTMontemagno Aug 14, 2024
723ac73
increased test coverage
KoraTMontemagno Aug 14, 2024
fc44a28
useless tests eliminated
KoraTMontemagno Aug 16, 2024
938fd93
documentation
LegrandNico Aug 22, 2024
3e11523
comment
LegrandNico Aug 22, 2024
4483fc3
ref
LegrandNico Aug 22, 2024
85e35e8
fixing github action
LegrandNico Aug 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
branch: gh-pages

- name: Deploy Dev 🚀
if: github.event_name == 'pull_request'
if: (github.event_name == 'pull_request') || (github.event_name == 'push')
uses: JamesIves/github-pages-deploy-action@v4
with:
folder: docs/build/html
Expand Down
Binary file added docs/source/images/non_linear_coupling.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 11 additions & 3 deletions docs/source/learn.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,27 +121,35 @@ Advanced customisation of predictive coding neural networks and Bayesian modelli
::::{grid} 1 1 2 3
:gutter: 1

:::{grid-item-card} Using custom response functions
:::{grid-item-card} Using custom response functions
:link: custom_response_functions
:link-type: ref
:img-top: ./images/response_models.png

How to adapt any model to specific behaviours and experimental design by using custom response functions.
:::

:::{grid-item-card} Embedding the Hierarchical Gaussian Filter in a Bayesian network for multilevel inference
:::{grid-item-card} Embedding the Hierarchical Gaussian Filter in a Bayesian network for multilevel inference
:link: multilevel_hgf
:link-type: ref
:img-top: ./images/multilevel-hgf.png

How to use any model as a distribution to perform hierarchical inference at the group level.
:::

:::{grid-item-card} Parameter recovery, prior and posterior predictive sampling
:::{grid-item-card} Parameter recovery, prior and posterior predictive sampling
:link: parameters_recovery
:link-type: ref
:img-top: ./images/parameter_recovery.png

Recovering parameters from the generative model and using the sampling functionalities to estimate prior and posterior uncertainties.
:::

:::{grid-item-card} Non-linear value coupling
:link: non_linear_coupling
:link-type: ref
:img-top: ./images/non_linear_coupling.png

Recovering parameters from the generative model and using the sampling functionalities to estimate prior and posterior uncertainties.
:::
::::
Expand Down
90 changes: 53 additions & 37 deletions docs/source/notebooks/0.2-Creating_networks.ipynb

Large diffs are not rendered by default.

1,114 changes: 1,114 additions & 0 deletions docs/source/notebooks/5-Non_linear_value_coupling.ipynb

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,14 @@ @InProceedings{mathys:2020
}

@article{Vehtari:2015,
doi = {10.48550/ARXIV.1507.04544},
url = {https://arxiv.org/abs/1507.04544},
author = {Vehtari, Aki and Gelman, Andrew and Gabry, Jonah},
keywords = {Computation (stat.CO), Methodology (stat.ME), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC},
publisher = {arXiv},
year = {2015},
copyright = {arXiv.org perpetual, non-exclusive license}
}
title={Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC},
volume={27},
ISSN={1573-1375},
url={http://dx.doi.org/10.1007/s11222-016-9696-4},
DOI={10.1007/s11222-016-9696-4},
number={5},
journal={Statistics and Computing},
publisher={Springer Science and Business Media LLC},
author={Vehtari, Aki and Gelman, Andrew and Gabry, Jonah},
year={2016},
month=aug, pages={1413–1432} }
45 changes: 40 additions & 5 deletions src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def add_nodes(
value_parents: Optional[Union[List, Tuple, int]] = None,
volatility_children: Optional[Union[List, Tuple, int]] = None,
volatility_parents: Optional[Union[List, Tuple, int]] = None,
coupling_fn: Tuple[Optional[Callable], ...] = (None,),
**additional_parameters,
):
"""Add new input/state node(s) to the neural network.
Expand Down Expand Up @@ -397,6 +398,14 @@ def add_nodes(
integer or a list of integers, in case of multiple children. The coupling
strength can be controlled by passing a tuple, where the first item is the
list of indexes, and the second item is the list of coupling strengths.
coupling_fn :
Coupling function(s) between the current node and its value children.
It has to be provided as a tuple. If multiple value children are specified,
the coupling functions must be stated in the same order of the children.
Note: if a node has multiple parents nodes with different coupling
functions, a coupling function should be indicated for all the parent nodes.
If no coupling function is stated, the relationship between nodes is assumed
linear.
**kwargs :
Additional keyword parameters will be passed and overwrite the node
attributes.
Expand All @@ -420,6 +429,16 @@ def add_nodes(
)
)

# assess children number
# this is required to ensure the coupling functions match
children_number = 1
if value_children is None:
children_number = 0
elif isinstance(value_children, int):
children_number = 1
elif isinstance(value_children, list):
children_number = len(value_children)

# transform coupling parameter into tuple of indexes and strenghts
couplings = []
for indexes in [
Expand Down Expand Up @@ -638,14 +657,19 @@ def add_nodes(

node_idx = len(self.attributes) # the index of the new node

# for mutiple value children, set a default tuple with corresponding length
if children_number != len(coupling_fn):
if coupling_fn == (None,):
coupling_fn = children_number * coupling_fn
else:
raise ValueError(
"The number of coupling fn and value children do not match"
)

# add a new edge
edges_as_list.append(
AdjacencyLists(
node_type,
None,
None,
None,
None,
node_type, None, None, None, None, coupling_fn=coupling_fn
)
)

Expand Down Expand Up @@ -684,6 +708,7 @@ def add_nodes(
parent_idxs=node_idx,
children_idxs=value_children[0],
coupling_strengths=value_children[1], # type: ignore
coupling_fn=coupling_fn,
)
if volatility_children[0] is not None:
self.add_edges(
Expand Down Expand Up @@ -788,6 +813,7 @@ def add_edges(
parent_idxs=Union[int, List[int]],
children_idxs=Union[int, List[int]],
coupling_strengths: Union[float, List[float], Tuple[float]] = 1.0,
coupling_fn: Tuple[Optional[Callable], ...] = (None,),
) -> "Network":
"""Add a value or volatility coupling link between a set of nodes.

Expand All @@ -801,6 +827,14 @@ def add_edges(
The index(es) of the children node(s).
coupling_strengths :
The coupling strength betwen the parents and children.
coupling_fn :
Coupling function(s) between the current node and its value children.
It has to be provided as a tuple. If multiple value children are specified,
the coupling functions must be stated in the same order of the children.
Note: if a node has multiple parents nodes with different coupling
functions, a coupling function should be indicated for all the parent nodes.
If no coupling function is stated, the relationship between nodes is assumed
linear.

"""
attributes, edges = add_edges(
Expand All @@ -810,6 +844,7 @@ def add_edges(
parent_idxs=parent_idxs,
children_idxs=children_idxs,
coupling_strengths=coupling_strengths,
coupling_fn=coupling_fn,
)

self.attributes = attributes
Expand Down
5 changes: 5 additions & 0 deletions src/pyhgf/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,14 @@ def plot_network(network: "Network") -> "Source":

if value_parents is not None:
for value_parents_idx in value_parents:

# get the coupling function from the value parent
child_idx = network.edges[value_parents_idx].value_children.index(i)
coupling_fn = network.edges[value_parents_idx].coupling_fn[child_idx]
graphviz_structure.edge(
f"x_{value_parents_idx}",
f"x_{i}",
color="black" if coupling_fn is None else "black:invis:black",
)

# connect volatility parents
Expand Down
4 changes: 4 additions & 0 deletions src/pyhgf/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ class AdjacencyLists(NamedTuple):
mean and unknown variance.
* 4: Dirichlet Process state node.

The variable `coupling_fn` list the coupling functions between this nodes and the
children nodes. If `None` is provided, a linear coupling is assumed.

"""

node_type: int
value_parents: Optional[Tuple]
volatility_parents: Optional[Tuple]
value_children: Optional[Tuple]
volatility_children: Optional[Tuple]
coupling_fn: Tuple[Optional[Callable], ...]


class Inputs(NamedTuple):
Expand Down
73 changes: 64 additions & 9 deletions src/pyhgf/updates/posterior/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from typing import Dict

import jax.numpy as jnp
from jax import jit
from jax import grad, jit

from pyhgf.typing import Edges


@partial(jit, static_argnames=("edges", "node_idx"))
def posterior_update_mean_continuous_node(
attributes: Dict, edges: Edges, node_idx: int, node_precision: float
attributes: Dict,
edges: Edges,
node_idx: int,
node_precision: float,
) -> float:
r"""Update the mean of a state node using the value prediction errors.

Expand All @@ -20,6 +23,8 @@ def posterior_update_mean_continuous_node(
The new mean of a state node :math:`b` value coupled with other input and/or state
nodes :math:`j` at time :math:`k` is given by:

For linear value coupling:

.. math::
\mu_b^{(k)} = \hat{\mu}_b^{(k)} + \sum_{j=1}^{N_{children}}
\frac{\kappa_j \hat{\pi}_j^{(k)}}{\pi_b} \delta_j^{(k)}
Expand All @@ -32,6 +37,14 @@ def posterior_update_mean_continuous_node(
If the child node is a state node, this value was computed by
:py:func:`pyhgf.updates.prediction_errors.nodes.continuous.continuous_node_value_prediction_error`.

For non-linear value coupling:

.. math::
\mu_b^{(k)} = \hat{\mu}_b^{(k)} + \sum_{j=1}^{N_{children}}
\frac{\kappa_j g'_{j,b}({\mu}_b^{(k-1)}) \hat{\pi}_j^{(k)}}{\pi_b}
\delta_j^{(k)}


2. Mean update from volatility coupling.

The new mean of a state node :math:`b` volatility coupled with other input and/or
Expand Down Expand Up @@ -115,9 +128,10 @@ def posterior_update_mean_continuous_node(
# Value coupling updates - update the mean of a value parent
# ----------------------------------------------------------
if edges[node_idx].value_children is not None:
for value_child_idx, value_coupling in zip(
for value_child_idx, value_coupling, coupling_fn in zip(
edges[node_idx].value_children, # type: ignore
attributes[node_idx]["value_coupling_children"],
edges[node_idx].coupling_fn,
):
# get the value prediction error (VAPE)
# if this is jnp.nan (no observation) set the VAPE to 0.0
Expand All @@ -128,11 +142,22 @@ def posterior_update_mean_continuous_node(
# cancel the prediction error if the child value was not observed
value_prediction_error *= attributes[value_child_idx]["observed"]

# get differential of coupling function with value children
if coupling_fn is None: # linear coupling
coupling_fn_prime = 1
else: # non-linear coupling
# Compute the derivative of the coupling function
coupling_fn_prime = grad(coupling_fn)(attributes[node_idx]["mean"])

# expected precisions from the value children
# sum the precision weigthed prediction errors over all children
value_precision_weigthed_prediction_error += (
(
(value_coupling * attributes[value_child_idx]["expected_precision"])
(
value_coupling
* attributes[value_child_idx]["expected_precision"]
* coupling_fn_prime
)
/ node_precision
)
) * value_prediction_error
Expand All @@ -149,7 +174,8 @@ def posterior_update_mean_continuous_node(
"volatility_prediction_error"
]

# retrieve the effective precision (γ) computed during the prediction step
# retrieve the effective precision (γ)
# computed during the prediction step
effective_precision = attributes[volatility_child_idx]["temp"][
"effective_precision"
]
Expand Down Expand Up @@ -197,6 +223,8 @@ def posterior_update_precision_continuous_node(
The new precision of a state node :math:`b` value coupled with other input and/or
state nodes :math:`j` at time :math:`k` is given by:

For linear coupling (default)

.. math::

\pi_b^{(k)} = \hat{\pi}_b^{(k)} + \sum_{j=1}^{N_{children}}
Expand All @@ -210,6 +238,13 @@ def posterior_update_precision_continuous_node(
If the child node is a state node, this value was computed by
:py:func:`pyhgf.updates.prediction_errors.nodes.continuous.continuous_node_value_prediction_error`.

For non-linear value coupling:

.. math::

\pi_b^{(k)} = \hat{\pi}_b^{(k)} + \sum_{j=1}^{N_{children}}
\hat{\pi}_j^{(k)} * (\kappa_j^2 * g'_{j,b}(\mu_b^(k-1))^2 -
g''_{j,b}(\mu_b^(k-1))*\delta_j)

#. Precision update from volatility coupling.

Expand Down Expand Up @@ -284,13 +319,30 @@ def posterior_update_precision_continuous_node(
# Value coupling updates - update the precision of a value parent
# ---------------------------------------------------------------
if edges[node_idx].value_children is not None:
for value_child_idx, value_coupling in zip(
for value_child_idx, value_coupling, coupling_fn in zip(
edges[node_idx].value_children, # type: ignore
attributes[node_idx]["value_coupling_children"],
edges[node_idx].coupling_fn,
):
if coupling_fn is None: # linear coupling
coupling_fn_prime = 1
coupling_fn_second = 0
else: # non-linear coupling
coupling_fn_prime = grad(coupling_fn)(attributes[node_idx]["mean"]) ** 2
value_prediction_error = attributes[value_child_idx]["temp"][
"value_prediction_error"
]
coupling_fn_second = (
grad(grad(coupling_fn))(attributes[node_idx]["mean"])
* value_prediction_error
)

# cancel the prediction error if the child value was not observed
precision_weigthed_prediction_error += (
value_coupling**2 * attributes[value_child_idx]["expected_precision"]
value_coupling**2
* attributes[value_child_idx]["expected_precision"]
* coupling_fn_prime
- coupling_fn_second
) * attributes[value_child_idx]["observed"]

# Volatility coupling updates - update the precision of a volatility parent
Expand Down Expand Up @@ -334,7 +386,7 @@ def posterior_update_precision_continuous_node(
)

# additionnal steps for unobserved values
# ----------------------------------------------------------------------------------
# ---------------------------------------

# List the node's volatility parents
volatility_parents_idxs = edges[node_idx].volatility_parents
Expand Down Expand Up @@ -493,7 +545,10 @@ def continuous_node_update_ehgf(
attributes[node_idx]["mean"] = posterior_mean

posterior_precision = posterior_update_precision_continuous_node(
attributes, edges, node_idx, time_step
attributes,
edges,
node_idx,
time_step,
)
attributes[node_idx]["precision"] = posterior_precision

Expand Down
Loading
Loading