From f7416d215c9d3bf62dd84855958204776e37f9b2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 18 Dec 2024 10:33:11 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/nn/probabilistic.py | 8 ++--- tensordict/nn/sequence.py | 62 ++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 75c326c7a..4637dd274 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -9,7 +9,7 @@ import warnings from textwrap import indent -from typing import Any, Dict, List, Optional, overload, OrderedDict +from typing import Any, Dict, List, Optional, OrderedDict, overload import torch @@ -800,8 +800,7 @@ def __init__( aggregate_probabilities: bool | None = None, include_sum: bool | None = None, inplace: bool | None = None, - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -812,8 +811,7 @@ def __init__( aggregate_probabilities: bool | None = None, include_sum: bool | None = None, inplace: bool | None = None, - ) -> None: - ... + ) -> None: ... def __init__( self, diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index 3f20b4195..6f8c9ee5c 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -6,6 +6,7 @@ from __future__ import annotations import collections +import contextlib import logging from copy import deepcopy from typing import Any, Callable, Iterable, List, OrderedDict, overload @@ -574,3 +575,64 @@ def __setitem__( def __delitem__(self, index: int | slice | str) -> None: self.module.__delitem__(idx=index) + + def plot(self, example_input: TensorDictBase | None = None, **kwargs): + import pydot + + graph = pydot.Dot( + "my_graph", graph_type="digraph", bgcolor="yellow", splines="curved" + ) + graph.set_bgcolor("white") + + if example_input is not None: + from torch._subclasses.fake_tensor import FakeTensorMode + + fake_mode = FakeTensorMode() + converter = fake_mode.fake_tensor_converter + fake_td = example_input.apply( + lambda x: converter.from_real_tensor(fake_mode, x) + ) + else: + fake_td = None + fake_mode = contextlib.nullcontext() + + with fake_mode: + iterator = ( + enumerate(self._module_iter()) + if not isinstance(self.module, nn.ModuleDict) + else self.module.items() + ) + for name, module in iterator: + graph.add_node( + pydot.Node(str(name), shape="box") + ) # label=str(node.module))) + + # Check if in_keys are there already + in_keys = module.in_keys + for in_key in in_keys: + if in_key not in graph.obj_dict["nodes"]: + in_key_node = pydot.Node( + in_key, label=in_key, shape="plaintext" + ) + graph.add_node(in_key_node) + in_key_edge = pydot.Edge( + in_key, str(name), color="blue", style="arrow" + ) + graph.add_edge(in_key_edge) + + if not isinstance(module, TensorDictModule): + fake_td = self._run_module(module, fake_td, **kwargs) + + out_keys = module.out_keys + for out_key in out_keys: + if out_key not in graph.obj_dict["nodes"]: + out_key_node = pydot.Node( + out_key, label=out_key, shape="plaintext" + ) + graph.add_node(out_key_node) + out_key_edge = pydot.Edge( + str(name), out_key, color="blue", style="arrow" + ) + graph.add_edge(out_key_edge) + + graph.write_png("/Users/vmoens/Downloads/my_graph.png")