From 9ffdc58e410c96a1e8110d9f7f98b375b42194be Mon Sep 17 00:00:00 2001 From: jdcpni Date: Mon, 2 Dec 2024 11:13:50 -0500 Subject: [PATCH] Feat/composition reset clear results (#3136) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * • composition.py - reset(): add clear_results arg --- psyneulink/core/compositions/composition.py | 53 +++++++++++++-------- tests/composition/test_composition.py | 10 ++++ 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/psyneulink/core/compositions/composition.py b/psyneulink/core/compositions/composition.py index 6cc1a44d39..15362e9097 100644 --- a/psyneulink/core/compositions/composition.py +++ b/psyneulink/core/compositions/composition.py @@ -12880,7 +12880,19 @@ def do_gradient_optimization(self, retain_in_pnl_options, context, optimization_ pass @handle_external_context(fallback_most_recent=True) - def reset(self, values=None, include_unspecified_nodes=True, context=NotImplemented): + def reset(self, values=None, include_unspecified_nodes=True, clear_results=False, context=NotImplemented): + """Reset all stateful functions in the Composition to their initial values. + + If **values** is provided, the `previous_value ` of the corresponding + `stateful functions ` are set to the values specified. If a value is not provided for a + given node, the `previous_value ` is set to the value of its `initializer + `. + + If **include_unspecified_nodes** is False, then all nodes must have corresponding reset values. + The `DEFAULT` keyword can be used in lieu of a numerical value to reset a node's value to its default. + + If **clear_results** is True, the `results ` attribute is set to an empty list. + """ if not values: values = {} @@ -12890,30 +12902,33 @@ def reset(self, values=None, include_unspecified_nodes=True, context=NotImplemen reset_val = values.get(node) node.reset(reset_val, context=context) + if clear_results: + self.parameters.results._set([], context) + @handle_external_context(fallback_most_recent=True) def initialize(self, values=None, include_unspecified_nodes=True, context=None): - """ - Initializes the values of nodes within cycles. If `include_unspecified_nodes` is True and a value is - provided for a given node, the node will be initialized to that value. If `include_unspecified_nodes` is - True and a value is not provided, the node will be initialized to its default value. If - `include_unspecified_nodes` is False, then all nodes must have corresponding initialization values. The - `DEFAULT` keyword can be used in lieu of a numerical value to reset a node's value to its default. + """Initialize the values of nodes within cycles. + If `include_unspecified_nodes` is True and a value is provided for a given node, the node is initialized to + that value. If `include_unspecified_nodes` is True and a value is not provided, the node is initialized to + its default value. If `include_unspecified_nodes` is False, then all nodes must have corresponding + initialization values. The `DEFAULT` keyword can be used in lieu of a numerical value to reset a node's value + to its default. - If a context is not provided, the most recent context under which the Composition has executed will be used. + If a context is not provided, the most recent context under which the Composition has executed is used. - Arguments - ---------- - values: Dict { Node: Node Value } - A dictionary contaning key-value pairs of Nodes and initialization values. Nodes within cycles that are - not included in this dict will be initialized to their default values. + Arguments + ---------- + values: Dict { Node: Node Value } + A dictionary containing key-value pairs of Nodes and initialization values. Nodes within cycles that are + not included in this dict are initialized to their default values. - include_unspecified_nodes: bool - Specifies whether all nodes within cycles should be initialized or only ones specified in the provided - values dictionary. + include_unspecified_nodes: bool + Specifies whether all nodes within cycles should be initialized or only ones specified in the provided + values dictionary. - context: Context - The context under which the nodes should be initialized. context will be set to - self.most_recent_execution_context if one is not specified. + context: Context + The context under which the nodes should be initialized. context are set to + self.most_recent_execution_context if one is not specified. """ # comp must be initialized from context before cycle values are initialized diff --git a/tests/composition/test_composition.py b/tests/composition/test_composition.py index 41b8127e52..8940d1ecab 100644 --- a/tests/composition/test_composition.py +++ b/tests/composition/test_composition.py @@ -7264,6 +7264,16 @@ def test_save_state_before_simulations(self): np.testing.assert_allclose(np.asfarray(run_1_values), [[0.36], [0.056], [0.056]]) np.testing.assert_allclose(np.asfarray(run_2_values), [[0.5904], [0.16384], [0.16384]]) + def test_reset_clear_results(self): + mech = ProcessingMechanism(name='mech') + comp = Composition(nodes=[mech]) + comp.run(inputs={mech: 1}) + assert comp.results == [[1]] + comp.reset() + assert comp.results == [[1]] + comp.reset(clear_results=True) + assert comp.results == [] + class TestNodeRoles: