Skip to content

Commit

Permalink
Implement initial support for *Graph Editor*.
Browse files Browse the repository at this point in the history
  • Loading branch information
KelSolaar committed Dec 29, 2024
1 parent dbda4a1 commit 616a3b0
Show file tree
Hide file tree
Showing 9 changed files with 36,330 additions and 0 deletions.
2 changes: 2 additions & 0 deletions colour_hdri/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
NodeDownsample,
NodeMergeImageStack,
NodeNormaliseExposure,
NodePassthrough,
NodeProcessingMetadata,
NodeProcessRawFileRawpy,
NodeReadFileMetadataDNG,
Expand Down Expand Up @@ -307,6 +308,7 @@ def __getattr__(self, attribute: str) -> Any:
"NodeDownsample",
"NodeMergeImageStack",
"NodeNormaliseExposure",
"NodePassthrough",
"NodeProcessRawFileRawpy",
"NodeProcessingMetadata",
"NodeReadFileMetadataDNG",
Expand Down
2 changes: 2 additions & 0 deletions colour_hdri/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .nodes import (
InputTransform,
NodePassthrough,
NodeConvertRawFileToDNGFile,
NodeReadImage,
NodeWriteImage,
Expand Down Expand Up @@ -35,6 +36,7 @@

__all__ = [
"InputTransform",
"NodePassthrough",
"NodeConvertRawFileToDNGFile",
"NodeReadImage",
"NodeWriteImage",
Expand Down
259 changes: 259 additions & 0 deletions colour_hdri/network/graph_editor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import json
import logging
import re
from collections import defaultdict
from pathlib import Path

from colour.utilities import (
ExecutionNode,
For,
ParallelForMultiprocess,
ParallelForThread,
PortGraph,
PortNode,
as_float_array,
)
from PySide6.QtCore import QSettings, Qt, QUrl
from PySide6.QtWebEngineCore import QWebEngineSettings
from PySide6.QtWebEngineWidgets import QWebEngineView
from PySide6.QtWidgets import (
QApplication,
QDockWidget,
QMainWindow,
QPushButton,
QVBoxLayout,
QWidget,
)

import colour_hdri.network.nodes

LOGGER = logging.getLogger(__name__)

HTML_INDEX = Path(__file__).parent / "resources" / "index.html"


def collect_colourscience_nodes() -> dict[str, ExecutionNode | PortNode]:
nodes = {}

nodes["For"] = For
nodes["ParallelForThread"] = ParallelForThread
nodes["ParallelForMultiprocess"] = ParallelForMultiprocess

for name in colour_hdri.network.nodes.__all__:
object_ = getattr(colour_hdri.network.nodes, name)

if issubclass(object_, (ExecutionNode, PortNode)):
nodes[name] = object_

return nodes


COLOUR_SCIENCE_NODES: dict[str, ExecutionNode | PortNode] = (
collect_colourscience_nodes()
)


class GraphEditor(QMainWindow):
def __init__(self, developer_tools: bool = False) -> None:
super().__init__()

self._settings = QSettings("colour-science", "GraphEditor")
LOGGER.info("Settings location: %s", self._settings.fileName())

self._lg_node_to_node = {}

self.setWindowTitle("Graph Editor")
self.resize(1280, 720)

self._central_widget_QWidget = QWidget()
self.setCentralWidget(self._central_widget_QWidget)

self._webview_QWebEngineView = QWebEngineView()
self._stash_graph_QPushButton = QPushButton("Stash Graph")
self._unstash_graph_QPushButton = QPushButton("Unstash Graph")
self._evaluate_graph_QPushButton = QPushButton("Evaluate Graph")

layout = QVBoxLayout()
layout.addWidget(self._webview_QWebEngineView)
layout.addWidget(self._stash_graph_QPushButton)
layout.addWidget(self._unstash_graph_QPushButton)
layout.addWidget(self._evaluate_graph_QPushButton)
self._central_widget_QWidget.setLayout(layout)

if developer_tools:
self._developer_tools_QWebEngineView = QWebEngineView()
self._webview_QWebEngineView.page().setDevToolsPage(
self._developer_tools_QWebEngineView.page()
)
self._developer_tools_QDockWidget = QDockWidget("Developer Tools", self)
self._developer_tools_QDockWidget.setWidget(
self._developer_tools_QWebEngineView
)
self.addDockWidget(
Qt.RightDockWidgetArea, self._developer_tools_QDockWidget
)

self._setup_views()
self._setup_signals()

def _setup_views(self) -> None:
self._webview_QWebEngineView.settings().setAttribute(
QWebEngineSettings.LocalContentCanAccessRemoteUrls, True
)
self._webview_QWebEngineView.load(QUrl.fromLocalFile(HTML_INDEX.resolve()))

def _setup_signals(self) -> None:
self._webview_QWebEngineView.loadFinished.connect(
self._webview_QWebEngineView_on_page_loaded
)

self._stash_graph_QPushButton.clicked.connect(
self._stash_graph_QPushButton_clicked
)
self._unstash_graph_QPushButton.clicked.connect(
self._unstash_graph_QPushButton_clicked
)
self._evaluate_graph_QPushButton.clicked.connect(
self._evaluate_graph_QPushButton_clicked
)

def closeEvent(self, event):
self.save_settings()

super().closeEvent(event)

def load_settings(self): ...

def save_settings(self): ...

def _webview_QWebEngineView_on_page_loaded(self) -> None:
self._register_colourscience_nodes()

def _stash_graph_QPushButton_clicked(self) -> None:
self._webview_QWebEngineView.page().runJavaScript(
"serializeGraph();", self._stash_graph_callback
)

def _unstash_graph_QPushButton_clicked(self) -> None:
if (graph_stash := self._settings.value("graph_stash")) is not None:
self._webview_QWebEngineView.page().runJavaScript(
f"deserializeGraph('{graph_stash}');"
)

def _evaluate_graph_QPushButton_clicked(self) -> None:
self._webview_QWebEngineView.page().runJavaScript(
"serializeGraph();", self._evaluate_graph_callback
)

def _stash_graph_callback(self, result: str) -> None:
self._settings.setValue("graph_stash", result)

def _evaluate_graph_callback(self, result: str) -> None:
try:
graph = self._build_graph(json.loads(result))
graph.process()
except Exception as exception:
LOGGER.critical(str(exception))

def _register_colourscience_nodes(self) -> None:
registry = ""
for node_class in COLOUR_SCIENCE_NODES.values():
node = node_class()
class_name = node.__class__.__name__
node_name = re.sub("^Node", "", class_name)
title = " " if node_name == "Passthrough" else node_name
description = node.description.replace('"', "'")

registry += f"function colourscience_{node_name}() {{"
for input_port in node.input_ports:
registry += f'\tthis.addInput("{input_port}");'
for output_port in node.output_ports:
registry += f'\tthis.addOutput("{output_port}");'
registry += "\n"
registry += "\tthis.properties = {"
registry += f'\t\tpythonClassName : "{class_name}"'
registry += "\t};"
registry += "\n"
if node_name == "Passthrough":
registry += "\tthis.flags = {"
registry += "\t\tcollapsed : true"
registry += "\t};"
registry += "\n"
registry += "};"
registry += "\n"
registry += f'colourscience_{node_name}.title = "{title}";'
registry += f'colourscience_{node_name}.desc = "{description}";'
registry += "\n"
registry += (
f'LiteGraph.registerNodeType("colourscience/{node_name}", '
f"colourscience_{node_name});"
)
registry += "\n"

self._webview_QWebEngineView.page().runJavaScript(registry)

def _build_graph(self, lg_graph: dict) -> PortGraph:
self._lg_node_to_node = {}

edges = defaultdict(dict)
constants = {}

graph = PortGraph()

# Nodes
for lg_node in lg_graph["nodes"]:
if (name := lg_node["properties"].get("pythonClassName")) is not None:
node = COLOUR_SCIENCE_NODES[name](
f'{lg_node.get("title", name)} ({lg_node["id"]})'
)
self._lg_node_to_node[lg_node["id"]] = node
graph.add_node(node)
elif lg_node["type"].startswith("basic/"):
value = lg_node["properties"]["value"]

if lg_node["type"] == "basic/array":
value = as_float_array(eval(value))

constants[lg_node["id"]] = value

for input_ in lg_node.get("inputs", []):
if input_.get("link") is not None:
edges[input_["link"]]["input"] = (lg_node["id"], input_["name"])

for output in lg_node.get("outputs", []):
if output.get("links") is not None:
for link in output["links"]:
edges[link]["output"] = (lg_node["id"], output["name"])

# Edges
for group in edges.values():
input_, output = group["input"], group["output"]
if (output_node := self._lg_node_to_node.get(output[0])) is not None and (
input_node := self._lg_node_to_node.get(input_[0])
) is not None:
output_node.connect(output[1], input_node, input_[1])

# Constants
for id_, value in constants.items():
for group in edges.values():
input_, output = group["input"], group["output"]
if output[0] == id_:
if (input_node := self._lg_node_to_node.get(input_[0])) is not None:
input_node.set_input(input_[1], value)
break

graph.to_graphviz().write_svg("graph.svg")

return graph


if __name__ == "__main__":
import sys

logging.basicConfig(level=logging.INFO)

application = QApplication(sys.argv)
graph_editor = GraphEditor(False)
graph_editor.show()

sys.exit(application.exec())
30 changes: 30 additions & 0 deletions colour_hdri/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from colour.utilities import (
CanonicalMapping,
ExecutionNode,
PortNode,
as_float_array,
batch,
ones,
Expand Down Expand Up @@ -85,6 +86,7 @@
__all__ = [
"JSONEncoderEXRAttribute",
"InputTransform",
"NodePassthrough",
"NodeConvertRawFileToDNGFile",
"NodeReadImage",
"NodeWriteImage",
Expand Down Expand Up @@ -179,6 +181,34 @@ def __eq__(self, other: object) -> bool:
return np.all(self.M == other.M) and np.all(self.RGB_w == other.RGB_w) # pyright: ignore


class NodePassthrough(PortNode):
"""
Pass the input data through.
Methods
-------
- :meth:`~colour_hdri.NodePassthrough.__init__`
- :meth:`~colour_hdri.NodePassthrough.process`
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self.description = "Pass the input data through"

self.add_input_port("input")
self.add_output_port("output")

def process(self, **kwargs: Any) -> None: # noqa: ARG002
"""
Process the node.
"""

self.set_output("output", self.get_input("input"))

self.dirty = False


class NodeConvertRawFileToDNGFile(ExecutionNode):
"""
Convert given raw file, e.g., *CR2*, *CR3*, *NEF*, to *DNG*.
Expand Down
Loading

0 comments on commit 616a3b0

Please sign in to comment.