Skip to content

Commit

Permalink
[graphIO] Introduce graph serializer classes
Browse files Browse the repository at this point in the history
Move the serialization logic to dedicated serializer classes.
Implement both `GraphSerializer` and `TemplateGraphSerializer`
to cover for the existing serialization use-cases.
  • Loading branch information
yann-lty committed Dec 5, 2024
1 parent 7d4f353 commit 9f9fe0a
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 31 deletions.
43 changes: 14 additions & 29 deletions meshroom/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from meshroom.core import Version
from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit
from meshroom.core.graphIO import GraphIO
from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer
from meshroom.core.node import Status, Node, CompatibilityNode
from meshroom.core.nodeFactory import nodeFactory
from meshroom.core.typing import PathLike
Expand Down Expand Up @@ -1380,39 +1380,24 @@ def toDict(self):
def asString(self):
return str(self.toDict())

def serialize(self, asTemplate: bool = False) -> dict:
"""Serialize this Graph instance.
Args:
asTemplate: Whether to use the template serialization.
Returns:
The serialized graph data.
"""
SerializerClass = TemplateGraphSerializer if asTemplate else GraphSerializer
return SerializerClass(self).serialize()

def save(self, filepath=None, setupProjectFile=True, template=False):
path = filepath or self._filepath
if not path:
raise ValueError("filepath must be specified for unsaved files.")

self.header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
self.header[GraphIO.Keys.FileVersion] = GraphIO.__version__

# Store versions of node types present in the graph (excluding CompatibilityNode instances)
# and remove duplicates
usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes if isinstance(n, Node)])
# Convert to node types to "name: version"
nodesVersions = {
"{}".format(p.__name__): meshroom.core.nodeVersion(p, "0.0")
for p in usedNodeTypes
}
# Sort them by name (to avoid random order changing from one save to another)
nodesVersions = dict(sorted(nodesVersions.items()))
# Add it the header
self.header[GraphIO.Keys.NodesVersions] = nodesVersions
self.header["template"] = template

data = {}
if template:
data = {
GraphIO.Keys.Header: self.header,
GraphIO.Keys.Graph: self.getNonDefaultInputAttributes()
}
else:
data = {
GraphIO.Keys.Header: self.header,
GraphIO.Keys.Graph: self.toDict()
}
data = self.serialize(template)

with open(path, 'w') as jsonFile:
json.dump(data, jsonFile, indent=4)
Expand Down
116 changes: 114 additions & 2 deletions meshroom/core/graphIO.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from enum import Enum
from typing import Union
from typing import Any, TYPE_CHECKING, Union

import meshroom
from meshroom.core import Version
from meshroom.core.node import Node

if TYPE_CHECKING:
from meshroom.core.graph import Graph


class GraphIO:
Expand Down Expand Up @@ -29,7 +34,7 @@ class Features(Enum):
NodesPositions = "nodesPositions"

@staticmethod
def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features",...]:
def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features", ...]:
"""Return the list of supported features based on a file version.
Args:
Expand All @@ -54,3 +59,110 @@ def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Fe

return tuple(features)


class GraphSerializer:
"""Standard Graph serializer."""

def __init__(self, graph: "Graph") -> None:
self._graph = graph

def serialize(self) -> dict:
"""
Serialize the Graph.
"""
return {
GraphIO.Keys.Header: self.serializeHeader(),
GraphIO.Keys.Graph: self.serializeContent(),
}

@property
def nodes(self) -> list[Node]:
return self._graph.nodes

def serializeHeader(self) -> dict:
"""Build and return the graph serialization header.
The header contains metadata about the graph, such as the:
- version of the software used to create it.
- version of the file format.
- version of the nodes types used in the graph.
- template flag.
Args:
nodes: (optional) The list of nodes to consider for node types versions - use all nodes if not specified.
template: Whether the graph is going to be serialized as a template.
"""
header: dict[str, Any] = {}
header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
header[GraphIO.Keys.FileVersion] = GraphIO.__version__
header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions()
return header

def _getNodeTypesVersions(self) -> dict[str, str]:
"""Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances."""
nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)])
nodeTypesVersions = {
nodeType.__name__: meshroom.core.nodeVersion(nodeType, "0.0") for nodeType in nodeTypes
}
# Sort them by name (to avoid random order changing from one save to another).
return dict(sorted(nodeTypesVersions.items()))

def serializeContent(self) -> dict:
"""Graph content serialization logic."""
return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)}

def serializeNode(self, node: Node) -> dict:
"""Node serialization logic."""
return node.toDict()


class TemplateGraphSerializer(GraphSerializer):
"""Serializer for serializing a graph as a template."""

def serializeHeader(self) -> dict:
header = super().serializeHeader()
header["template"] = True
return header

def serializeNode(self, node: Node) -> dict:
"""Adapt node serialization to template graphs.
Instead of getting all the inputs and internal attribute keys, only get the keys of
the attributes whose value is not the default one.
The output attributes, UIDs, parallelization parameters and internal folder are
not relevant for templates, so they are explicitly removed from the returned dictionary.
"""
# For now, implemented as a post-process to update the default serialization.
nodeData = super().serializeNode(node)

inputKeys = list(nodeData["inputs"].keys())

internalInputKeys = []
internalInputs = nodeData.get("internalInputs", None)
if internalInputs:
internalInputKeys = list(internalInputs.keys())

for attrName in inputKeys:
attribute = node.attribute(attrName)
# check that attribute is not a link for choice attributes
if attribute.isDefault and not attribute.isLink:
del nodeData["inputs"][attrName]

for attrName in internalInputKeys:
attribute = node.internalAttribute(attrName)
# check that internal attribute is not a link for choice attributes
if attribute.isDefault and not attribute.isLink:
del nodeData["internalInputs"][attrName]

# If all the internal attributes are set to their default values, remove the entry
if len(nodeData["internalInputs"]) == 0:
del nodeData["internalInputs"]

del nodeData["outputs"]
del nodeData["uid"]
del nodeData["internalFolder"]
del nodeData["parallelization"]

return nodeData


0 comments on commit 9f9fe0a

Please sign in to comment.