Skip to content

Commit

Permalink
Python: Add a Node class instead of using the NodeProto everywhere.
Browse files Browse the repository at this point in the history
One major drawback of using `NodeProto` everywhere is we can only get
`TensorProto` whereas what we want usually is `Tensor`. This adds a
`Node` class where we can freely store `Tensor` as its inputs/outputs.
`NodeProto` creation is deferred until the user wants serialization.

Using `Node` makes some processes like _insert_switch_nodes_to_graph()
surprisingly cleaner.

This also makes the code more Python idiomatic.

TESTED=all tests are passing.
  • Loading branch information
yaoyuannnn committed Jul 29, 2020
1 parent 1c4e84f commit fd00c97
Show file tree
Hide file tree
Showing 11 changed files with 450 additions and 294 deletions.
128 changes: 81 additions & 47 deletions smaug/python/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@
from smaug.core import types_pb2
from smaug.core import tensor_pb2
from smaug.python import global_vars
from smaug.python.node import Node
from smaug.python.tensor import Tensor

class Graph:
def __init__(
self, name="DefaultGraph", backend="Reference",
mem_policy=types_pb2.AllDma):
assert (backend in global_vars.backend_alignment)
self.graph = graph_pb2.GraphProto()
if backend not in global_vars.backend_alignment:
raise ValueError("An unknown backend %s is used!" % backend)
self._name = name
self._backend = backend
self._mem_policy = mem_policy
self._nodes = []
self._node_names = {}
self.graph.name = name
self.graph.backend = backend
self.graph.mem_policy = mem_policy
self.alignment = global_vars.backend_alignment[backend]
self._alignment = global_vars.backend_alignment[self._backend]
# Layout transformation is enabled by default.
self._layout_trans_enabled = True
# This proto stores all the parameters in the network.
self.tensor_data_array = tensor_pb2.TensorDataArray()
self._parent_graph = None

def __enter__(self):
Expand All @@ -39,20 +39,28 @@ def __exit__(self, *args):

@property
def backend(self):
return self.graph.backend
return self._backend

@property
def mem_policy(self):
return self.graph.mem_policy
return self._mem_policy

@property
def alignment(self):
return self._alignment

@property
def layout_trans_enabled(self):
return self._layout_trans_enabled

def merge(self, other):
"""Merge another graph into this."""
for node in other.get_nodes():
if self.get_node(node.name) is not None:
raise ValueError(
"The graph to be merged contains a node with the same name as one "
"in the current graph. Possibly merging a graph more than once?")
self.get_nodes().extend(other.get_nodes())
self.tensor_data_array.data_array.extend(other.tensor_data_array.data_array)

