Skip to content

Commit

Permalink
[skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcpni committed Oct 29, 2023
1 parent 6b72ef2 commit 251d124
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions psyneulink/library/compositions/autodiffcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,15 @@ def create_pathway(node)->list:
# target_mechs = [ProcessingMechanism(default_variable = np.zeros_like(mech.value),
# name= 'TARGET for ' + mech.name)
# for mech in output_mechs if mech not in self.target_output_map.values()]
# MODIFIED 10/27/23 NEW:
# # MODIFIED 10/27/23 NEW:
# target_mechs = [ProcessingMechanism(default_variable = np.array([np.zeros_like(value)
# for value in output_mechs[0].value],
# dtype=object),
# name= 'TARGET for ' + mech.name)
# for mech in output_mechs if mech not in self.target_output_map.values()]
# # MODIFIED 10/27/23 NEWER:
target_mechs = [ProcessingMechanism(default_variable = np.array([np.zeros_like(value)
for value in output_mechs[0].value],
for value in mech.value],
dtype=object),
name= 'TARGET for ' + mech.name)
for mech in output_mechs if mech not in self.target_output_map.values()]
Expand Down Expand Up @@ -784,15 +790,30 @@ def autodiff_training(self, inputs, targets, context=None, scheduler=None):
input = inputs[component]
curr_tensor_inputs[component] = torch.tensor(input, device=self.device).double()
for component in targets.keys():
# # MODIFIED 10/27/23 OLD:
# FIX: 10/28/23: BRANCH ON WHETHER TARGET ARRAY IS REGULAR OR RAGGED?
# # MODIFIED 10/27/23 OLD: WORKS FOR TESTS BUT NOT SCRATCH PAD BECAUSE IT FORCES A RAGGED ARRAY INTO A TENSOR
# curr_tensor_targets[self.target_output_map[component]] = [torch.tensor(target, device=self.device).double()
# for target in targets[component]]
# MODIFIED 10/27/23 NEW:
# # # MODIFIED 10/27/23 NEW:
# terminal_node = self.target_output_map[component]
# curr_tensor_targets[terminal_node] = [torch.tensor(target_port_input,
# device=self.device).double()
# for target in targets[component] for
# target_port_input in target]
# # # MODIFIED 10/27/23 NEWER: WORKS FOR SCRATCH PAD BUT NOT TESTS BECAUSE IT REDUCES ONE DIMENSION TOO FAR
# # AND PRODUCES BAD RESULT FOR TESTS BECAUSE IT ADDS A DIMENSION TO THE TENSORS
# terminal_node = self.target_output_map[component]
# curr_tensor_targets[terminal_node] = [torch.tensor([target_port_input],
# device=self.device).double()
# for target in targets[component] for
# target_port_input in target]
# # # MODIFIED 10/27/23 NEWEST:
terminal_node = self.target_output_map[component]
curr_tensor_targets[terminal_node] = [torch.tensor(target_port_input,
device=self.device).double()
for target in targets[component] for
target_port_input in target]
curr_tensor_target = []
for target in targets[component]:
for target_port_input in target:
curr_tensor_target.append(torch.tensor(target_port_input, device=self.device).double())
curr_tensor_targets[terminal_node] = curr_tensor_target
# MODIFIED 10/27/23 END

# do forward computation on current inputs
Expand All @@ -805,9 +826,9 @@ def autodiff_training(self, inputs, targets, context=None, scheduler=None):
# new_loss = self.loss(curr_tensor_outputs[component][0],
# curr_tensor_targets[component][0])
# MODIFIED 10/27/23 NEW:
# FIX: 10/26/23 - SHOULD HANDLE MULTIPLE OUTPUT PORTS curr_tensor_outputs[component][idx] as per below
new_loss = 0
for i in range(len(curr_tensor_outputs[component])):
new_loss = self.loss(curr_tensor_outputs[component][i],
new_loss += self.loss(curr_tensor_outputs[component][i],
curr_tensor_targets[component][i])
# MODIFIED 10/27/23 END
tracked_loss += new_loss
Expand Down Expand Up @@ -863,7 +884,7 @@ def _infer_output_nodes(self, nodes: dict):
---------
A dict mapping TARGET Nodes -> target values
"""
return {node:value for node,value in nodes.items() if node in self.target_output_map}
return {node:value for node, value in nodes.items() if node in self.target_output_map}

def _infer_input_nodes(self, nodes: dict):
"""Remove TARGET Nodes, and return dict with values of INPUT Nodes for single trial
Expand Down

0 comments on commit 251d124

Please sign in to comment.