diff --git a/pyhgf/binary.py b/pyhgf/binary.py index fda051be3..1b38bd75e 100644 --- a/pyhgf/binary.py +++ b/pyhgf/binary.py @@ -62,79 +62,92 @@ def binary_node_update( """ # using the current node index, unwrap parameters and parents - node_parameters = parameters_structure[node_idx] - value_parents_idx = node_structure[node_idx].value_parents - volatility_parents_idx = node_structure[node_idx].volatility_parents + value_parent_idxs = node_structure[node_idx].value_parents - # Return here if no parents node are provided - if (value_parents_idx is None) and (volatility_parents_idx is None): + # Return here if no parents node are found + if value_parent_idxs is None: return parameters_structure - pihat = node_parameters["pihat"] - vape = jnp.subtract(node_parameters["mu"], node_parameters["muhat"]) + pihat = parameters_structure[node_idx]["pihat"] + vape = ( + parameters_structure[node_idx]["mu"] - parameters_structure[node_idx]["muhat"] + ) ####################################### # Update the continuous value parents # ####################################### - if value_parents_idx is not None: - # unpack the current parent's parameters with value and volatility parents - va_pa_node_parameters = parameters_structure[value_parents_idx[0]] - # va_pa_value_parents_idx = node_structure[value_parents_idx[0]].value_parents - va_pa_volatility_parents_idx = node_structure[ - value_parents_idx[0] - ].volatility_parents - - # 1. get pihat_pa and nu_pa from the value parent (x2) - - # 1.1 get new_nu (x2) - # 1.1.1 get logvol - logvol = va_pa_node_parameters["omega"] - - # 1.1.2 Look at the (optional) va_pa's volatility parents - # and update logvol accordingly - if va_pa_volatility_parents_idx is not None: - for va_pa_vo_pa, k in zip( - va_pa_volatility_parents_idx, va_pa_node_parameters["kappas_parents"] - ): - logvol += k * parameters_structure[va_pa_vo_pa]["mu"] - - # 1.1.3 Compute new_nu - nu = time_step * jnp.exp(logvol) - new_nu = jnp.where(nu > 1e-128, nu, jnp.nan) - - # 1.2 Compute new value for nu and pihat - pihat_pa, nu_pa = [1 / (1 / va_pa_node_parameters["pi"] + new_nu), new_nu] - - # 2. - pi_pa = pihat_pa + 1 / pihat - - # 3. get muhat_pa from value parent (x2) - - # 3.1 - driftrate = va_pa_node_parameters["rho"] - - # TODO: this will be decided once we figure out the multi parents/children - # 3.2 Look at the (optional) va_pa's value parents - # and update driftrate accordingly - # if va_pa_value_parents is not None: - # psi - # for va_pa_va_pa, psi in zip( - # va_pa_value_parents, va_pa_node_parameters["psis"] - # ): - # _mu = va_pa_va_pa[0]["mu"] - # driftrate += psi * _mu - # 3.3 - muhat_pa = va_pa_node_parameters["mu"] + time_step * driftrate - - # 4. - mu_pa = muhat_pa + vape / pi_pa # This line differs from the continuous input - - # 5. Update node's parameters and node's parents recursively - parameters_structure[value_parents_idx[0]]["pihat"] = pihat_pa - parameters_structure[value_parents_idx[0]]["pi"] = pi_pa - parameters_structure[value_parents_idx[0]]["muhat"] = muhat_pa - parameters_structure[value_parents_idx[0]]["mu"] = mu_pa - parameters_structure[value_parents_idx[0]]["nu"] = nu_pa + if value_parent_idxs is not None: + for value_parent_idx in value_parent_idxs: + value_parent_value_parent_idxs = node_structure[ + value_parent_idx + ].value_parents + value_parent_volatility_parent_idxs = node_structure[ + value_parent_idx + ].volatility_parents + + # 1. get pihat_value_parent and nu_value_parent from the value parent (x2) + + # 1.1 get new_nu (x2) + # 1.1.1 get logvol + logvol = parameters_structure[value_parent_idx]["omega"] + + # 1.1.2 Look at the (optional) va_pa's volatility parents + # and update logvol accordingly + if value_parent_volatility_parent_idxs is not None: + for value_parent_volatility_parent_idx, k in zip( + value_parent_volatility_parent_idxs, + parameters_structure[value_parent_idx]["kappas_parents"], + ): + logvol += ( + k + * parameters_structure[value_parent_volatility_parent_idx]["mu"] + ) + + # 1.1.3 Compute new_nu + nu = time_step * jnp.exp(logvol) + new_nu = jnp.where(nu > 1e-128, nu, jnp.nan) + + # 1.2 Compute new value for nu and pihat + pihat_value_parent, nu_value_parent = [ + 1 / (1 / parameters_structure[value_parent_idx]["pi"] + new_nu), + new_nu, + ] + + # 2. + pi_value_parent = pihat_value_parent + 1 / pihat + + # 3. get muhat_value_parent from value parent (x2) + + # 3.1 + driftrate = parameters_structure[value_parent_idx]["rho"] + + # 3.2 Look at the (optional) value parent's value parents + # and update driftrate accordingly + if value_parent_value_parent_idxs is not None: + for value_parent_value_parent_idx, psi in zip( + value_parent_value_parent_idxs, + parameters_structure[value_parent_idx]["psis_parents"], + ): + driftrate += ( + psi * parameters_structure[value_parent_value_parent_idx]["mu"] + ) + + # 3.3 + muhat_value_parent = ( + parameters_structure[value_parent_idx]["mu"] + time_step * driftrate + ) + + # 4. + mu_value_parent = ( + muhat_value_parent + vape / pi_value_parent + ) # This line differs from the continuous input + + # 5. Update node's parameters and node's parents recursively + parameters_structure[value_parent_idx]["pihat"] = pihat_value_parent + parameters_structure[value_parent_idx]["pi"] = pi_value_parent + parameters_structure[value_parent_idx]["muhat"] = muhat_value_parent + parameters_structure[value_parent_idx]["mu"] = mu_value_parent + parameters_structure[value_parent_idx]["nu"] = nu_value_parent return parameters_structure @@ -190,72 +203,74 @@ def binary_input_update( arXiv. https://doi.org/10.48550/ARXIV.2305.10937 """ - # using the current node index, unwrap parameters and parents - input_node_parameters = parameters_structure[node_idx] - value_parents_idx = node_structure[node_idx].value_parents - volatility_parents_idx = node_structure[node_idx].volatility_parents + # list value and volatility parents + value_parent_idxs = node_structure[node_idx].value_parents + volatility_parent_idxs = node_structure[node_idx].volatility_parents - if (value_parents_idx is None) and (volatility_parents_idx is None): + if (value_parent_idxs is None) and (volatility_parent_idxs is None): return parameters_structure - pihat = input_node_parameters["pihat"] - - ################################ - # Update parents (binary node) # - ################################ - - if value_parents_idx is not None: - # unpack the current parent's parameters with value and volatility parents - # va_pa_node_parameters = parameters_structure[value_parents_idx[0]] - va_pa_value_parents_idx = node_structure[value_parents_idx[0]].value_parents - # va_pa_volatility_parents_idx = node_structure[ - # value_parents_idx[0] - # ].volatility_parents - - # 1. Compute new muhat_pa and pihat_pa from binary node parent - # ------------------------------------------------------------ - - # 1.1 Compute new_muhat from continuous node parent (x2) - - # 1.1.1 get rho from the value parent of the binary node (x2) - driftrate = parameters_structure[va_pa_value_parents_idx[0]]["rho"] - - # TODO: this will be decided once we figure out the multi parents/children - # # 1.1.2 Look at the (optional) va_pa's value parents (x3) - # # and update the drift rate accordingly - # if va_pa_value_parents[0][1] is not None: - # for va_pa_va_pa in va_pa_value_parents[0][1]: - # # For each x2's value parents (optional) - # _psi = va_pa_value_parents[0][0]["psis"] - # _mu = va_pa_va_pa[0]["mu"] - # driftrate += _psi * _mu - - # 1.1.3 compute new_muhat - muhat_va_pa = ( - parameters_structure[va_pa_value_parents_idx[0]]["mu"] - + time_step * driftrate - ) - - muhat_va_pa = sgm(muhat_va_pa) - pihat_va_pa = 1 / (muhat_va_pa * (1 - muhat_va_pa)) - - # 2. Compute surprise - # ------------------- - eta0 = input_node_parameters["eta0"] - eta1 = input_node_parameters["eta1"] - - mu_va_pa, pi_va_pa, surprise = cond( - pihat == jnp.inf, - input_surprise_inf, - input_surprise_reg, - (pihat, value, eta1, eta0, muhat_va_pa), - ) - - # Update value parent's parameters - parameters_structure[value_parents_idx[0]]["pihat"] = pihat_va_pa - parameters_structure[value_parents_idx[0]]["pi"] = pi_va_pa - parameters_structure[value_parents_idx[0]]["muhat"] = muhat_va_pa - parameters_structure[value_parents_idx[0]]["mu"] = mu_va_pa + pihat = parameters_structure[node_idx]["pihat"] + + ####################################################### + # Update the value parent(s) of the binary input node # + ####################################################### + + if value_parent_idxs is not None: + for value_parent_idx in value_parent_idxs: + # list the (unique) value parents + value_parent_value_parent_idxs = node_structure[ + value_parent_idx + ].value_parents[0] + + # 1. Compute new muhat_value_parent and pihat_value_parent + # -------------------------------------------------------- + # 1.1 Compute new_muhat from continuous node parent (x2) + # 1.1.1 get rho from the value parent of the binary node (x2) + driftrate = parameters_structure[value_parent_value_parent_idxs]["rho"] + + # # 1.1.2 Look at the (optional) value parent's value parents (x3) + # # and update the drift rate accordingly + if node_structure[value_parent_value_parent_idxs].value_parents is not None: + for value_parent_value_parent_value_parent_idx in node_structure[ + value_parent_value_parent_idxs + ].value_parents: + # For each x2's value parents (optional) + driftrate += ( + parameters_structure[value_parent_value_parent_idxs][ + "psis_parents" + ] + * parameters_structure[ + value_parent_value_parent_value_parent_idx + ]["mu"] + ) + + # 1.1.3 compute new_muhat + muhat_value_parent = ( + parameters_structure[value_parent_value_parent_idxs]["mu"] + + time_step * driftrate + ) + + muhat_value_parent = sgm(muhat_value_parent) + pihat_value_parent = 1 / (muhat_value_parent * (1 - muhat_value_parent)) + + # 2. Compute surprise + # ------------------- + eta0 = parameters_structure[node_idx]["eta0"] + eta1 = parameters_structure[node_idx]["eta1"] + + mu_value_parent, pi_value_parent, surprise = cond( + pihat == jnp.inf, + input_surprise_inf, + input_surprise_reg, + (pihat, value, eta1, eta0, muhat_value_parent), + ) + + # Update value parent's parameters + parameters_structure[value_parent_idx]["pihat"] = pihat_value_parent + parameters_structure[value_parent_idx]["pi"] = pi_value_parent + parameters_structure[value_parent_idx]["muhat"] = muhat_value_parent + parameters_structure[value_parent_idx]["mu"] = mu_value_parent parameters_structure[node_idx]["surprise"] = surprise parameters_structure[node_idx]["time_step"] = time_step @@ -322,17 +337,17 @@ def binary_surprise( def input_surprise_inf(op): """Apply special case if pihat is `jnp.inf` (just pass the value through).""" - (pihat, value, eta1, eta0, muhat_va_pa) = op - mu_va_pa = value - pi_va_pa = jnp.inf - surprise = binary_surprise(value, muhat_va_pa) + _, value, _, _, muhat_value_parent = op + mu_value_parent = value + pi_value_parent = jnp.inf + surprise = binary_surprise(value, muhat_value_parent) - return mu_va_pa, pi_va_pa, surprise + return mu_value_parent, pi_value_parent, surprise def input_surprise_reg(op): - """Compute the surprise, mu_va_pa and pi_va_pa if pihat is not `jnp.inf`.""" - (pihat, value, eta1, eta0, muhat_va_pa) = op + """Compute the surprise, mu_value_parent and pi_value_parent.""" + pihat, value, eta1, eta0, muhat_value_parent = op # Likelihood under eta1 und1 = jnp.exp(jnp.subtract(0, pihat) / 2 * (jnp.subtract(value, eta1)) ** 2) @@ -341,13 +356,17 @@ def input_surprise_reg(op): und0 = jnp.exp(jnp.subtract(0, pihat) / 2 * (jnp.subtract(value, eta0)) ** 2) # Eq. 39 in Mathys et al. (2014) (i.e., Bayes) - mu_va_pa = muhat_va_pa * und1 / (muhat_va_pa * und1 + (1 - muhat_va_pa) * und0) - pi_va_pa = 1 / (mu_va_pa * (1 - mu_va_pa)) + mu_value_parent = ( + muhat_value_parent + * und1 + / (muhat_value_parent * und1 + (1 - muhat_value_parent) * und0) + ) + pi_value_parent = 1 / (mu_value_parent * (1 - mu_value_parent)) # Surprise surprise = -jnp.log( - muhat_va_pa * gaussian_density(value, eta1, pihat) - + (1 - muhat_va_pa) * gaussian_density(value, eta0, pihat) + muhat_value_parent * gaussian_density(value, eta1, pihat) + + (1 - muhat_value_parent) * gaussian_density(value, eta0, pihat) ) - return mu_va_pa, pi_va_pa, surprise + return mu_value_parent, pi_value_parent, surprise diff --git a/pyhgf/continuous.py b/pyhgf/continuous.py index 1bd66e0db..4c53edf03 100644 --- a/pyhgf/continuous.py +++ b/pyhgf/continuous.py @@ -67,130 +67,149 @@ def continuous_node_update( arXiv. https://doi.org/10.48550/ARXIV.2305.10937 """ - # using the current node index, unwrap parameters and parents - node_parameters = parameters_structure[node_idx] - value_parents_idx = node_structure[node_idx].value_parents - volatility_parents_idx = node_structure[node_idx].volatility_parents + # list value and volatility parents + value_parents_idxs = node_structure[node_idx].value_parents + volatility_parents_idxs = node_structure[node_idx].volatility_parents # return here if no parents node are provided - if (value_parents_idx is None) and (volatility_parents_idx is None): + if (value_parents_idxs is None) and (volatility_parents_idxs is None): return parameters_structure - pihat = node_parameters["pihat"] + pihat = parameters_structure[node_idx]["pihat"] - ####################################### - # Update the continuous value parents # - ####################################### - if value_parents_idx is not None: - # the strength of the value coupling with parents nodes - psis = node_parameters["psis_parents"] + ######################## + # Update value parents # + ######################## + if value_parents_idxs is not None: + # the strength of the value coupling between the base node and the parents nodes + psis = parameters_structure[node_idx]["psis_parents"] - for va_pa_idx, psi in zip(value_parents_idx, psis): + for value_parents_idx, psi in zip(value_parents_idxs, psis): # if this child is the last one relative to this parent's family, all the - # child will update the parent at once, otherwise just pass and wait - if node_structure[va_pa_idx].value_children[-1] == node_idx: - # unpack the parent's parameters with the value and volatility parents - va_pa_node_parameters = parameters_structure[va_pa_idx] - va_pa_value_parents_idx = node_structure[va_pa_idx].value_parents - va_pa_volatility_parents_idx = node_structure[ - va_pa_idx + # children will update the parent at once, otherwise just pass and wait + if node_structure[value_parents_idx].value_children[-1] == node_idx: + # list the value and volatility parents + value_parent_value_parents_idxs = node_structure[ + value_parents_idx + ].value_parents + value_parent_volatility_parents_idxs = node_structure[ + value_parents_idx ].volatility_parents # Compute new value for nu and pihat - logvol = va_pa_node_parameters["omega"] + logvol = parameters_structure[value_parents_idx]["omega"] # Look at the (optional) va_pa's volatility parents # and update logvol accordingly - if va_pa_volatility_parents_idx is not None: - for va_pa_vo_pa, k in zip( - va_pa_volatility_parents_idx, - va_pa_node_parameters["kappas_parents"], + if value_parent_volatility_parents_idxs is not None: + for value_parent_volatility_parents_idx, k in zip( + value_parent_volatility_parents_idxs, + parameters_structure[value_parents_idx]["kappas_parents"], ): - logvol += k * parameters_structure[va_pa_vo_pa]["mu"] + logvol += ( + k + * parameters_structure[value_parent_volatility_parents_idx][ + "mu" + ] + ) # Estimate new_nu nu = time_step * jnp.exp(logvol) new_nu = jnp.where(nu > 1e-128, nu, jnp.nan) - pihat_pa, nu_pa = [ - 1 / (1 / va_pa_node_parameters["pi"] + new_nu), + pihat_value_parent, nu_value_parent = [ + 1 / (1 / parameters_structure[value_parents_idx]["pi"] + new_nu), new_nu, ] # gather precision updates from other nodes if the parent has many - # children - this part corresponds to the sum of children required for - # the multi-children situations + # children - this part corresponds to the sum over children + # required for the multi-children situations pi_children = 0.0 for child_idx, psi_child in zip( - node_structure[va_pa_idx].value_children, - parameters_structure[va_pa_idx]["psis_children"], + node_structure[value_parents_idx].value_children, + parameters_structure[value_parents_idx]["psis_children"], ): pihat_child = parameters_structure[child_idx]["pihat"] pi_children += psi_child**2 * pihat_child - pi_pa = pihat_pa + pi_children + pi_value_parent = pihat_value_parent + pi_children # Compute new muhat - driftrate = va_pa_node_parameters["rho"] + driftrate = parameters_structure[value_parents_idx]["rho"] # Look at the (optional) va_pa's value parents # and update drift rate accordingly - if va_pa_value_parents_idx is not None: - for va_pa_va_pa in va_pa_value_parents_idx: - driftrate += psi * va_pa_va_pa["mu"] + if value_parent_value_parents_idxs is not None: + for va_pa_va_pa in value_parent_value_parents_idxs: + driftrate += psi * parameters_structure[va_pa_va_pa]["mu"] - muhat_pa = va_pa_node_parameters["mu"] + time_step * driftrate + muhat_value_parent = ( + parameters_structure[value_parents_idx]["mu"] + + time_step * driftrate + ) # gather PE updates from other nodes if the parent has many children # this part corresponds to the sum of children required for the # multi-children situations pe_children = 0.0 for child_idx, psi_child in zip( - node_structure[va_pa_idx].value_children, - parameters_structure[va_pa_idx]["psis_children"], + node_structure[value_parents_idx].value_children, + parameters_structure[value_parents_idx]["psis_children"], ): vape_child = ( parameters_structure[child_idx]["mu"] - parameters_structure[child_idx]["muhat"] ) pihat_child = parameters_structure[child_idx]["pihat"] - pe_children += (psi_child * pihat_child * vape_child) / pi_pa + pe_children += ( + psi_child * pihat_child * vape_child + ) / pi_value_parent - mu_pa = muhat_pa + pe_children + mu_value_parent = muhat_value_parent + pe_children # Update this parent's parameters - parameters_structure[va_pa_idx]["pihat"] = pihat_pa - parameters_structure[va_pa_idx]["pi"] = pi_pa - parameters_structure[va_pa_idx]["muhat"] = muhat_pa - parameters_structure[va_pa_idx]["mu"] = mu_pa - parameters_structure[va_pa_idx]["nu"] = nu_pa + parameters_structure[value_parents_idx]["pihat"] = pihat_value_parent + parameters_structure[value_parents_idx]["pi"] = pi_value_parent + parameters_structure[value_parents_idx]["muhat"] = muhat_value_parent + parameters_structure[value_parents_idx]["mu"] = mu_value_parent + parameters_structure[value_parents_idx]["nu"] = nu_value_parent ############################# # Update volatility parents # ############################# - if volatility_parents_idx is not None: - nu = node_parameters["nu"] - kappas = node_parameters["kappas_parents"] - vope = ( - 1 / node_parameters["pi"] - + (node_parameters["mu"] - node_parameters["muhat"]) ** 2 - ) * node_parameters["pihat"] - 1 + if volatility_parents_idxs is not None: + # the strength of the value coupling between the base node and the parents nodes + kappas = parameters_structure[node_idx]["kappas_parents"] - for vo_pa_idx, kappa in zip(volatility_parents_idx, kappas): - # unpack the current parent's parameters with value and volatility parents - vo_pa_node_parameters = parameters_structure[vo_pa_idx] - vo_pa_value_parents_idx = node_structure[vo_pa_idx].value_parents - vo_pa_volatility_parents_idx = node_structure[vo_pa_idx].volatility_parents + nu = parameters_structure[node_idx]["nu"] + vope = ( + 1 / parameters_structure[node_idx]["pi"] + + ( + parameters_structure[node_idx]["mu"] + - parameters_structure[node_idx]["muhat"] + ) + ** 2 + ) * parameters_structure[node_idx]["pihat"] - 1 + + for volatility_parents_idx, kappa in zip(volatility_parents_idxs, kappas): + # list the value and volatility parents + volatility_parent_value_parents_idx = node_structure[ + volatility_parents_idx + ].value_parents + volatility_parent_volatility_parents_idx = node_structure[ + volatility_parents_idx + ].volatility_parents # Compute new value for nu and pihat - logvol = vo_pa_node_parameters["omega"] + logvol = parameters_structure[volatility_parents_idx]["omega"] # Look at the (optional) vo_pa's volatility parents # and update logvol accordingly - if vo_pa_volatility_parents_idx is not None: + if volatility_parent_volatility_parents_idx is not None: for vo_pa_vo_pa, k in zip( - vo_pa_volatility_parents_idx, - vo_pa_node_parameters["kappas_parents"], + volatility_parent_volatility_parents_idx, + parameters_structure[volatility_parents_idx]["kappas_parents"], ): logvol += k * parameters_structure[vo_pa_vo_pa]["mu"] @@ -198,31 +217,46 @@ def continuous_node_update( new_nu = time_step * jnp.exp(logvol) new_nu = jnp.where(new_nu > 1e-128, new_nu, jnp.nan) - pihat_pa, nu_pa = [1 / (1 / vo_pa_node_parameters["pi"] + new_nu), new_nu] + pihat_volatility_parent, nu_volatility_parent = [ + 1 / (1 / parameters_structure[volatility_parents_idx]["pi"] + new_nu), + new_nu, + ] - pi_pa = pihat_pa + 0.5 * (kappa * nu * pihat) ** 2 * ( - 1 + (1 - 1 / (nu * node_parameters["pi"])) * vope + pi_volatility_parent = pihat_volatility_parent + 0.5 * ( + kappa * nu * pihat + ) ** 2 * (1 + (1 - 1 / (nu * parameters_structure[node_idx]["pi"])) * vope) + pi_volatility_parent = jnp.where( + pi_volatility_parent <= 0, jnp.nan, pi_volatility_parent ) - pi_pa = jnp.where(pi_pa <= 0, jnp.nan, pi_pa) # Compute new muhat - driftrate = vo_pa_node_parameters["rho"] + driftrate = parameters_structure[volatility_parents_idx]["rho"] # Look at the (optional) va_pa's value parents # and update drift rate accordingly - if vo_pa_value_parents_idx is not None: - for vo_pa_va_pa in vo_pa_value_parents_idx: + if volatility_parent_value_parents_idx is not None: + for vo_pa_va_pa in volatility_parent_value_parents_idx: driftrate += psi * parameters_structure[vo_pa_va_pa]["mu"] - muhat_pa = vo_pa_node_parameters["mu"] + time_step * driftrate - mu_pa = muhat_pa + 0.5 * kappa * nu * pihat / pi_pa * vope + muhat_volatility_parent = ( + parameters_structure[volatility_parents_idx]["mu"] + + time_step * driftrate + ) + mu_volatility_parent = ( + muhat_volatility_parent + + 0.5 * kappa * nu * pihat / pi_volatility_parent * vope + ) # Update this parent's parameters - parameters_structure[vo_pa_idx]["pihat"] = pihat_pa - parameters_structure[vo_pa_idx]["pi"] = pi_pa - parameters_structure[vo_pa_idx]["muhat"] = muhat_pa - parameters_structure[vo_pa_idx]["mu"] = mu_pa - parameters_structure[vo_pa_idx]["nu"] = nu_pa + parameters_structure[volatility_parents_idx][ + "pihat" + ] = pihat_volatility_parent + parameters_structure[volatility_parents_idx]["pi"] = pi_volatility_parent + parameters_structure[volatility_parents_idx][ + "muhat" + ] = muhat_volatility_parent + parameters_structure[volatility_parents_idx]["mu"] = mu_volatility_parent + parameters_structure[volatility_parents_idx]["nu"] = nu_volatility_parent return parameters_structure @@ -277,76 +311,88 @@ def continuous_input_update( arXiv. https://doi.org/10.48550/ARXIV.2305.10937 """ - # using the current node index, unwrap parameters and parents - input_node_parameters = parameters_structure[node_idx] - value_parents_idx = node_structure[node_idx].value_parents - volatility_parents_idx = node_structure[node_idx].volatility_parents + # list value and volatility parents + value_parents_idxs = node_structure[node_idx].value_parents + volatility_parents_idxs = node_structure[node_idx].volatility_parents - lognoise = input_node_parameters["omega"] + lognoise = parameters_structure[node_idx]["omega"] - if volatility_parents_idx is not None: - for vo_pa_idx, k in zip( - volatility_parents_idx, input_node_parameters["kappas_parents"] + if volatility_parents_idxs is not None: + for volatility_parents_idx, k in zip( + volatility_parents_idxs, parameters_structure[node_idx]["kappas_parents"] ): - lognoise += k * parameters_structure[vo_pa_idx]["mu"] + lognoise += k * parameters_structure[volatility_parents_idx]["mu"] pihat = 1 / jnp.exp(lognoise) ######################## # Update value parents # ######################## - if value_parents_idx is not None: - # unpack the current parent's parameters with value and volatility parents - va_pa_node_parameters = parameters_structure[value_parents_idx[0]] - va_pa_value_parents_idx = node_structure[value_parents_idx[0]].value_parents - va_pa_volatility_parents_idx = node_structure[ - value_parents_idx[0] - ].volatility_parents - - vape = va_pa_node_parameters["mu"] - va_pa_node_parameters["muhat"] - - # Compute new value for nu and pihat - logvol = va_pa_node_parameters["omega"] - - # Look at the (optional) va_pa's volatility parents - # and update logvol accordingly - if va_pa_volatility_parents_idx is not None: - for va_pa_vo_pa, k in zip( - va_pa_volatility_parents_idx, va_pa_node_parameters["kappas_parents"] - ): - logvol += k * parameters_structure[va_pa_vo_pa]["mu"] - - # Estimate new_nu - nu = time_step * jnp.exp(logvol) - new_nu = jnp.where(nu > 1e-128, nu, jnp.nan) - - pihat_va_pa, nu_va_pa = [ - 1 / (1 / va_pa_node_parameters["pi"] + new_nu), - new_nu, - ] - pi_va_pa = pihat_va_pa + pihat - - # Compute new muhat - driftrate = va_pa_node_parameters["rho"] - - # Look at the (optional) va_pa's value parents - # and update drift rate accordingly - if va_pa_value_parents_idx is not None: - for va_pa_va_pa in va_pa_value_parents_idx: - driftrate += ( - va_pa_node_parameters["psis"][0] - * parameters_structure[va_pa_va_pa]["mu"] - ) + if value_parents_idxs is not None: + for value_parents_idx in value_parents_idxs: + # unpack the current parent's parameters with value and volatility parents + value_parent_value_parents_idxs = node_structure[ + value_parents_idx + ].value_parents + value_parent_volatility_parents_idxs = node_structure[ + value_parents_idx + ].volatility_parents + + vape = ( + parameters_structure[value_parents_idx]["mu"] + - parameters_structure[value_parents_idx]["muhat"] + ) + + # Compute new value for nu and pihat + logvol = parameters_structure[value_parents_idx]["omega"] + + # Look at the (optional) va_pa's volatility parents + # and update logvol accordingly + if value_parent_volatility_parents_idxs is not None: + for value_parent_volatility_parents_idx, k in zip( + value_parent_volatility_parents_idxs, + parameters_structure[value_parents_idx]["kappas_parents"], + ): + logvol += ( + k + * parameters_structure[value_parent_volatility_parents_idx][ + "mu" + ] + ) + + # Estimate new_nu + nu = time_step * jnp.exp(logvol) + new_nu = jnp.where(nu > 1e-128, nu, jnp.nan) - muhat_va_pa = va_pa_node_parameters["mu"] + time_step * driftrate - vape = value - muhat_va_pa - mu_va_pa = muhat_va_pa + pihat / pi_va_pa * vape + pihat_value_parent, nu_value_parent = [ + 1 / (1 / parameters_structure[value_parents_idx]["pi"] + new_nu), + new_nu, + ] + pi_value_parent = pihat_value_parent + pihat - # update input node's parameters - parameters_structure[value_parents_idx[0]]["pihat"] = pihat_va_pa - parameters_structure[value_parents_idx[0]]["pi"] = pi_va_pa - parameters_structure[value_parents_idx[0]]["muhat"] = muhat_va_pa - parameters_structure[value_parents_idx[0]]["mu"] = mu_va_pa - parameters_structure[value_parents_idx[0]]["nu"] = nu_va_pa + # Compute new muhat + driftrate = parameters_structure[value_parents_idx]["rho"] + + # Look at the (optional) va_pa's value parents + # and update drift rate accordingly + if value_parent_value_parents_idxs is not None: + for value_parent_value_parents_idx in value_parent_value_parents_idxs: + driftrate += ( + parameters_structure[value_parents_idx]["psis_parents"][0] + * parameters_structure[value_parent_value_parents_idx]["mu"] + ) + + muhat_value_parent = ( + parameters_structure[value_parents_idx]["mu"] + time_step * driftrate + ) + vape = value - muhat_value_parent + mu_value_parent = muhat_value_parent + pihat / pi_value_parent * vape + + # update input node's parameters + parameters_structure[value_parents_idx]["pihat"] = pihat_value_parent + parameters_structure[value_parents_idx]["pi"] = pi_value_parent + parameters_structure[value_parents_idx]["muhat"] = muhat_value_parent + parameters_structure[value_parents_idx]["mu"] = mu_value_parent + parameters_structure[value_parents_idx]["nu"] = nu_value_parent # store value and timestep in the node's parameters parameters_structure[node_idx]["time_step"] = time_step