Skip to content

Commit

Permalink
[PT] Introduce GraphBuilderMode (#3029)
Browse files Browse the repository at this point in the history
### Changes

Add GraphBuilderMode to build graph of pytorch model based on
FunctionHookMode

### Related tickets

154914

### Tests

tests/torch2/function_hook/graph/test_build_graph_mode.py
tests/torch2/function_hook/graph/test_graph_visualisation.py
  • Loading branch information
AlexanderDokuchaev authored Oct 28, 2024
1 parent 6afb13d commit 94f1006
Show file tree
Hide file tree
Showing 19 changed files with 1,236 additions and 55 deletions.
1 change: 1 addition & 0 deletions nncf/experimental/torch2/function_hook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph as build_graph
from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage as get_hook_storage
from nncf.experimental.torch2.function_hook.wrapper import is_wrapped as is_wrapped
from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook as register_post_function_hook
Expand Down
10 changes: 10 additions & 0 deletions nncf/experimental/torch2/function_hook/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
369 changes: 369 additions & 0 deletions nncf/experimental/torch2/function_hook/graph/build_graph_mode.py

Large diffs are not rendered by default.

102 changes: 102 additions & 0 deletions nncf/experimental/torch2/function_hook/graph/graph_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Tuple

import torch


class NodeType(Enum):
const = "const"
fn_call = "function_call"
input = "input"
output = "output"

def __str__(self) -> str:
return self.value


class TensorSource(Enum):
buffer = "buffer"
function = "function"
input = "input"
output = "output"
parameter = "parameter"

def __str__(self) -> str:
return self.value


@dataclass
class TensorMeta:
dtype: torch.dtype
shape: Tuple[int, ...]
requires_grad: bool

@staticmethod
def from_tensor(tensor: torch.Tensor) -> TensorMeta:
return TensorMeta(tensor.dtype, tuple(tensor.shape), tensor.requires_grad)


@dataclass
class ConstMeta:
dtype: torch.dtype
shape: Tuple[int, ...]
name_in_model: str

@staticmethod
def from_tensor(tensor: torch.Tensor, name_in_model: str) -> ConstMeta:
return ConstMeta(tensor.dtype, tuple(tensor.shape), name_in_model)


@dataclass
class InOutMeta:
dtype: torch.dtype
shape: Tuple[int, ...]
name: str

@staticmethod
def from_tensor(tensor: torch.Tensor, name: str) -> InOutMeta:
return InOutMeta(tensor.dtype, tuple(tensor.shape), name)


@dataclass
class FunctionMeta:
op_name: str
fn_name: str
args: Tuple[Any, ...]
kwargs: Dict[str, Any]


@dataclass
class EdgeMeta:
dtype: torch.dtype
shape: Tuple[int, ...]
input_port: int
output_port: int

@staticmethod
def from_tensor(tensor: torch.Tensor, input_port: int, output_port: int) -> EdgeMeta:
return EdgeMeta(tensor.dtype, tuple(tensor.shape), input_port, output_port)


@dataclass
class TensorInfo:
tensor_source: TensorSource
shape: Tuple[int, ...]
dtype: torch.dtype
output_port_id: int
source_node_id: Optional[int]
name_in_model: Optional[str]
226 changes: 226 additions & 0 deletions nncf/experimental/torch2/function_hook/graph/graph_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
from enum import Enum
from typing import Any, Dict, Tuple

import networkx as nx # type: ignore[import-untyped]
import pydot # type: ignore[import-untyped]

from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import EdgeMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta


class PydotStyleTemplate(Enum):
"""
Enum to define different styles for Pydot graph representation.
- disable: labels contain only names, used for tests (not recommend to convert to svg)
- short: labels contain names, add edge info
- full: labels contain all full information about nodes
"""

disable = "disable"
short = "short"
full = "full"

def __str__(self) -> str:
return self.value


def fix_dot_label(label: str) -> str:
"""
Escapes curly braces in a DOT label to avoid syntax errors.
:param label: The label string to be fixed.
:return: The label with escaped curly braces.
"""
return label.replace("{", r"\{").replace("}", r"\}")


def args_to_label(args: Tuple[Any, ...]) -> str:
"""
Converts function arguments to a formatted string label.
:param args: Function arguments.
:return: Formatted string label of arguments.
"""
if not args:
return "[]"
ret = "["
for arg in args:
ret += f"\n{arg},"
return ret + "\n]"


def kwargs_to_label(kwargs: Dict[str, Any]) -> str:
"""
Converts function keyword arguments to a formatted string label.
:param kwargs: Function keyword arguments.
:return: Formatted string label of keyword arguments.
"""
if not kwargs:
return "{}"
ret = "{"
for key, val in kwargs.items():
ret += f"\n{key} : {str(val)[:50]}"
return ret + "\n}"


def get_label_from_node_data(node_data: Dict[str, Any], style: PydotStyleTemplate) -> str:
"""
Generates a label for a graph node based on its metadata and the desired style.
:param node_data: Metadata of the node.
:param style: Style template to determine the label format.
:return: Formatted label for the node.
"""
meta = node_data["meta"]
node_type = node_data["type"]
if style == PydotStyleTemplate.full:
rows = []
if isinstance(meta, InOutMeta):
rows = [
f"type: {node_type}",
f"name: {meta.name}",
f"dtype: {meta.dtype}",
f"shape: {meta.shape}",
]
elif isinstance(meta, ConstMeta):
rows = [
f"type: {node_type}",
f"name: {meta.name_in_model}",
f"dtype: {meta.dtype}",
f"shape: {meta.shape}",
]
if isinstance(meta, FunctionMeta):
rows = [
f"type: {node_type}",
f"op_name: {meta.op_name}",
f"fn_name: {meta.fn_name}",
f"args: {args_to_label(meta.args)}",
f"kwargs: {kwargs_to_label(meta.kwargs)}",
]
return "{" + fix_dot_label("|".join(rows)) + "}"
else:
if isinstance(meta, InOutMeta):
return f"{meta.name}"
if isinstance(meta, ConstMeta):
return f"{meta.name_in_model}"
if isinstance(meta, FunctionMeta):
return f"{meta.op_name}"
raise ValueError(f"Unknown meta node {type(meta)}")


def get_label_from_edge_data(node_data: Dict[str, Any], style: PydotStyleTemplate) -> str:
"""
Generates a label for a graph edge based on its metadata and the desired style.
:param edge_data: Metadata of the edge.
:param style: Style template to determine the label format.
:return: Formatted label for the edge.
"""
meta = node_data["meta"]
assert isinstance(meta, EdgeMeta)

if style == PydotStyleTemplate.disable:
return f"{meta.output_port}{meta.input_port}"
else:
return f"{meta.shape}\n{meta.output_port}{meta.input_port}"


_RAINBOW_COLORS = [
"#ffadad",
"#ffc2a9",
"#ffd6a5",
"#fdffb6",
"#caffbf",
"#b3fbdf",
"#aae0ef",
"#a0c4ff",
"#bdb2ff",
"#ffc6ff",
]


def color_picker(data: str) -> str:
"""
Picks a color from a predefined set of colors based on the input string.
:param data: Input string to determine the color.
:return: Hex code of the selected color.
"""
data = "".join(d for d in data if d.isalpha())
hash_object = hashlib.sha256(data.encode())
hex_int = int(hash_object.hexdigest()[:6], 16)
return _RAINBOW_COLORS[hex_int % len(_RAINBOW_COLORS)]


def get_style(node: Dict[str, Any], style: PydotStyleTemplate) -> Dict[str, str]:
"""
Generates a style dictionary for a graph node based on its metadata and the desired style.
:param node: Metadata of the node.
:param style: Style template to determine the node style.
:return: Dictionary containing style attributes for the node.
"""
if style == PydotStyleTemplate.disable:
return {}
meta = node["meta"]

if isinstance(meta, InOutMeta):
return {
"fillcolor": "#adadad",
"fontcolor": "#000000",
"shape": "record",
"style": '"filled,rounded"',
}
if isinstance(meta, ConstMeta):
return {
"fillcolor": "#ffffff",
"fontcolor": "#000000",
"shape": "record",
"style": '"filled,rounded"',
}
if isinstance(meta, FunctionMeta):
return {
"fillcolor": color_picker(meta.fn_name),
"fontcolor": "#000000",
"shape": "record",
"style": '"filled,rounded"',
}

raise ValueError(f"Unknown meta node {type(meta)}")


def to_pydot(nx_graph: nx.MultiDiGraph, style_template: PydotStyleTemplate = PydotStyleTemplate.full) -> pydot.Graph:
"""
Converts a NetworkX directed graph to a Pydot graph with specified styling.
:param nx_graph: Input NetworkX directed graph.
:param style_template: Style template to determine node and edge styles.
:return: Pydot graph representation of the input NetworkX graph.
"""
dot_graph = pydot.Dot("", rankdir="TB")

for key, data in nx_graph.nodes(data=True):
style = get_style(data, style_template)
dot_node = pydot.Node(key, label=get_label_from_node_data(data, style_template), **style)
dot_graph.add_node(dot_node)

for key_from, key_to, data in nx_graph.edges(data=True):
dot_edge = pydot.Edge(key_from, key_to, label=get_label_from_edge_data(data, style_template))
dot_graph.add_edge(dot_edge)

return dot_graph
Loading

0 comments on commit 94f1006

Please sign in to comment.