Skip to content

Commit

Permalink
Merge branch 'refactor/learning_pathways_using_ports' of https://gith…
Browse files Browse the repository at this point in the history
…ub.com/PrincetonUniversity/PsyNeuLink into refactor/learning_pathways_using_ports
  • Loading branch information
jdcpni committed Oct 30, 2023
2 parents ac2abf4 + 81444d0 commit a143219
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 42 deletions.
128 changes: 112 additions & 16 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,6 +2879,7 @@ def input_function(env, result):
import functools
import inspect
import types
import numbers
import itertools
import logging
import sys
Expand Down Expand Up @@ -9806,7 +9807,7 @@ def _infer_target_nodes(self, targets: dict, execution_mode):
ret[node] = values
return ret

def _parse_learning_spec(self, inputs, targets, execution_mode):
def _parse_learning_spec(self, inputs, targets, execution_mode, context):
"""
Converts learning inputs and targets to a standardized form

Expand Down Expand Up @@ -9849,7 +9850,7 @@ def _recursive_update(d, u):

# 3) Resize inputs to be of the form [[[]]],
# where each level corresponds to: <TRIALS <PORTS <INPUTS> > >
inputs, num_inputs_sets = self._parse_input_dict(inputs)
inputs, num_inputs_sets = self._parse_input_dict(inputs, context)

return inputs, num_inputs_sets

Expand Down Expand Up @@ -9991,7 +9992,7 @@ def _validate_single_input(self, receiver, input):
_input = None
return _input

def _parse_input_dict(self, inputs):
def _parse_input_dict(self, inputs, context=None):
"""
Validate and parse a dict provided as input to a Composition into a standardized form to be used throughout
its execution
Expand All @@ -10005,14 +10006,20 @@ def _parse_input_dict(self, inputs):
Number of input sets (i.e., trials' worths of inputs) in dict for each input node in the Composition

"""

# MODIFIED 10/29/23 NEW:
if context and (context.runmode & ContextFlags.LEARNING_MODE) and (context.source & ContextFlags.COMPOSITION):
return inputs, 1
# MODIFIED 10/29/23 END

# parse a user-provided input dict to format it properly for execution.
# compute number of input sets and return that as well
_inputs = self._parse_names_in_inputs(inputs)
_inputs = self._parse_labels(_inputs)
self._validate_input_dict_keys(_inputs)
_inputs = self._instantiate_input_dict(_inputs)
_inputs = self._flatten_nested_dicts(_inputs)
_inputs = self._validate_input_shapes(_inputs)
_inputs = self._validate_input_shapes_and_expand_for_all_trials(_inputs)
num_inputs_sets = len(next(iter(_inputs.values()),[]))
return _inputs, num_inputs_sets

Expand Down Expand Up @@ -10216,9 +10223,72 @@ def _instantiate_input_dict(self, inputs):
input_dict[INPUT_Node] = INPUT_Node.external_input_shape
continue

# If entry is for an INPUT_Node of self, assign the entry directly to input_dict and proceed to next
if INPUT_Node in inputs:
input_dict[INPUT_Node] = inputs[INPUT_Node]
# If entry is for an INPUT_Node of self,
# check format, adjust as needed, assign the entry to input_dict, and proceed to next
# # MODIFIED 10/29/23 OLD:
# input_dict[INPUT_Node] = inputs[INPUT_Node]
# MODIFIED 10/29/23 NEW:
# FIX: 10/29/23
# ENFORCE get_input_format() spec for inputs here
# (e.g., impose outer dimension for single trial or single input_port)
add_dim = False
_inputs = inputs[INPUT_Node]
# Check formatting of first node as proxy for formatting of all items, and make changes accordingly
# (any other errant items will be detected in _validate_input_shapes_and_expand_for_all_trials())
if isinstance(_inputs, dict):
# entry is dict for a nested Composition, which will be handled recursively
pass
elif isinstance(_inputs, numbers.Number):
# Single scalar, so must be single value for single trial
_inputs = [[[_inputs]]]
elif all(isinstance(elem, numbers.Number) for elem in _inputs):
# List of scalars, so determine if it is one trial's input for array variable or array of trials
# (use squeeze in to handle both Mechanism or InputPort)
if len(_inputs) == len(np.array(INPUT_Node.variable).squeeze(0)):
_inputs = [[_inputs]]
else:
_inputs = [np.atleast_2d(elem).tolist() for elem in _inputs]
else:
# FIX: IF inputs IS 3D, THEN THE FOLLOWING IS FINE,
# BUT IF IT IS 2D, MUST DETERMINE WHETHER
# IT IS A SINGLE TRIAL FOR N INPUT_PORTS OR N TRIALS FOR ONE INPUT_PORT
if len(_inputs) == len(np.array(INPUT_Node.variable).squeeze(0)):

# FIX: THE FOLLOWING ONLY WORKS IF INPUT IS 3D, OR HAVE ALREADY DETERMINED
# THAT IT IS FOR MULTIPLE TRIALS (3D) OR FOR ONLY ONE INPUT_PORT (2D)
entry = convert_to_np_array(_inputs[0])
if entry.dtype == object:
# FIX: NEED TO FIGURE OUT WHERE RAGGED ARRAY IS COMING FROM:
# IF AXIS 1, THEN:
# Entry is ragged, so at some level must be specification for input_ports of NODE
if entry.ndim == 1:
# entry is ragged so must be a list of inputs to the ports of INPUT_Node for a single trial
item = entry
add_dim = True
elif entry.ndim == 2:
# entry itself is not ragged, so must be multiple trials, but input for each must be ragged
item = entry[0]
else:
raise CompositionError(f"BAD ENTRY") # FIX: MSG -> TOO MANY DIMENSIONS TO ENTRY IN INPUT
for i, input_port in zip(item, INPUT_Node.input_ports):
# Make sure each item of the entry is a 1d vector the size of the corresponding input_port
if not np.array(i).ndim == 1 and len(i) == len(input_port.defaults.value):
raise CompositionError(f"BAD ENTRY") # FIX: MSG -> MISMATCH BETWEEN ENTRY AND INPUT_PORT

else:
# entry is a regular array, so if it is 3d we're done
if entry.ndim == 1:
# enforce 3d on each entry
# (validity of shape will be determined in _validate_input_shapes_and_expand_for_all_trials())
add_dim = True

if add_dim:
_inputs = [[input] for input in _inputs]

input_dict[INPUT_Node] = _inputs
# MODIFIED 10/29/23 END

remaining_inputs.remove(INPUT_Node)
continue

Expand Down Expand Up @@ -10384,7 +10454,7 @@ def _flatten_nested_dicts(self, inputs):
_inputs.update({node:inp})
return _inputs

def _validate_input_shapes(self, inputs):
def _validate_input_shapes_and_expand_for_all_trials(self, inputs):
"""
Validates that all inputs provided in input dict are valid

Expand All @@ -10411,9 +10481,7 @@ def _validate_input_shapes(self, inputs):
# see if the entire stimulus set provided is a valid input for the receiver
# (i.e. in the case of a call with a single trial of provided input)
_input = self._validate_single_input(receiver, stimulus)
if _input is not None:
_input = [_input]
else:
if _input is None:
# if _input is None, it may mean there are multiple trials of input in the stimulus set,
# so in list comprehension below loop through and validate each individual input;
_input = [self._validate_single_input(receiver, single_trial_input) for single_trial_input in stimulus]
Expand All @@ -10438,6 +10506,12 @@ def _validate_input_shapes(self, inputs):
"(or other values) to represent the outside stimulus for " \
"the inhibition InputPort, and for Compositions, put your inputs"
raise RunError(err_msg)
else:
# # MODIFIED 10/29/23 OLD:
# _input = [_input]
# MODIFIED 10/29/23 NEW:
assert True
# MODIFIED 10/29/23 END
_inputs[receiver] = _input
input_length = len(_input)
if input_length == 1:
Expand Down Expand Up @@ -10500,9 +10574,10 @@ def _parse_run_inputs(self, inputs, context=None):

return _inputs, num_inputs_sets

def _parse_trial_inputs(self, inputs, trial_num):
def _parse_trial_inputs(self, inputs, trial_num, context):
"""
Extracts inputs for a single trial and parses it in accordance with its type
Note: this method is intended to run BEFORE a call to Composition.execute

Returns
-------
Expand All @@ -10511,11 +10586,31 @@ def _parse_trial_inputs(self, inputs, trial_num):
Input dict parsed for a single trial of a Composition's execution

"""

# parse and return a single trial's worth of inputs.
# this method is intended to run BEFORE a call to Composition.execute
# # MODIFIED 10/29/23 OLD:
# if callable(inputs):
# try:
# inputs, _ = self._parse_input_dict(inputs(trial_num), context)
# i = 0
# except TypeError as e:
# error_text = e.args[0]
# if f" takes 0 positional arguments but 1 was given" in error_text:
# raise CompositionError(f"{error_text}: requires arg for trial number")
# else:
# raise CompositionError(f"Problem with function provided to 'inputs' arg of {self.name}.run")
# elif isgenerator(inputs):
# inputs, _ = self._parse_input_dict(inputs.__next__(), context)
# i = 0
# else:
# num_inputs_sets = len(next(iter(inputs.values())))
# i = trial_num % num_inputs_sets
# next_inputs = {node:inp[i] for node, inp in inputs.items()}
# next_inputs = inputs
# MODIFIED 10/29/23 NEW:
if callable(inputs):
try:
inputs, _ = self._parse_input_dict(inputs(trial_num))
next_inputs, _ = self._parse_input_dict(inputs(trial_num), context)
i = 0
except TypeError as e:
error_text = e.args[0]
Expand All @@ -10524,12 +10619,13 @@ def _parse_trial_inputs(self, inputs, trial_num):
else:
raise CompositionError(f"Problem with function provided to 'inputs' arg of {self.name}.run")
elif isgenerator(inputs):
inputs, _ = self._parse_input_dict(inputs.__next__())
next_inputs, _ = self._parse_input_dict(inputs.__next__(), context)
i = 0
else:
num_inputs_sets = len(next(iter(inputs.values())))
i = trial_num % num_inputs_sets
next_inputs = {node:inp[i] for node, inp in inputs.items()}
next_inputs = {node:inp[i] for node, inp in inputs.items()}
# MODIFIED 10/29/23 END
return next_inputs

def _validate_execution_inputs(self, inputs):
Expand Down Expand Up @@ -11092,7 +11188,7 @@ def run(
# PROCESSING ------------------------------------------------------------------------
# Prepare stimuli from the outside world -- collect the inputs for this TRIAL and store them in a dict
try:
execution_stimuli = self._parse_trial_inputs(inputs, trial_num)
execution_stimuli = self._parse_trial_inputs(inputs, trial_num, context)
except StopIteration:
break

Expand Down
46 changes: 22 additions & 24 deletions psyneulink/library/compositions/autodiffcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,15 @@ def create_pathway(node)->list:
# target_mechs = [ProcessingMechanism(default_variable = np.zeros_like(mech.value),
# name= 'TARGET for ' + mech.name)
# for mech in output_mechs if mech not in self.target_output_map.values()]
# MODIFIED 10/27/23 NEW:
# # MODIFIED 10/27/23 NEW:
# target_mechs = [ProcessingMechanism(default_variable = np.array([np.zeros_like(value)
# for value in output_mechs[0].value],
# dtype=object),
# name= 'TARGET for ' + mech.name)
# for mech in output_mechs if mech not in self.target_output_map.values()]
# # MODIFIED 10/27/23 NEWER:
target_mechs = [ProcessingMechanism(default_variable = np.array([np.zeros_like(value)
for value in output_mechs[0].value],
for value in mech.value],
dtype=object),
name= 'TARGET for ' + mech.name)
for mech in output_mechs if mech not in self.target_output_map.values()]
Expand Down Expand Up @@ -777,39 +783,30 @@ def autodiff_training(self, inputs, targets, context=None, scheduler=None):
curr_tensor_targets = {}
for component in inputs.keys():
if isinstance(component, Mechanism_Base):
# FIX 10/1/23: SHOULD REALLY CYCLE THROUGH INPUT PORTS FOR A MECHANISM
# FIX 10/1/23 f/u: 10/29/23: SHOULD REALLY CYCLE THROUGH INPUT PORTS FOR A MECHANISM
# RATHER THAN JUST ASSUMING ONE INPUT AND USING [0]
input = inputs[component][0]
# # MODIFIED 10/29/23 OLD:
# input = inputs[component][0]
# MODIFIED 10/29/23 NEW:
input = inputs[component]
# MODIFIED 10/29/23 END
else:
input = inputs[component]
curr_tensor_inputs[component] = torch.tensor(input, device=self.device).double()
for component in targets.keys():
# # MODIFIED 10/27/23 OLD:
# curr_tensor_targets[self.target_output_map[component]] = [torch.tensor(target, device=self.device).double()
# for target in targets[component]]
# MODIFIED 10/27/23 NEW:
terminal_node = self.target_output_map[component]
curr_tensor_targets[terminal_node] = [torch.tensor(target_port_input,
device=self.device).double()
for target in targets[component] for
target_port_input in target]
# MODIFIED 10/27/23 END
curr_tensor_targets[self.target_output_map[component]] = [torch.tensor(target, device=self.device).double()
for target in targets[component]]

# do forward computation on current inputs
curr_tensor_outputs = self.parameters.pytorch_representation._get(context).forward(curr_tensor_inputs, context)

for component in curr_tensor_outputs.keys():
# possibly add custom loss option, which is a loss function that takes many args
# (outputs, targets, weights, and more) and returns a scalar
# # MODIFIED 10/27/23 OLD:
# new_loss = self.loss(curr_tensor_outputs[component][0],
# curr_tensor_targets[component][0])
# MODIFIED 10/27/23 NEW:
# FIX: 10/26/23 - SHOULD HANDLE MULTIPLE OUTPUT PORTS curr_tensor_outputs[component][idx] as per below
new_loss = 0
for i in range(len(curr_tensor_outputs[component])):
new_loss = self.loss(curr_tensor_outputs[component][i],
new_loss += self.loss(curr_tensor_outputs[component][i],
curr_tensor_targets[component][i])
# MODIFIED 10/27/23 END
tracked_loss += new_loss

outputs = []
Expand Down Expand Up @@ -863,7 +860,8 @@ def _infer_output_nodes(self, nodes: dict):
---------
A dict mapping TARGET Nodes -> target values
"""
return {node:value for node,value in nodes.items() if node in self.target_output_map}
# 10/29/23: FIX - VALUES SHOULD 2D HERE
return {node:value for node, value in nodes.items() if node in self.target_output_map}

def _infer_input_nodes(self, nodes: dict):
"""Remove TARGET Nodes, and return dict with values of INPUT Nodes for single trial
Expand All @@ -890,8 +888,8 @@ def _infer_input_nodes(self, nodes: dict):
input_nodes[node] = values
return input_nodes

def _parse_learning_spec(self, inputs, targets, execution_mode):
stim_input, num_input_trials = super()._parse_learning_spec(inputs, targets, execution_mode)
def _parse_learning_spec(self, inputs, targets, execution_mode, context):
stim_input, num_input_trials = super()._parse_learning_spec(inputs, targets, execution_mode, context)

if not callable(inputs):
input_ports_for_INPUT_Nodes = self._get_input_receivers()
Expand Down
6 changes: 4 additions & 2 deletions psyneulink/library/compositions/compositionrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def _batch_function_inputs(self,
try:
trial_input, _ = self._composition._parse_learning_spec(inputs=inputs(idx),
targets=None,
execution_mode=execution_mode)
execution_mode=execution_mode,
context=context)
except:
break
if trial_input is None:
Expand Down Expand Up @@ -223,7 +224,8 @@ def run_learning(self,

stim_input, num_input_trials = self._composition._parse_learning_spec(inputs=stim_input,
targets=stim_target,
execution_mode=execution_mode)
execution_mode=execution_mode,
context=context)
if num_trials is None:
num_trials = num_input_trials

Expand Down
1 change: 1 addition & 0 deletions psyneulink/library/compositions/emcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
# - Add warning of this on initial call to learn()
#
# - Composition:
# - _validate_input_shapes_and_expand_for_all_trials: consolidate with get_input_format()
# - Generalize treatment of FEEDBACK specification:
# - FIX: ADD TESTS FOR FEEDBACK TUPLE SPECIFICATION OF Projection, DIRECT SPECIFICATION IN CONSTRUCTOR
# - FIX: why aren't FEEDBACK_SENDER and FEEDBACK_RECEIVER roles being assigned when feedback is specified?
Expand Down

0 comments on commit a143219

Please sign in to comment.