Skip to content

Commit

Permalink
non-linear test draft (+ minor modification)
Browse files Browse the repository at this point in the history
  • Loading branch information
KoraTMontemagno committed Aug 12, 2024
1 parent 7ff35ae commit 18f0093
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def add_nodes(
)

# assess children number
children_number = 1
if value_children == None:
children_number = 0
elif isinstance(value_children, int):
Expand Down
65 changes: 64 additions & 1 deletion tests/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,22 @@ def test_continuous_input_update(nodes_attributes):
# one value parent with one volatility parent #
###############################################
attributes = nodes_attributes

def identity(x):
return(x)

edges = (
AdjacencyLists(0, (1,), None, None, None, (None,)),
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
AdjacencyLists(2, None, None, None, (1,), (None,)),
)

# define a non linear network behaving like the linear one
edges_nonlinear = (
AdjacencyLists(0, (1,), None, None, None, (None,)),
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
AdjacencyLists(2, None, None, None, (1,), (identity,)),
)

# create update sequence
sequence1 = 1, continuous_node_prediction
sequence2 = 2, continuous_node_prediction
Expand Down Expand Up @@ -165,6 +174,14 @@ def test_continuous_input_update(nodes_attributes):
input_data=(data, time_steps, observed),
)

#repeat for non-linear network
new_attributes_nonlinear, _ = beliefs_propagation(
structure=(inputs, edges_nonlinear),
attributes=attributes,
update_sequence=update_sequence,
input_data=(data, time_steps, observed),
)

for idx, val in zip(["time_step", "values"], [1.0, 0.2]):
assert jnp.isclose(new_attributes[0][idx], val)
for idx, val in zip(
Expand All @@ -178,6 +195,19 @@ def test_continuous_input_update(nodes_attributes):
):
assert jnp.isclose(new_attributes[2][idx], val)

for idx, val in zip(["time_step", "values"], [1.0, 0.2]):
assert jnp.isclose(new_attributes_nonlinear[0][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[10000.881, 0.880797, 0.20007047, 1.0],
):
assert jnp.isclose(new_attributes_nonlinear[1][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[0.9794834, 0.95257413, 0.97345114, 1.0],
):
assert jnp.isclose(new_attributes_nonlinear[2][idx], val)


def test_scan_loop(nodes_attributes):
timeserie = load_data("continuous")
Expand All @@ -191,6 +221,16 @@ def test_scan_loop(nodes_attributes):
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
AdjacencyLists(2, None, None, None, (1,), (None,)),
)

#testing it with a non-linear coupling
def identity(x):
return(x)

edges_nonlinear = (
AdjacencyLists(0, (1,), None, None, None, (None,)),
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
AdjacencyLists(2, None, None, None, (1,), (identity,)),
)

# create update sequence
sequence1 = 2, continuous_node_prediction
Expand All @@ -217,6 +257,12 @@ def test_scan_loop(nodes_attributes):
structure=(inputs, edges),
)

scan_fn_nonlinear = Partial(
beliefs_propagation,
update_sequence=update_sequence,
structure=(inputs, edges),
)

# Create the data (value and time steps vectors)
time_steps = jnp.ones((len(timeserie), 1))
observed = jnp.ones((len(timeserie), 1))
Expand All @@ -236,6 +282,23 @@ def test_scan_loop(nodes_attributes):
):
assert jnp.isclose(last[2][idx], val)

# non linear coupling
# Run the entire for loop
last_nonlinear, _ = scan(scan_fn_nonlinear, attributes,
(timeserie, time_steps, observed))
for idx, val in zip(["time_step", "values"], [1.0, 0.8241]):
assert jnp.isclose(last_nonlinear[0][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[24557.84, 14557.839, 0.8041823, 0.79050046],
):
assert jnp.isclose(last_nonlinear[1][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[1.3334407, 1.3493799, -7.1686087, -7.509615],
):
assert jnp.isclose(last_nonlinear[2][idx], val)


if __name__ == "__main__":
unittest.main(argv=["first-arg-is-ignored"], exit=False)

0 comments on commit 18f0093

Please sign in to comment.