Skip to content

Commit

Permalink
- autodiffcomposition.py, pytorchwrappers.py: revise docstrings re: … (
Browse files Browse the repository at this point in the history
…#3157)

* - autodiffcomposition.py, pytorchwrappers.py:  revise docstrings re: relationship to Pytorch modules
  • Loading branch information
jdcpni authored Jan 8, 2025
1 parent 77e4632 commit 047401e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 21 deletions.
14 changes: 7 additions & 7 deletions psyneulink/library/compositions/autodiffcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
AutodiffComposition does not (currently) support the *automatic* construction of separate bias parameters.
Thus, when constructing the PyTorch version of an AutodiffComposition, the `bias
<https://www.pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ parameter of PyTorch modules are set to False.
<https://www.pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ parameter of any PyTorch modules are set to False.
However, biases can be implemented using `Composition_Bias_Nodes`.
Expand Down Expand Up @@ -531,7 +531,7 @@ class AutodiffComposition(Composition):
but slows performance (see `AutodiffComposition_PyTorch_LearningScale` for information about settings).
synch_node_variables_with_torch : OPTIMIZATION_STEP, TRIAL, MINIBATCH, EPOCH, RUN or None
determines when to copy the current input to Pytorch nodes (modules) to the PsyNeuLink `variable
determines when to copy the current input to Pytorch functions to the PsyNeuLink `variable
<Mechanism_Base.value>` attribute of the corresponding PsyNeuLink `nodes <Composition_Node>`, if this is not
specified in the call to `learn <AutodiffComposition.learn>`.
COMMENT:
Expand All @@ -545,11 +545,11 @@ class AutodiffComposition(Composition):
but slows performance (see `AutodiffComposition_PyTorch_LearningScale` for information about settings).
synch_node_values_with_torch : OPTIMIZATION_STEP, MINIBATCH, EPOCH or RUN
determines when to copy the current output of Pytorch nodes (modules) to the PsyNeuLink `value
<Mechanism_Base.value>` attribute of the corresponding PsyNeuLink `nodes <Composition_Node>`, if this is not
specified in the call to `learn <AutodiffComposition.learn>`. Copying more frequently keeps the PsyNeuLink
representation more closely copying more frequently keeps them synchronized with parameter updates in Pytorch,
but slows performance (see `AutodiffComposition_PyTorch_LearningScale` for information about settings).
determines when to copy the current output of Pytorch functions to the PsyNeuLink `value <Mechanism_Base.value>`
attribute of the corresponding PsyNeuLink `nodes <Composition_Node>`, if this is not specified in the call to
`learn <AutodiffComposition.learn>`. Copying more frequently keeps the PsyNeuLink representation more closely
copying more frequently keeps them synchronized with parameter updates in Pytorch, but slows performance
(see `AutodiffComposition_PyTorch_LearningScale` for information about settings).
synch_results_with_torch : OPTIMIZATION_STEP, TRIAL, MINIBATCH, EPOCH or RUN
determines when to copy the current outputs of Pytorch nodes to the PsyNeuLink `results
Expand Down
45 changes: 31 additions & 14 deletions psyneulink/library/compositions/pytorchwrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,25 @@ class PytorchCompositionWrapper(torch.nn.Module):
# # MODIFIED 7/29/24 NEW: NEEDED FOR torch MPS SUPPORT
# class PytorchCompositionWrapper(torch.jit.ScriptModule):
# MODIFIED 7/29/24 END
"""Wrapper for a Composition as a Pytorch Module
Class that wraps a `Composition <Composition>` as a PyTorch module.
"""Wrapper for a Composition as a Pytorch Module.
Wraps an `AutodiffComposition` as a `PyTorch module
<https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_, with each `Mechanism <Mechanism>` in the
AutodiffComposition wrapped as a `PytorchMechanismWrapper`, each `Projection <Projection>` wrapped as a
`PytorchProjectionWrapper`, and any nested Compositions wrapped as `PytorchCompositionWrapper`\\s. Each
PytorchMechanismWrapper implements a Pytorch version of the `function(s) <Mechanism_Base.function>` of the wrapped
`Mechanism`, which are executed in the PyTorchCompositionWrapper's `forward <PyTorchCompositionWrapper.forward>`
method in the order specified by the AutodiffComposition's `scheduler <Composition.scheduler>`. The
`matrix <MappingProjection.matrix>` Parameters of each wrapped `Projection` are assigned as parameters of the
`PytorchMechanismWrapper` Pytorch module and used, together with a Pytorch `matmul
<https://pytorch.org/docs/main/generated/torch.matmul.html>`_ operation, to generate the input to each
PyTorch function as specified by the `PytorchProjectionWrapper`\\'s `graph <Composition.graph>`. The graph
can be visualized using the AutodiffComposition's `show_graph <ShowGraph.show_graph>` method and setting its
*show_pytorch* argument to True (see `PytorchShowGraph` for additional information).
Two main responsibilities:
1) Set up parameters of PyTorch model & information required for forward computation:
1) Set up functions and parameters of PyTorch module required for it forward computation:
Handle nested compositions (flattened in infer_backpropagation_learning_pathways):
Deal with Projections into and/or out of a nested Composition as shown in figure below:
(note: Projections in outer Composition to/from a nested Composition's CIMs are learnable,
Expand Down Expand Up @@ -115,12 +128,14 @@ class PytorchCompositionWrapper(torch.nn.Module):
`AutodiffComposition` being wrapped.
wrapped_nodes : List[PytorchMechanismWrapper]
list of nodes in the PytorchCompositionWrapper corresponding to PyTorch modules. Generally these are
`Mechanisms <Mechanism>` wrapped in a `PytorchMechanismWrapper`, however, if the `AutodiffComposition`
being wrapped is itself a nested Composition, then the wrapped nodes are `PytorchCompositionWrapper` objects.
When the PyTorch model is executed these are "flattened" into a single PyTorch module, which can be visualized
using the AutodiffComposition's `show_graph <ShowGraph.show_graph>` method and setting its *show_pytorch*
argument to True (see `PytorchShowGraph` for additional information).
list of nodes in the PytorchCompositionWrapper corresponding to the PyTorch functions that comprise the
forward method of the Pytorch module implemented by the PytorchCompositionWrapper. Generally these are
`Mechanisms <Mechanism>` wrapped in a `PytorchMechanismWrapper`, however, if the `AutodiffComposition` Node
being wrapped is a nested Composition, then the wrapped node is itself a `PytorchCompositionWrapper` object.
When the PyTorch model is executed, all of these are "flattened" into a single PyTorch module, corresponding
to the outermost AutodiffComposition being wrapped, which can be visualized using that AutodiffComposition's
`show_graph <ShowGraph.show_graph>` method and setting its *show_pytorch* argument to True (see
`PytorchShowGraph` for additional information).
nodes_map : Dict[Node: PytorchMechanismWrapper or PytorchCompositionWrapper]
maps psyneulink `Nodes <Composition_Nodes>` to PytorchCompositionWrapper nodes.
Expand All @@ -140,7 +155,7 @@ class PytorchCompositionWrapper(torch.nn.Module):
assigned by AutodffComposition after the wrapper is created, which passes the parameters to the optimizer
device : torch.device
device used to process torch Tensors in PyTorch modules
device used to process torch Tensors in PyTorch functions
params : nn.ParameterList()
list of PyTorch parameters (connection weight matrices) in the PyTorch model.
Expand Down Expand Up @@ -857,7 +872,7 @@ def detach_all(self):

class PytorchMechanismWrapper():
"""Wrapper for a Mechanism in a PytorchCompositionWrapper
These comprise nodes of the PytorchCompositionWrapper, and generally correspond to modules of a Pytorch model.
These comprise nodes of the PytorchCompositionWrapper, and generally correspond to functions in a Pytorch model.
Attributes
----------
Expand Down Expand Up @@ -1128,9 +1143,11 @@ def __repr__(self):
class PytorchProjectionWrapper():
"""Wrapper for Projection in a PytorchCompositionWrapper
The matrix of the wrapped `_projection <PytorchProjectionWrapper._projection>` corresponds to the parameters
(connection weights) of the PyTorch Module that is the `function <Mechanism_Base.function>` of the
`receiver <Projection_Base.receiver>` of the wrapped Projection.
The matrix of the wrapped `_projection <PytorchProjectionWrapper._projection>` is assigned as a parameter of
(set of connection weights in ) the PyTorch Module that, coupled with a corresponding input and `torch.matmul
<https://pytorch.org/docs/main/generated/torch.matmul.html>`_ operation, provide the input to the Pytorch
function associated with the `Node <Composition_Node>` of the AutdiffComposition that is the `receiver
<Projection_Base.receiver>` of the wrapped Projection.
.. note::
In the case of a nested Composition, the sender and/or receiver attributes may be mapped to different Node(s)
Expand Down

0 comments on commit 047401e

Please sign in to comment.