Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

explicit variable names in the update functions and avoid duplicating parameters #75

Merged
merged 4 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 160 additions & 141 deletions pyhgf/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,79 +62,92 @@ def binary_node_update(

"""
# using the current node index, unwrap parameters and parents
node_parameters = parameters_structure[node_idx]
value_parents_idx = node_structure[node_idx].value_parents
volatility_parents_idx = node_structure[node_idx].volatility_parents
value_parent_idxs = node_structure[node_idx].value_parents

# Return here if no parents node are provided
if (value_parents_idx is None) and (volatility_parents_idx is None):
# Return here if no parents node are found
if value_parent_idxs is None:
return parameters_structure

pihat = node_parameters["pihat"]
vape = jnp.subtract(node_parameters["mu"], node_parameters["muhat"])
pihat = parameters_structure[node_idx]["pihat"]
vape = (
parameters_structure[node_idx]["mu"] - parameters_structure[node_idx]["muhat"]
)

#######################################
# Update the continuous value parents #
#######################################
if value_parents_idx is not None:
# unpack the current parent's parameters with value and volatility parents
va_pa_node_parameters = parameters_structure[value_parents_idx[0]]
# va_pa_value_parents_idx = node_structure[value_parents_idx[0]].value_parents
va_pa_volatility_parents_idx = node_structure[
value_parents_idx[0]
].volatility_parents

# 1. get pihat_pa and nu_pa from the value parent (x2)

# 1.1 get new_nu (x2)
# 1.1.1 get logvol
logvol = va_pa_node_parameters["omega"]

# 1.1.2 Look at the (optional) va_pa's volatility parents
# and update logvol accordingly
if va_pa_volatility_parents_idx is not None:
for va_pa_vo_pa, k in zip(
va_pa_volatility_parents_idx, va_pa_node_parameters["kappas_parents"]
):
logvol += k * parameters_structure[va_pa_vo_pa]["mu"]

# 1.1.3 Compute new_nu
nu = time_step * jnp.exp(logvol)
new_nu = jnp.where(nu > 1e-128, nu, jnp.nan)

# 1.2 Compute new value for nu and pihat
pihat_pa, nu_pa = [1 / (1 / va_pa_node_parameters["pi"] + new_nu), new_nu]

# 2.
pi_pa = pihat_pa + 1 / pihat

# 3. get muhat_pa from value parent (x2)

# 3.1
driftrate = va_pa_node_parameters["rho"]

# TODO: this will be decided once we figure out the multi parents/children
# 3.2 Look at the (optional) va_pa's value parents
# and update driftrate accordingly
# if va_pa_value_parents is not None:
# psi
# for va_pa_va_pa, psi in zip(
# va_pa_value_parents, va_pa_node_parameters["psis"]
# ):
# _mu = va_pa_va_pa[0]["mu"]
# driftrate += psi * _mu
# 3.3
muhat_pa = va_pa_node_parameters["mu"] + time_step * driftrate

# 4.
mu_pa = muhat_pa + vape / pi_pa # This line differs from the continuous input

# 5. Update node's parameters and node's parents recursively
parameters_structure[value_parents_idx[0]]["pihat"] = pihat_pa
parameters_structure[value_parents_idx[0]]["pi"] = pi_pa
parameters_structure[value_parents_idx[0]]["muhat"] = muhat_pa
parameters_structure[value_parents_idx[0]]["mu"] = mu_pa
parameters_structure[value_parents_idx[0]]["nu"] = nu_pa
if value_parent_idxs is not None:
for value_parent_idx in value_parent_idxs:
value_parent_value_parent_idxs = node_structure[
value_parent_idx
].value_parents
value_parent_volatility_parent_idxs = node_structure[
value_parent_idx
].volatility_parents

# 1. get pihat_value_parent and nu_value_parent from the value parent (x2)

# 1.1 get new_nu (x2)
# 1.1.1 get logvol
logvol = parameters_structure[value_parent_idx]["omega"]

# 1.1.2 Look at the (optional) va_pa's volatility parents
# and update logvol accordingly
if value_parent_volatility_parent_idxs is not None:
for value_parent_volatility_parent_idx, k in zip(
value_parent_volatility_parent_idxs,
parameters_structure[value_parent_idx]["kappas_parents"],
):
logvol += (
k
* parameters_structure[value_parent_volatility_parent_idx]["mu"]
)

# 1.1.3 Compute new_nu
nu = time_step * jnp.exp(logvol)
new_nu = jnp.where(nu > 1e-128, nu, jnp.nan)

# 1.2 Compute new value for nu and pihat
pihat_value_parent, nu_value_parent = [
1 / (1 / parameters_structure[value_parent_idx]["pi"] + new_nu),
new_nu,
]

# 2.
pi_value_parent = pihat_value_parent + 1 / pihat

# 3. get muhat_value_parent from value parent (x2)

# 3.1
driftrate = parameters_structure[value_parent_idx]["rho"]

# 3.2 Look at the (optional) value parent's value parents
# and update driftrate accordingly
if value_parent_value_parent_idxs is not None:
for value_parent_value_parent_idx, psi in zip(
value_parent_value_parent_idxs,
parameters_structure[value_parent_idx]["psis_parents"],
):
driftrate += (
psi * parameters_structure[value_parent_value_parent_idx]["mu"]
)

# 3.3
muhat_value_parent = (
parameters_structure[value_parent_idx]["mu"] + time_step * driftrate
)

# 4.
mu_value_parent = (
muhat_value_parent + vape / pi_value_parent
) # This line differs from the continuous input

# 5. Update node's parameters and node's parents recursively
parameters_structure[value_parent_idx]["pihat"] = pihat_value_parent
parameters_structure[value_parent_idx]["pi"] = pi_value_parent
parameters_structure[value_parent_idx]["muhat"] = muhat_value_parent
parameters_structure[value_parent_idx]["mu"] = mu_value_parent
parameters_structure[value_parent_idx]["nu"] = nu_value_parent

return parameters_structure

Expand Down Expand Up @@ -190,72 +203,74 @@ def binary_input_update(
arXiv. https://doi.org/10.48550/ARXIV.2305.10937

"""
# using the current node index, unwrap parameters and parents
input_node_parameters = parameters_structure[node_idx]
value_parents_idx = node_structure[node_idx].value_parents
volatility_parents_idx = node_structure[node_idx].volatility_parents
# list value and volatility parents
value_parent_idxs = node_structure[node_idx].value_parents
volatility_parent_idxs = node_structure[node_idx].volatility_parents

if (value_parents_idx is None) and (volatility_parents_idx is None):
if (value_parent_idxs is None) and (volatility_parent_idxs is None):
return parameters_structure

pihat = input_node_parameters["pihat"]

################################
# Update parents (binary node) #
################################

if value_parents_idx is not None:
# unpack the current parent's parameters with value and volatility parents
# va_pa_node_parameters = parameters_structure[value_parents_idx[0]]
va_pa_value_parents_idx = node_structure[value_parents_idx[0]].value_parents
# va_pa_volatility_parents_idx = node_structure[
# value_parents_idx[0]
# ].volatility_parents

# 1. Compute new muhat_pa and pihat_pa from binary node parent
# ------------------------------------------------------------

# 1.1 Compute new_muhat from continuous node parent (x2)

# 1.1.1 get rho from the value parent of the binary node (x2)
driftrate = parameters_structure[va_pa_value_parents_idx[0]]["rho"]

# TODO: this will be decided once we figure out the multi parents/children
# # 1.1.2 Look at the (optional) va_pa's value parents (x3)
# # and update the drift rate accordingly
# if va_pa_value_parents[0][1] is not None:
# for va_pa_va_pa in va_pa_value_parents[0][1]:
# # For each x2's value parents (optional)
# _psi = va_pa_value_parents[0][0]["psis"]
# _mu = va_pa_va_pa[0]["mu"]
# driftrate += _psi * _mu

# 1.1.3 compute new_muhat
muhat_va_pa = (
parameters_structure[va_pa_value_parents_idx[0]]["mu"]
+ time_step * driftrate
)

muhat_va_pa = sgm(muhat_va_pa)
pihat_va_pa = 1 / (muhat_va_pa * (1 - muhat_va_pa))

# 2. Compute surprise
# -------------------
eta0 = input_node_parameters["eta0"]
eta1 = input_node_parameters["eta1"]

mu_va_pa, pi_va_pa, surprise = cond(
pihat == jnp.inf,
input_surprise_inf,
input_surprise_reg,
(pihat, value, eta1, eta0, muhat_va_pa),
)

# Update value parent's parameters
parameters_structure[value_parents_idx[0]]["pihat"] = pihat_va_pa
parameters_structure[value_parents_idx[0]]["pi"] = pi_va_pa
parameters_structure[value_parents_idx[0]]["muhat"] = muhat_va_pa
parameters_structure[value_parents_idx[0]]["mu"] = mu_va_pa
pihat = parameters_structure[node_idx]["pihat"]

#######################################################
# Update the value parent(s) of the binary input node #
#######################################################

if value_parent_idxs is not None:
for value_parent_idx in value_parent_idxs:
# list the (unique) value parents
value_parent_value_parent_idxs = node_structure[
value_parent_idx
].value_parents[0]

# 1. Compute new muhat_value_parent and pihat_value_parent
# --------------------------------------------------------
# 1.1 Compute new_muhat from continuous node parent (x2)
# 1.1.1 get rho from the value parent of the binary node (x2)
driftrate = parameters_structure[value_parent_value_parent_idxs]["rho"]

# # 1.1.2 Look at the (optional) value parent's value parents (x3)
# # and update the drift rate accordingly
if node_structure[value_parent_value_parent_idxs].value_parents is not None:
for value_parent_value_parent_value_parent_idx in node_structure[
value_parent_value_parent_idxs
].value_parents:
# For each x2's value parents (optional)
driftrate += (
parameters_structure[value_parent_value_parent_idxs][
"psis_parents"
]
* parameters_structure[
value_parent_value_parent_value_parent_idx
]["mu"]
)

# 1.1.3 compute new_muhat
muhat_value_parent = (
parameters_structure[value_parent_value_parent_idxs]["mu"]
+ time_step * driftrate
)

muhat_value_parent = sgm(muhat_value_parent)
pihat_value_parent = 1 / (muhat_value_parent * (1 - muhat_value_parent))

# 2. Compute surprise
# -------------------
eta0 = parameters_structure[node_idx]["eta0"]
eta1 = parameters_structure[node_idx]["eta1"]

mu_value_parent, pi_value_parent, surprise = cond(
pihat == jnp.inf,
input_surprise_inf,
input_surprise_reg,
(pihat, value, eta1, eta0, muhat_value_parent),
)

# Update value parent's parameters
parameters_structure[value_parent_idx]["pihat"] = pihat_value_parent
parameters_structure[value_parent_idx]["pi"] = pi_value_parent
parameters_structure[value_parent_idx]["muhat"] = muhat_value_parent
parameters_structure[value_parent_idx]["mu"] = mu_value_parent

parameters_structure[node_idx]["surprise"] = surprise
parameters_structure[node_idx]["time_step"] = time_step
Expand Down Expand Up @@ -322,17 +337,17 @@ def binary_surprise(

def input_surprise_inf(op):
"""Apply special case if pihat is `jnp.inf` (just pass the value through)."""
(pihat, value, eta1, eta0, muhat_va_pa) = op
mu_va_pa = value
pi_va_pa = jnp.inf
surprise = binary_surprise(value, muhat_va_pa)
_, value, _, _, muhat_value_parent = op
mu_value_parent = value
pi_value_parent = jnp.inf
surprise = binary_surprise(value, muhat_value_parent)

return mu_va_pa, pi_va_pa, surprise
return mu_value_parent, pi_value_parent, surprise


def input_surprise_reg(op):
"""Compute the surprise, mu_va_pa and pi_va_pa if pihat is not `jnp.inf`."""
(pihat, value, eta1, eta0, muhat_va_pa) = op
"""Compute the surprise, mu_value_parent and pi_value_parent."""
pihat, value, eta1, eta0, muhat_value_parent = op

# Likelihood under eta1
und1 = jnp.exp(jnp.subtract(0, pihat) / 2 * (jnp.subtract(value, eta1)) ** 2)
Expand All @@ -341,13 +356,17 @@ def input_surprise_reg(op):
und0 = jnp.exp(jnp.subtract(0, pihat) / 2 * (jnp.subtract(value, eta0)) ** 2)

# Eq. 39 in Mathys et al. (2014) (i.e., Bayes)
mu_va_pa = muhat_va_pa * und1 / (muhat_va_pa * und1 + (1 - muhat_va_pa) * und0)
pi_va_pa = 1 / (mu_va_pa * (1 - mu_va_pa))
mu_value_parent = (
muhat_value_parent
* und1
/ (muhat_value_parent * und1 + (1 - muhat_value_parent) * und0)
)
pi_value_parent = 1 / (mu_value_parent * (1 - mu_value_parent))

# Surprise
surprise = -jnp.log(
muhat_va_pa * gaussian_density(value, eta1, pihat)
+ (1 - muhat_va_pa) * gaussian_density(value, eta0, pihat)
muhat_value_parent * gaussian_density(value, eta1, pihat)
+ (1 - muhat_value_parent) * gaussian_density(value, eta0, pihat)
)

return mu_va_pa, pi_va_pa, surprise
return mu_value_parent, pi_value_parent, surprise
Loading
Loading