From 9f9fe0a5f45c49c0b9f59b647c9d7ee9ff59837c Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Thu, 5 Dec 2024 19:39:25 +0100 Subject: [PATCH] [graphIO] Introduce graph serializer classes Move the serialization logic to dedicated serializer classes. Implement both `GraphSerializer` and `TemplateGraphSerializer` to cover for the existing serialization use-cases. --- meshroom/core/graph.py | 43 +++++---------- meshroom/core/graphIO.py | 116 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 128 insertions(+), 31 deletions(-) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 9c9ba8eebb..2b4512c7b2 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -17,7 +17,7 @@ from meshroom.core import Version from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit -from meshroom.core.graphIO import GraphIO +from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer from meshroom.core.node import Status, Node, CompatibilityNode from meshroom.core.nodeFactory import nodeFactory from meshroom.core.typing import PathLike @@ -1380,39 +1380,24 @@ def toDict(self): def asString(self): return str(self.toDict()) + def serialize(self, asTemplate: bool = False) -> dict: + """Serialize this Graph instance. + + Args: + asTemplate: Whether to use the template serialization. + + Returns: + The serialized graph data. + """ + SerializerClass = TemplateGraphSerializer if asTemplate else GraphSerializer + return SerializerClass(self).serialize() + def save(self, filepath=None, setupProjectFile=True, template=False): path = filepath or self._filepath if not path: raise ValueError("filepath must be specified for unsaved files.") - self.header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__ - self.header[GraphIO.Keys.FileVersion] = GraphIO.__version__ - - # Store versions of node types present in the graph (excluding CompatibilityNode instances) - # and remove duplicates - usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes if isinstance(n, Node)]) - # Convert to node types to "name: version" - nodesVersions = { - "{}".format(p.__name__): meshroom.core.nodeVersion(p, "0.0") - for p in usedNodeTypes - } - # Sort them by name (to avoid random order changing from one save to another) - nodesVersions = dict(sorted(nodesVersions.items())) - # Add it the header - self.header[GraphIO.Keys.NodesVersions] = nodesVersions - self.header["template"] = template - - data = {} - if template: - data = { - GraphIO.Keys.Header: self.header, - GraphIO.Keys.Graph: self.getNonDefaultInputAttributes() - } - else: - data = { - GraphIO.Keys.Header: self.header, - GraphIO.Keys.Graph: self.toDict() - } + data = self.serialize(template) with open(path, 'w') as jsonFile: json.dump(data, jsonFile, indent=4) diff --git a/meshroom/core/graphIO.py b/meshroom/core/graphIO.py index b7f7ad5a12..bc65629212 100644 --- a/meshroom/core/graphIO.py +++ b/meshroom/core/graphIO.py @@ -1,7 +1,12 @@ from enum import Enum -from typing import Union +from typing import Any, TYPE_CHECKING, Union +import meshroom from meshroom.core import Version +from meshroom.core.node import Node + +if TYPE_CHECKING: + from meshroom.core.graph import Graph class GraphIO: @@ -29,7 +34,7 @@ class Features(Enum): NodesPositions = "nodesPositions" @staticmethod - def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features",...]: + def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features", ...]: """Return the list of supported features based on a file version. Args: @@ -54,3 +59,110 @@ def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Fe return tuple(features) + +class GraphSerializer: + """Standard Graph serializer.""" + + def __init__(self, graph: "Graph") -> None: + self._graph = graph + + def serialize(self) -> dict: + """ + Serialize the Graph. + """ + return { + GraphIO.Keys.Header: self.serializeHeader(), + GraphIO.Keys.Graph: self.serializeContent(), + } + + @property + def nodes(self) -> list[Node]: + return self._graph.nodes + + def serializeHeader(self) -> dict: + """Build and return the graph serialization header. + + The header contains metadata about the graph, such as the: + - version of the software used to create it. + - version of the file format. + - version of the nodes types used in the graph. + - template flag. + + Args: + nodes: (optional) The list of nodes to consider for node types versions - use all nodes if not specified. + template: Whether the graph is going to be serialized as a template. + """ + header: dict[str, Any] = {} + header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__ + header[GraphIO.Keys.FileVersion] = GraphIO.__version__ + header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions() + return header + + def _getNodeTypesVersions(self) -> dict[str, str]: + """Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances.""" + nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)]) + nodeTypesVersions = { + nodeType.__name__: meshroom.core.nodeVersion(nodeType, "0.0") for nodeType in nodeTypes + } + # Sort them by name (to avoid random order changing from one save to another). + return dict(sorted(nodeTypesVersions.items())) + + def serializeContent(self) -> dict: + """Graph content serialization logic.""" + return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)} + + def serializeNode(self, node: Node) -> dict: + """Node serialization logic.""" + return node.toDict() + + +class TemplateGraphSerializer(GraphSerializer): + """Serializer for serializing a graph as a template.""" + + def serializeHeader(self) -> dict: + header = super().serializeHeader() + header["template"] = True + return header + + def serializeNode(self, node: Node) -> dict: + """Adapt node serialization to template graphs. + + Instead of getting all the inputs and internal attribute keys, only get the keys of + the attributes whose value is not the default one. + The output attributes, UIDs, parallelization parameters and internal folder are + not relevant for templates, so they are explicitly removed from the returned dictionary. + """ + # For now, implemented as a post-process to update the default serialization. + nodeData = super().serializeNode(node) + + inputKeys = list(nodeData["inputs"].keys()) + + internalInputKeys = [] + internalInputs = nodeData.get("internalInputs", None) + if internalInputs: + internalInputKeys = list(internalInputs.keys()) + + for attrName in inputKeys: + attribute = node.attribute(attrName) + # check that attribute is not a link for choice attributes + if attribute.isDefault and not attribute.isLink: + del nodeData["inputs"][attrName] + + for attrName in internalInputKeys: + attribute = node.internalAttribute(attrName) + # check that internal attribute is not a link for choice attributes + if attribute.isDefault and not attribute.isLink: + del nodeData["internalInputs"][attrName] + + # If all the internal attributes are set to their default values, remove the entry + if len(nodeData["internalInputs"]) == 0: + del nodeData["internalInputs"] + + del nodeData["outputs"] + del nodeData["uid"] + del nodeData["internalFolder"] + del nodeData["parallelization"] + + return nodeData + +