Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nodes can store additional parameters #73

Merged
merged 3 commits into from
Jul 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pyhgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def add_input_node(
pihat: Union[float, np.ndarray, ArrayLike] = jnp.inf,
eta0: Union[float, np.ndarray, ArrayLike] = 0.0,
eta1: Union[float, np.ndarray, ArrayLike] = 1.0,
additional_parameters: Optional[Dict] = None,
):
"""Create an input node.

Expand All @@ -477,6 +478,8 @@ def add_input_node(
The lower bound of the binary process (only relevant if `kind="binary"`).
eta1 :
The lower bound of the binary process (only relevant if `kind="binary"`).
additional_parameters :
Add more custom parameters to the input node.

"""
if kind == "continuous":
Expand All @@ -495,6 +498,11 @@ def add_input_node(
"time_step": jnp.nan,
"value": jnp.nan,
}

# add more parameters (optional)
if additional_parameters is not None:
input_node_parameters = {**input_node_parameters, **additional_parameters}

if input_idx == 0:
# this is the first node, create the node structure
self.parameters_structure = {input_idx: input_node_parameters}
Expand Down Expand Up @@ -527,6 +535,7 @@ def add_value_parent(
psis: Optional[Tuple] = None,
omega: Union[float, np.ndarray, ArrayLike] = -4.0,
rho: Union[float, np.ndarray, ArrayLike] = 0.0,
additional_parameters: Optional[Dict] = None,
):
"""Add a value parent to a given set of nodes.

Expand Down Expand Up @@ -559,6 +568,8 @@ def add_value_parent(
volatility parent(s)).
rho :
The drift of the random walk. Defaults to `0.0` (no drift).
additional_parameters :
Add more custom parameters to the node.

"""
# how many nodes in structure
Expand All @@ -578,6 +589,10 @@ def add_value_parent(
"rho": rho,
}

# add more parameters (optional)
if additional_parameters is not None:
node_parameters = {**node_parameters, **additional_parameters}

# add a new node to connection structure with no parents
self.node_structure += (Indexes(None, None, tuple(children_idxs), None),)

Expand Down Expand Up @@ -632,6 +647,7 @@ def add_volatility_parent(
psis: Optional[Tuple] = None,
omega: Union[float, np.ndarray, ArrayLike] = -4.0,
rho: Union[float, np.ndarray, ArrayLike] = 0.0,
additional_parameters: Optional[Dict] = None,
):
"""Add a volatility parent to a given set of nodes.

Expand Down Expand Up @@ -663,6 +679,8 @@ def add_volatility_parent(
volatility parent(s)).
rho :
The drift of the random walk. Defaults to `0.0` (no drift).
additional_parameters :
Add more custom parameters to the node.

"""
# how many nodes in structure
Expand All @@ -682,6 +700,10 @@ def add_volatility_parent(
"rho": rho,
}

# add more parameters (optional)
if additional_parameters is not None:
node_parameters = {**node_parameters, **additional_parameters}

# add a new node to the connection structure with no parents
self.node_structure += (Indexes(None, None, None, tuple(children_idxs)),)

Expand Down
Loading