def add_node(
self, name, op, input_tensors, output_tensors_dims,
Expand All @@ -75,24 +83,16 @@ def add_node(
Returns:
The output tensor of the added node.
"""
node = self.graph.nodes.add()
node.name = self.create_unique_name(name)
node.op = op

# Add the parameters to the node.
if params != None:
node.params.CopyFrom(params)
name = self.create_unique_name(name)
node = Node(name, op, params)
self._nodes.append(node)

# Update the node's parents field, and add every input tensor to the node.
# Add every input tensor to the node.
for i,tensor in enumerate(input_tensors):
if tensor.name == None:
tensor.name = node.name + "/input%d" % i
if tensor.source is not None:
node.parents.append(tensor.source[0].name)
node.src_tensors_indices.append(tensor.source[1])
node.add_input(tensor)
tensor.targets.append(node)
input_tensor_proto = node.input_tensors.add()
tensor.to_tensor_proto(input_tensor_proto, self.tensor_data_array)

# Create the output tensor (with the node as its source), and add it to the
# node.
Expand All @@ -101,33 +101,33 @@ def add_node(
output_tensor = Tensor(
dims=d, name="%s/output%d" % (node.name, i),
data_layout=output_tensor_layout, data_type=output_tensor_dtype,
data_format=output_tensor_dformat, source=(node, i),
alignment=self.alignment)
output_tensor_proto = node.output_tensors.add()
output_tensor.to_tensor_proto(output_tensor_proto, self.tensor_data_array)
data_format=output_tensor_dformat, source=node, source_index=i,
alignment=self._alignment)
node.add_output(output_tensor)
output_tensors.append(output_tensor)

return output_tensors

def get_node(self, node_name, recursive=False):
"""Return a node in the graph proto by its name.
"""Return a node in the graph by its name.
Args:
node_name: Node name.
recursive: If true, recursively search the node in the parent graphs.
Returns:
A NodeProto if we find the node.
A `Node` if we find the node or None is returned.
"""
for i in range(len(self.graph.nodes)):
if self.graph.nodes[i].name == node_name:
return self.graph.nodes[i]
for node in self._nodes:
if node.name == node_name:
return node
if recursive and self._parent_graph is not None:
return self._parent_graph.get_node(node_name, True)
return None

def get_nodes(self):
"""Return nodes in the graph proto."""
return self.graph.nodes
"""Return nodes in the graph."""
return self._nodes

def get_root_graph(self):
"""Return the root graph."""
Expand Down Expand Up @@ -167,19 +167,36 @@ def enable_layout_transform(self):
"""Enable automatic layout transformation."""
self._layout_trans_enabled = True

def to_proto(self):
"""Serialize the graph.
Returns:
A tuple of (`GraphProto`, `TensorDataArray`).
"""
graph_proto = graph_pb2.GraphProto()
graph_proto.name = self._name
graph_proto.backend = self._backend
graph_proto.mem_policy = self._mem_policy
tensor_data_array = tensor_pb2.TensorDataArray()
for node in self._nodes:
graph_proto.nodes.append(node.to_proto(tensor_data_array))
return graph_proto, tensor_data_array

def write_graph(self, name=None):
"""Serialize the graph to a protobuf file.
Args:
name: Name of the output protobuf file. If not specified, use the graph's
name instead.
"""
if name == None:
topo_name = self.graph.name + "_topo.pbtxt"
params_name = self.graph.name + "_params.pb"
graph_proto, tensor_data_array = self.to_proto()
if name is None:
name = self._name
topo_name = name + "_topo.pbtxt"
params_name = name + "_params.pb"
with open(topo_name, "w") as f_topo, open(params_name, "wb") as f_params:
f_topo.write(text_format.MessageToString(self.graph))
f_params.write(self.tensor_data_array.SerializeToString())
f_topo.write(text_format.MessageToString(graph_proto))
f_params.write(tensor_data_array.SerializeToString())

def print_summary(self):
"""Print the summary of the graph.
Expand All @@ -189,28 +206,45 @@ def print_summary(self):
input/output tensors.
"""
print("=================================================================")
print(" Summary of the network: %s (%s)" % (self.graph.name,
self.graph.backend))
print(" Summary of the network: %s (%s)" % (self._name, self._backend))
print("=================================================================")
print(
"Host memory access policy: %s." %
types_pb2.HostMemoryAccessPolicy.Name(self.graph.mem_policy))
types_pb2.HostMemoryAccessPolicy.Name(self._mem_policy))
print("-----------------------------------------------------------------")
for node in self.graph.nodes:
for node in self._nodes:
print("Name: %s (%s)" % (node.name, types_pb2.OpType.Name(node.op)))
print("Parents:", end = '')
for i in node.parents:
for i in node.get_parents():
print(i, end = ' ')
print("\nChildren:", end = '')
for i in node.get_children():
print(i, end = ' ')
print("\nInput tensors:")
for t in node.input_tensors:
for t in node.inputs:
print(
" ", t.name, types_pb2.DataType.Name(t.data_type), t.shape.dims,
types_pb2.DataLayout.Name(t.shape.layout),
"alignment(%d)" % t.shape.alignment)
print("Output tensors:")
for t in node.output_tensors:
for t in node.outputs:
print(
" ", t.name, types_pb2.DataType.Name(t.data_type), t.shape.dims,
types_pb2.DataLayout.Name(t.shape.layout),
"alignment(%d)" % t.shape.alignment)
print("-----------------------------------------------------------------")

def get_node_proto(graph_proto, node_name):
"""Get a `NodeProto` from `GraphProto` by node name.
Args:
graph_proto: A `GraphProto`.
node_name: Name of the node.
Returns:
A `NodeProto` or None.
"""
for node_proto in graph_proto.nodes:
if node_proto.name == node_name:
return node_proto
return None
122 changes: 122 additions & 0 deletions smaug/python/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import sys
import numpy as np

from smaug.core import node_pb2
from smaug.core import types_pb2
from smaug.python import global_vars
from smaug.python import datatypes

class Node:
def __init__(self, name, op, params=None, inputs=None, outputs=None):
"""Create a node.
A `Node` instance contains information about its corresponding operation,
including the operator type, parameters and input/output tensors. A `Graph`
is made up of `Node`s. When serialized, a `NodeProto` is created.
Args:
name: Name of the node.
op: `OpType` representing the operation type of the node.
params: `Params` used by the operator (optional).
inputs: A list of `Tensor` (optional).
outputs: A list of `Tensor` (optional).
Returns:
A `Node` instance.
"""
self._name = name
self._op = op
self._params = params
self._inputs = [] if inputs is None else inputs
self._outputs = [] if outputs is None else outputs

@property
def name(self):
return self._name

@property
def op(self):
return self._op

@property
def inputs(self):
return self._inputs

@property
def outputs(self):
return self._outputs

def add_input(self, tensor):
"""Add an input tensor to the node.
Args:
tensor: A `Tensor`.
"""
self._inputs.append(tensor)

def add_output(self, tensor):
"""Add an output tensor to the node.
Args:
tensor: A `Tensor`.
"""
self._outputs.append(tensor)

def update_input(self, tensor, index):
"""Update the `index`th input with `tensor`.
Args:
tensor: A `Tensor` representing the new input.
index: The input index.
"""
self._inputs[index] = tensor

def get_parents(self):
"""Get the parents of the node.
Returns:
A list of strings representing names of the parent nodes.
"""
parents = []
for tensor in self._inputs:
if tensor.source is not None:
parents.append(tensor.source.name)
return parents

def get_children(self):
"""Get the children of the node.
Returns:
A list of strings representing names of the children nodes.
"""
children = []
for tensor in self._outputs:
for target in tensor.targets:
children.append(target.name)
return children

def to_proto(self, tensor_data_array):
"""Serialize `Node` into `NodeProto`.
Args:
tensor_data_array: `TensorDataArray` that tensor data gets serialized
into.
Returns:
A `NodeProto`.
"""
node_proto = node_pb2.NodeProto()
node_proto.name = self._name
node_proto.op = self._op
if self._params is not None:
node_proto.params.CopyFrom(self._params)
for tensor in self._inputs:
if tensor.source is not None:
node_proto.parents.append(tensor.source.name)
node_proto.src_tensors_indices.append(tensor.source_index)
tensor_proto = node_proto.input_tensors.add()
tensor.to_tensor_proto(tensor_proto, tensor_data_array)
for tensor in self._outputs:
tensor_proto = node_proto.output_tensors.add()
tensor.to_tensor_proto(tensor_proto, tensor_data_array)
return node_proto
Loading

0 comments on commit fd00c97

Please sign in to comment.