diff --git a/pyhgf/model.py b/pyhgf/model.py index bdc5a1685..52115cca7 100644 --- a/pyhgf/model.py +++ b/pyhgf/model.py @@ -459,6 +459,7 @@ def add_input_node( pihat: Union[float, np.ndarray, ArrayLike] = jnp.inf, eta0: Union[float, np.ndarray, ArrayLike] = 0.0, eta1: Union[float, np.ndarray, ArrayLike] = 1.0, + additional_parameters: Optional[Dict] = None, ): """Create an input node. @@ -477,6 +478,8 @@ def add_input_node( The lower bound of the binary process (only relevant if `kind="binary"`). eta1 : The lower bound of the binary process (only relevant if `kind="binary"`). + additional_parameters : + Add more custom parameters to the input node. """ if kind == "continuous": @@ -495,6 +498,11 @@ def add_input_node( "time_step": jnp.nan, "value": jnp.nan, } + + # add more parameters (optional) + if additional_parameters is not None: + input_node_parameters = {**input_node_parameters, **additional_parameters} + if input_idx == 0: # this is the first node, create the node structure self.parameters_structure = {input_idx: input_node_parameters} @@ -527,6 +535,7 @@ def add_value_parent( psis: Optional[Tuple] = None, omega: Union[float, np.ndarray, ArrayLike] = -4.0, rho: Union[float, np.ndarray, ArrayLike] = 0.0, + additional_parameters: Optional[Dict] = None, ): """Add a value parent to a given set of nodes. @@ -559,6 +568,8 @@ def add_value_parent( volatility parent(s)). rho : The drift of the random walk. Defaults to `0.0` (no drift). + additional_parameters : + Add more custom parameters to the node. """ # how many nodes in structure @@ -578,6 +589,10 @@ def add_value_parent( "rho": rho, } + # add more parameters (optional) + if additional_parameters is not None: + node_parameters = {**node_parameters, **additional_parameters} + # add a new node to connection structure with no parents self.node_structure += (Indexes(None, None, tuple(children_idxs), None),) @@ -632,6 +647,7 @@ def add_volatility_parent( psis: Optional[Tuple] = None, omega: Union[float, np.ndarray, ArrayLike] = -4.0, rho: Union[float, np.ndarray, ArrayLike] = 0.0, + additional_parameters: Optional[Dict] = None, ): """Add a volatility parent to a given set of nodes. @@ -663,6 +679,8 @@ def add_volatility_parent( volatility parent(s)). rho : The drift of the random walk. Defaults to `0.0` (no drift). + additional_parameters : + Add more custom parameters to the node. """ # how many nodes in structure @@ -682,6 +700,10 @@ def add_volatility_parent( "rho": rho, } + # add more parameters (optional) + if additional_parameters is not None: + node_parameters = {**node_parameters, **additional_parameters} + # add a new node to the connection structure with no parents self.node_structure += (Indexes(None, None, None, tuple(children_idxs)),)