Skip to content

Commit

Permalink
explicit variable names in the update functions and avoid duplicating…
Browse files Browse the repository at this point in the history
… parameters (#75)
  • Loading branch information
LegrandNico authored Aug 1, 2023
1 parent 3f96315 commit 68cdcf4
Show file tree
Hide file tree
Showing 2 changed files with 343 additions and 278 deletions.
301 changes: 160 additions & 141 deletions pyhgf/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,79 +62,92 @@ def binary_node_update(
"""
# using the current node index, unwrap parameters and parents
node_parameters = parameters_structure[node_idx]
value_parents_idx = node_structure[node_idx].value_parents
volatility_parents_idx = node_structure[node_idx].volatility_parents
value_parent_idxs = node_structure[node_idx].value_parents

# Return here if no parents node are provided
if (value_parents_idx is None) and (volatility_parents_idx is None):
# Return here if no parents node are found
if value_parent_idxs is None:
return parameters_structure

pihat = node_parameters["pihat"]
vape = jnp.subtract(node_parameters["mu"], node_parameters["muhat"])
pihat = parameters_structure[node_idx]["pihat"]
vape = (
parameters_structure[node_idx]["mu"] - parameters_structure[node_idx]["muhat"]
)

#######################################
# Update the continuous value parents #
#######################################
if value_parents_idx is not None:
# unpack the current parent's parameters with value and volatility parents
va_pa_node_parameters = parameters_structure[value_parents_idx[0]]
# va_pa_value_parents_idx = node_structure[value_parents_idx[0]].value_parents
va_pa_volatility_parents_idx = node_structure[
value_parents_idx[0]
].volatility_parents

# 1. get pihat_pa and nu_pa from the value parent (x2)

# 1.1 get new_nu (x2)
# 1.1.1 get logvol
logvol = va_pa_node_parameters["omega"]

# 1.1.2 Look at the (optional) va_pa's volatility parents
# and update logvol accordingly
if va_pa_volatility_parents_idx is not None:
for va_pa_vo_pa, k in zip(
va_pa_volatility_parents_idx, va_pa_node_parameters["kappas_parents"]
):
logvol += k * parameters_structure[va_pa_vo_pa]["mu"]

# 1.1.3 Compute new_nu
nu = time_step * jnp.exp(logvol)
new_nu = jnp.where(nu > 1e-128, nu, jnp.nan)

# 1.2 Compute new value for nu and pihat
pihat_pa, nu_pa = [1 / (1 / va_pa_node_parameters["pi"] + new_nu), new_nu]

# 2.
pi_pa = pihat_pa + 1 / pihat

# 3. get muhat_pa from value parent (x2)

# 3.1
driftrate = va_pa_node_parameters["rho"]

# TODO: this will be decided once we figure out the multi parents/children
# 3.2 Look at the (optional) va_pa's value parents
# and update driftrate accordingly
# if va_pa_value_parents is not None:
# psi
# for va_pa_va_pa, psi in zip(
# va_pa_value_parents, va_pa_node_parameters["psis"]
# ):
# _mu = va_pa_va_pa[0]["mu"]
# driftrate += psi * _mu
# 3.3
muhat_pa = va_pa_node_parameters["mu"] + time_step * driftrate

# 4.
mu_pa = muhat_pa + vape / pi_pa # This line differs from the continuous input

# 5. Update node's parameters and node's parents recursively
parameters_structure[value_parents_idx[0]]["pihat"] = pihat_pa
parameters_structure[value_parents_idx[0]]["pi"] = pi_pa
parameters_structure[value_parents_idx[0]]["muhat"] = muhat_pa
parameters_structure[value_parents_idx[0]]["mu"] = mu_pa
parameters_structure[value_parents_idx[0]]["nu"] = nu_pa
if value_parent_idxs is not None:
for value_parent_idx in value_parent_idxs:
value_parent_value_parent_idxs = node_structure[
value_parent_idx
].value_parents
value_parent_volatility_parent_idxs = node_structure[
value_parent_idx
].volatility_parents

# 1. get pihat_value_parent and nu_value_parent from the value parent (x2)

# 1.1 get new_nu (x2)
# 1.1.1 get logvol
logvol = parameters_structure[value_parent_idx]["omega"]

# 1.1.2 Look at the (optional) va_pa's volatility parents
# and update logvol accordingly
if value_parent_volatility_parent_idxs is not None:
for value_parent_volatility_parent_idx, k in zip(
value_parent_volatility_parent_idxs,
parameters_structure[value_parent_idx]["kappas_parents"],
):
logvol += (
k
* parameters_structure[value_parent_volatility_parent_idx]["mu"]
)

# 1.1.3 Compute new_nu
nu = time_step * jnp.exp(logvol)
new_nu = jnp.where(nu > 1e-128, nu, jnp.nan)

# 1.2 Compute new value for nu and pihat
pihat_value_parent, nu_value_parent = [
1 / (1 / parameters_structure[value_parent_idx]["pi"] + new_nu),
new_nu,
]

# 2.
pi_value_parent = pihat_value_parent + 1 / pihat

# 3. get muhat_value_parent from value parent (x2)

# 3.1
driftrate = parameters_structure[value_parent_idx]["rho"]

# 3.2 Look at the (optional) value parent's value parents
# and update driftrate accordingly
if value_parent_value_parent_idxs is not None:
for value_parent_value_parent_idx, psi in zip(
value_parent_value_parent_idxs,
parameters_structure[value_parent_idx]["psis_parents"],
):
driftrate += (
psi * parameters_structure[value_parent_value_parent_idx]["mu"]
)

# 3.3
muhat_value_parent = (
parameters_structure[value_parent_idx]["mu"] + time_step * driftrate
)

# 4.
mu_value_parent = (
muhat_value_parent + vape / pi_value_parent
) # This line differs from the continuous input

# 5. Update node's parameters and node's parents recursively
parameters_structure[value_parent_idx]["pihat"] = pihat_value_parent
parameters_structure[value_parent_idx]["pi"] = pi_value_parent
parameters_structure[value_parent_idx]["muhat"] = muhat_value_parent
parameters_structure[value_parent_idx]["mu"] = mu_value_parent
parameters_structure[value_parent_idx]["nu"] = nu_value_parent

return parameters_structure

Expand Down Expand Up @@ -190,72 +203,74 @@ def binary_input_update(
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
"""
# using the current node index, unwrap parameters and parents
input_node_parameters = parameters_structure[node_idx]
value_parents_idx = node_structure[node_idx].value_parents
volatility_parents_idx = node_structure[node_idx].volatility_parents
# list value and volatility parents
value_parent_idxs = node_structure[node_idx].value_parents
volatility_parent_idxs = node_structure[node_idx].volatility_parents

if (value_parents_idx is None) and (volatility_parents_idx is None):
if (value_parent_idxs is None) and (volatility_parent_idxs is None):
return parameters_structure

pihat = input_node_parameters["pihat"]

################################
# Update parents (binary node) #
################################

if value_parents_idx is not None:
# unpack the current parent's parameters with value and volatility parents
# va_pa_node_parameters = parameters_structure[value_parents_idx[0]]
va_pa_value_parents_idx = node_structure[value_parents_idx[0]].value_parents
# va_pa_volatility_parents_idx = node_structure[
# value_parents_idx[0]
# ].volatility_parents

# 1. Compute new muhat_pa and pihat_pa from binary node parent
# ------------------------------------------------------------

# 1.1 Compute new_muhat from continuous node parent (x2)

# 1.1.1 get rho from the value parent of the binary node (x2)
driftrate = parameters_structure[va_pa_value_parents_idx[0]]["rho"]

# TODO: this will be decided once we figure out the multi parents/children
# # 1.1.2 Look at the (optional) va_pa's value parents (x3)
# # and update the drift rate accordingly
# if va_pa_value_parents[0][1] is not None:
# for va_pa_va_pa in va_pa_value_parents[0][1]:
# # For each x2's value parents (optional)
# _psi = va_pa_value_parents[0][0]["psis"]
# _mu = va_pa_va_pa[0]["mu"]
# driftrate += _psi * _mu

# 1.1.3 compute new_muhat
muhat_va_pa = (
parameters_structure[va_pa_value_parents_idx[0]]["mu"]
+ time_step * driftrate
)

muhat_va_pa = sgm(muhat_va_pa)
pihat_va_pa = 1 / (muhat_va_pa * (1 - muhat_va_pa))

# 2. Compute surprise
# -------------------
eta0 = input_node_parameters["eta0"]
eta1 = input_node_parameters["eta1"]

mu_va_pa, pi_va_pa, surprise = cond(
pihat == jnp.inf,
input_surprise_inf,
input_surprise_reg,
(pihat, value, eta1, eta0, muhat_va_pa),
)

# Update value parent's parameters
parameters_structure[value_parents_idx[0]]["pihat"] = pihat_va_pa
parameters_structure[value_parents_idx[0]]["pi"] = pi_va_pa
parameters_structure[value_parents_idx[0]]["muhat"] = muhat_va_pa
parameters_structure[value_parents_idx[0]]["mu"] = mu_va_pa
pihat = parameters_structure[node_idx]["pihat"]

#######################################################
# Update the value parent(s) of the binary input node #
#######################################################

if value_parent_idxs is not None:
for value_parent_idx in value_parent_idxs:
# list the (unique) value parents
value_parent_value_parent_idxs = node_structure[
value_parent_idx
].value_parents[0]

# 1. Compute new muhat_value_parent and pihat_value_parent
# --------------------------------------------------------
# 1.1 Compute new_muhat from continuous node parent (x2)
# 1.1.1 get rho from the value parent of the binary node (x2)
driftrate = parameters_structure[value_parent_value_parent_idxs]["rho"]

# # 1.1.2 Look at the (optional) value parent's value parents (x3)
# # and update the drift rate accordingly
if node_structure[value_parent_value_parent_idxs].value_parents is not None:
for value_parent_value_parent_value_parent_idx in node_structure[
value_parent_value_parent_idxs
].value_parents:
# For each x2's value parents (optional)
driftrate += (
parameters_structure[value_parent_value_parent_idxs][
"psis_parents"
]
* parameters_structure[
value_parent_value_parent_value_parent_idx
]["mu"]
)

# 1.1.3 compute new_muhat
muhat_value_parent = (
parameters_structure[value_parent_value_parent_idxs]["mu"]
+ time_step * driftrate
)

muhat_value_parent = sgm(muhat_value_parent)
pihat_value_parent = 1 / (muhat_value_parent * (1 - muhat_value_parent))

# 2. Compute surprise
# -------------------
eta0 = parameters_structure[node_idx]["eta0"]
eta1 = parameters_structure[node_idx]["eta1"]

mu_value_parent, pi_value_parent, surprise = cond(
pihat == jnp.inf,
input_surprise_inf,
input_surprise_reg,
(pihat, value, eta1, eta0, muhat_value_parent),
)

# Update value parent's parameters
parameters_structure[value_parent_idx]["pihat"] = pihat_value_parent
parameters_structure[value_parent_idx]["pi"] = pi_value_parent
parameters_structure[value_parent_idx]["muhat"] = muhat_value_parent
parameters_structure[value_parent_idx]["mu"] = mu_value_parent

parameters_structure[node_idx]["surprise"] = surprise
parameters_structure[node_idx]["time_step"] = time_step
Expand Down Expand Up @@ -322,17 +337,17 @@ def binary_surprise(

def input_surprise_inf(op):
"""Apply special case if pihat is `jnp.inf` (just pass the value through)."""
(pihat, value, eta1, eta0, muhat_va_pa) = op
mu_va_pa = value
pi_va_pa = jnp.inf
surprise = binary_surprise(value, muhat_va_pa)
_, value, _, _, muhat_value_parent = op
mu_value_parent = value
pi_value_parent = jnp.inf
surprise = binary_surprise(value, muhat_value_parent)

return mu_va_pa, pi_va_pa, surprise
return mu_value_parent, pi_value_parent, surprise


def input_surprise_reg(op):
"""Compute the surprise, mu_va_pa and pi_va_pa if pihat is not `jnp.inf`."""
(pihat, value, eta1, eta0, muhat_va_pa) = op
"""Compute the surprise, mu_value_parent and pi_value_parent."""
pihat, value, eta1, eta0, muhat_value_parent = op

# Likelihood under eta1
und1 = jnp.exp(jnp.subtract(0, pihat) / 2 * (jnp.subtract(value, eta1)) ** 2)
Expand All @@ -341,13 +356,17 @@ def input_surprise_reg(op):
und0 = jnp.exp(jnp.subtract(0, pihat) / 2 * (jnp.subtract(value, eta0)) ** 2)

# Eq. 39 in Mathys et al. (2014) (i.e., Bayes)
mu_va_pa = muhat_va_pa * und1 / (muhat_va_pa * und1 + (1 - muhat_va_pa) * und0)
pi_va_pa = 1 / (mu_va_pa * (1 - mu_va_pa))
mu_value_parent = (
muhat_value_parent
* und1
/ (muhat_value_parent * und1 + (1 - muhat_value_parent) * und0)
)
pi_value_parent = 1 / (mu_value_parent * (1 - mu_value_parent))

# Surprise
surprise = -jnp.log(
muhat_va_pa * gaussian_density(value, eta1, pihat)
+ (1 - muhat_va_pa) * gaussian_density(value, eta0, pihat)
muhat_value_parent * gaussian_density(value, eta1, pihat)
+ (1 - muhat_value_parent) * gaussian_density(value, eta0, pihat)
)

return mu_va_pa, pi_va_pa, surprise
return mu_value_parent, pi_value_parent, surprise
Loading

0 comments on commit 68cdcf4

Please sign in to comment.