Skip to content

Commit

Permalink
Feat/composition reset clear results (#3136)
Browse files Browse the repository at this point in the history
* • composition.py
  - reset():  add clear_results arg
  • Loading branch information
jdcpni authored Dec 2, 2024
1 parent cd06c82 commit 9ffdc58
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
53 changes: 34 additions & 19 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12880,7 +12880,19 @@ def do_gradient_optimization(self, retain_in_pnl_options, context, optimization_
pass

@handle_external_context(fallback_most_recent=True)
def reset(self, values=None, include_unspecified_nodes=True, context=NotImplemented):
def reset(self, values=None, include_unspecified_nodes=True, clear_results=False, context=NotImplemented):
"""Reset all stateful functions in the Composition to their initial values.

If **values** is provided, the `previous_value <StatefulFunction.previous_value>` of the corresponding
`stateful functions <StatefulFunction>` are set to the values specified. If a value is not provided for a
given node, the `previous_value <StatefulFunction.previous_value>` is set to the value of its `initializer
<StatefulFunction.initializer>`.

If **include_unspecified_nodes** is False, then all nodes must have corresponding reset values.
The `DEFAULT` keyword can be used in lieu of a numerical value to reset a node's value to its default.

If **clear_results** is True, the `results <Composition.results>` attribute is set to an empty list.
"""
if not values:
values = {}

Expand All @@ -12890,30 +12902,33 @@ def reset(self, values=None, include_unspecified_nodes=True, context=NotImplemen
reset_val = values.get(node)
node.reset(reset_val, context=context)

if clear_results:
self.parameters.results._set([], context)

@handle_external_context(fallback_most_recent=True)
def initialize(self, values=None, include_unspecified_nodes=True, context=None):
"""
Initializes the values of nodes within cycles. If `include_unspecified_nodes` is True and a value is
provided for a given node, the node will be initialized to that value. If `include_unspecified_nodes` is
True and a value is not provided, the node will be initialized to its default value. If
`include_unspecified_nodes` is False, then all nodes must have corresponding initialization values. The
`DEFAULT` keyword can be used in lieu of a numerical value to reset a node's value to its default.
"""Initialize the values of nodes within cycles.
If `include_unspecified_nodes` is True and a value is provided for a given node, the node is initialized to
that value. If `include_unspecified_nodes` is True and a value is not provided, the node is initialized to
its default value. If `include_unspecified_nodes` is False, then all nodes must have corresponding
initialization values. The `DEFAULT` keyword can be used in lieu of a numerical value to reset a node's value
to its default.

If a context is not provided, the most recent context under which the Composition has executed will be used.
If a context is not provided, the most recent context under which the Composition has executed is used.

Arguments
----------
values: Dict { Node: Node Value }
A dictionary contaning key-value pairs of Nodes and initialization values. Nodes within cycles that are
not included in this dict will be initialized to their default values.
Arguments
----------
values: Dict { Node: Node Value }
A dictionary containing key-value pairs of Nodes and initialization values. Nodes within cycles that are
not included in this dict are initialized to their default values.

include_unspecified_nodes: bool
Specifies whether all nodes within cycles should be initialized or only ones specified in the provided
values dictionary.
include_unspecified_nodes: bool
Specifies whether all nodes within cycles should be initialized or only ones specified in the provided
values dictionary.

context: Context
The context under which the nodes should be initialized. context will be set to
self.most_recent_execution_context if one is not specified.
context: Context
The context under which the nodes should be initialized. context are set to
self.most_recent_execution_context if one is not specified.

"""
# comp must be initialized from context before cycle values are initialized
Expand Down
10 changes: 10 additions & 0 deletions tests/composition/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7264,6 +7264,16 @@ def test_save_state_before_simulations(self):
np.testing.assert_allclose(np.asfarray(run_1_values), [[0.36], [0.056], [0.056]])
np.testing.assert_allclose(np.asfarray(run_2_values), [[0.5904], [0.16384], [0.16384]])

def test_reset_clear_results(self):
mech = ProcessingMechanism(name='mech')
comp = Composition(nodes=[mech])
comp.run(inputs={mech: 1})
assert comp.results == [[1]]
comp.reset()
assert comp.results == [[1]]
comp.reset(clear_results=True)
assert comp.results == []


class TestNodeRoles:

Expand Down

0 comments on commit 9ffdc58

Please sign in to comment.