diff --git a/orangecanvas/application/canvasmain.py b/orangecanvas/application/canvasmain.py index cd044399a..898d8c2ff 100644 --- a/orangecanvas/application/canvasmain.py +++ b/orangecanvas/application/canvasmain.py @@ -51,6 +51,7 @@ from ..scheme import Scheme, IncompatibleChannelTypeError, SchemeNode from ..scheme import readwrite from ..scheme.readwrite import UnknownWidgetDefinition +from ..scheme.node import UserMessage from ..gui.dropshadow import DropShadowFrame from ..gui.dock import CollapsibleDockWidget from ..gui.quickhelp import QuickHelpTipEvent @@ -1231,15 +1232,85 @@ def new_scheme_from_contents_and_path( ------- workflow: Optional[Scheme] """ - new_scheme = config.workflow_constructor(parent=self) - new_scheme.set_runtime_env( - "basedir", os.path.abspath(os.path.dirname(path))) - errors = [] # type: List[Exception] - try: + def warn(warning): + if isinstance(warning, readwrite.PickleDataWarning): + raise warning + + def load(fileobj, warning_handler=None, + data_deserializer=readwrite.default_deserializer): + new_scheme = config.workflow_constructor() + new_scheme.set_runtime_env( + "basedir", os.path.abspath(os.path.dirname(path))) + errors = [] # type: List[Exception] new_scheme.load_from( fileobj, registry=self.widget_registry, - error_handler=errors.append + error_handler=errors.append, warning_handler=warning_handler, + data_deserializer=data_deserializer + ) + return new_scheme, errors + + basename = os.path.basename(path) + pos = -1 + try: + pos = fileobj.tell() + new_scheme, errors = load( + fileobj, warning_handler=warn, + data_deserializer=readwrite.default_deserializer + ) + except (readwrite.UnsupportedPickleFormatError, + readwrite.PickleDataWarning): + mbox = QMessageBox( + self, icon=QMessageBox.Warning, + windowTitle=self.tr("Security Warning"), + text=self.tr( + "The file {basename} contains pickled data that can run " + "arbitrary commands on this computer.\n" + "Would you like to load the unsafe content anyway?" + ).format(basename=basename), + informativeText=self.tr( + "Only select Load unsafe if you trust the source " + "of the file." + ), + textFormat=Qt.PlainText, + standardButtons=QMessageBox.Yes | QMessageBox.No | + QMessageBox.Abort ) + mbox.setDefaultButton(QMessageBox.Abort) + yes = mbox.button(QMessageBox.Yes) + yes.setText(self.tr("Load unsafe")) + yes.setToolTip(self.tr( + "Load the complete file. Only select this if you trust " + "the origin of the file." + )) + no = mbox.button(QMessageBox.No) + no.setText(self.tr("Load partial")) + no.setToolTip(self.tr( + "Load the file only partially, striping out all the " + "unsafe content." + )) + res = mbox.exec() + if res == QMessageBox.Abort: + return None + elif res == QMessageBox.Yes: # load with unsafe data + data_deserializer = readwrite.default_deserializer_with_pickle_fallback + elif res == QMessageBox.No: # load but discard unsafe data + data_deserializer = readwrite.default_deserializer + else: + assert False + fileobj.seek(pos, os.SEEK_SET) + new_scheme, errors = load( + fileobj, warning_handler=None, + data_deserializer=data_deserializer + ) + for e in list(errors): + if isinstance(e, readwrite.UnsupportedPickleFormatError): + if e.node is not None and e.node in new_scheme.nodes: + e.node.set_state_message( + UserMessage( + "Did not restore settings", UserMessage.Warning, + message_id="-properties-restore-error-data", + )) + errors.remove(e) except Exception: # pylint: disable=broad-except log.exception("") message_critical( @@ -1262,6 +1333,7 @@ def new_scheme_from_contents_and_path( details=details, parent=self, ) + new_scheme.setParent(self) return new_scheme def check_requires(self, fileobj: IO) -> bool: @@ -1506,9 +1578,55 @@ def save_scheme_to(self, scheme, filename): # First write the scheme to a buffer so we don't truncate an # existing scheme file if `scheme.save_to` raises an error. buffer = io.BytesIO() + scheme.set_runtime_env("basedir", os.path.abspath(dirname)) try: - scheme.set_runtime_env("basedir", os.path.abspath(dirname)) - scheme.save_to(buffer, pretty=True, pickle_fallback=True) + try: + scheme.save_to( + buffer, pretty=True, data_serializer=readwrite.default_serializer + ) + except (readwrite.UnserializableTypeError, + readwrite.UnserializableValueError): + mb = QMessageBox( + parent=self, windowTitle=self.tr("Unsafe contents"), + icon=QMessageBox.Warning, + text=self.tr( + "The workflow contains parameters that cannot be " + "safely deserialized.\n" + "Would you like to save a partial workflow anyway."), + informativeText=self.tr( + "Workflow structure will be saved but some node " + "parameters will be lost." + ), + standardButtons=QMessageBox.Discard | QMessageBox.Ignore | + QMessageBox.Abort + ) + mb.setEscapeButton(QMessageBox.Abort) + mb.setDefaultButton(QMessageBox.Discard) + b = mb.button(QMessageBox.Ignore) + b.setText(self.tr("Save anyway")) + b.setToolTip(self.tr( + "Loading such a workflow will require explicit user " + "confirmation.")) + b = mb.button(QMessageBox.Discard) + b.setText(self.tr("Discard unsafe content")) + b.setToolTip(self.tr("The saved workflow will not be complete. " + "Some parameters will not be restored")) + res = mb.exec() + buffer.truncate(0) + if res == QMessageBox.Abort: + return False + if res == QMessageBox.Discard: + def serializer(node): + try: + return readwrite.default_serializer(node) + except Exception: + return None + else: + serializer = readwrite.default_serializer_with_pickle_fallback + scheme.save_to( + buffer, pretty=True, + data_serializer=serializer + ) except Exception: log.error("Error saving %r to %r", scheme, filename, exc_info=True) message_critical( @@ -2624,7 +2742,7 @@ def collectall( IncompatibleChannelTypeError)) ) contents = [] - if missing_node_defs is not None: + if missing_node_defs: contents.extend([ "Missing node definitions:", *[" \N{BULLET} " + e.args[0] for e in missing_node_defs], diff --git a/orangecanvas/application/tests/test_mainwindow.py b/orangecanvas/application/tests/test_mainwindow.py index eb91ae456..0014250f2 100644 --- a/orangecanvas/application/tests/test_mainwindow.py +++ b/orangecanvas/application/tests/test_mainwindow.py @@ -1,5 +1,6 @@ import os import tempfile +from contextlib import contextmanager from unittest.mock import patch from AnyQt.QtGui import QWhatsThisClickedEvent @@ -7,7 +8,7 @@ from .. import addons from ..outputview import TextStream -from ...scheme import SchemeTextAnnotation, SchemeLink +from ...scheme import SchemeTextAnnotation, SchemeLink, SchemeNode, Scheme from ...gui.quickhelp import QuickHelpTipEvent, QuickHelp from ...utils.shtools import temp_named_file from ...utils.pickle import swp_name @@ -177,6 +178,68 @@ def test_save(self): w.save_scheme() self.assertEqual(w.current_document().path(), self.filename) + @contextmanager + def patch_messagebox_exec(self, return_value): + with patch("AnyQt.QtWidgets.QMessageBox.exec", + return_value=return_value) as f: + with patch("AnyQt.QtWidgets.QMessageBox.exec_", + f): + yield f + + def test_save_unsafe_warn(self): + w = self.w + doc = w.current_document() + doc.setPath(self.filename) + node = SchemeNode(self.registry.widget("one")) + node.properties = {"a": object()} + doc.addNode(node) + + def contents(): + with open(self.filename, "r", encoding="utf-8") as f: + return f.read() + with self.patch_messagebox_exec(QMessageBox.Abort) as f: + w.save_scheme() + f.assert_called_with() + self.assertEqual(contents(), "") + with self.patch_messagebox_exec(QMessageBox.Discard) as f: + w.save_scheme() + f.assert_called_with() + self.assertNotIn("pickle", contents()) + + with self.patch_messagebox_exec(QMessageBox.Ignore) as f: + w.save_scheme() + f.assert_called_with() + self.assertIn("pickle", contents()) + + def test_load_unsafe_ask(self): + w = self.w + workflow = Scheme() + node = SchemeNode(self.registry.widget("one")) + node.properties = {"a": object()} + workflow.add_node(node) + with open(self.filename, "wb") as f: + workflow.save_to(f, pickle_fallback=True) + + with self.patch_messagebox_exec(QMessageBox.Abort) as f: + w.load_scheme(self.filename) + f.assert_called_with() + self.assertEqual(len(w.current_document().scheme().nodes), 0) + self.assertTrue(w.is_transient()) + + with self.patch_messagebox_exec(return_value=QMessageBox.No) as f: + w.load_scheme(self.filename) + f.assert_called_with() + workflow = w.current_document().scheme() + self.assertEqual(len(workflow.nodes), 1) + self.assertEqual(workflow.nodes[0].properties, {}) + + with self.patch_messagebox_exec(QMessageBox.Yes) as f: + w.load_scheme(self.filename) + f.assert_called_with() + workflow = w.current_document().scheme() + self.assertEqual(len(workflow.nodes), 1) + self.assertEqual(workflow.nodes[0].properties["a"].__class__, object) + def test_save_swp(self): w = self.w swpname = swp_name(w) diff --git a/orangecanvas/scheme/readwrite.py b/orangecanvas/scheme/readwrite.py index b7d0a915e..4a9e5d6e4 100644 --- a/orangecanvas/scheme/readwrite.py +++ b/orangecanvas/scheme/readwrite.py @@ -2,20 +2,18 @@ Scheme save/load routines. """ +import io import numbers -import sys -import types import warnings import base64 import binascii -import itertools +import pickle +from functools import partial from xml.etree.ElementTree import TreeBuilder, Element, ElementTree, parse -from collections import defaultdict -from itertools import chain, count +from itertools import chain -import pickle import json import pprint @@ -25,7 +23,7 @@ import logging from typing import ( - NamedTuple, Dict, Tuple, List, Union, Any, Optional, AnyStr, IO + NamedTuple, Dict, Tuple, List, Union, Any, Optional, AnyStr, IO, Callable ) from . import SchemeNode, SchemeLink @@ -45,6 +43,14 @@ class UnknownWidgetDefinition(Exception): pass +class DeserializationWarning(UserWarning): + node = None # type: Optional[SchemeNode] + + +class PickleDataWarning(DeserializationWarning): + pass + + def _ast_parse_expr(source): # type: (str) -> ast.Expression node = ast.parse(source, "", mode="eval") @@ -255,7 +261,7 @@ def parse_ows_etree_v_2_0(tree): for annot in tree.findall("annotations/*"): if annot.tag == "text": - rect = tuple_eval(annot.get("rect", "(0.0, 0.0, 20.0, 20.0)")) + rect = tuple_eval(annot.get("rect", "0.0, 0.0, 20.0, 20.0")) font_family = annot.get("font-family", "").strip() font_size = annot.get("font-size", "").strip() @@ -275,8 +281,8 @@ def parse_ows_etree_v_2_0(tree): rect, annot.text or "", font, content_type), ) elif annot.tag == "arrow": - start = tuple_eval(annot.get("start", "(0, 0)")) - end = tuple_eval(annot.get("end", "(0, 0)")) + start = tuple_eval(annot.get("start", "0, 0")) + end = tuple_eval(annot.get("end", "0, 0")) color = annot.get("fill", "red") annotation = _annotation( id=annot.get("id"), @@ -401,7 +407,73 @@ def resolve_replaced(scheme_desc: _scheme, registry: WidgetRegistry) -> _scheme: return scheme_desc._replace(nodes=nodes, links=links) -def scheme_load(scheme, stream, registry=None, error_handler=None): +def default_error_handler(err: Exception): + raise err + + +def default_warning_handler(err: Warning): + warnings.warn(err, stacklevel=2) + + +def default_serializer(node: SchemeNode, data_format="literal") -> Optional[Tuple[AnyStr, str]]: + if node.properties: + return dumps(node.properties, format=data_format), data_format + else: + return None + + +def default_serializer_with_pickle_fallback( + node: SchemeNode, data_format="literal" +) -> Optional[Tuple[AnyStr, str]]: + try: + return default_serializer(node, data_format=data_format) + except (UnserializableTypeError, UnserializableValueError): + data = pickle.dumps(node.properties, protocol=PICKLE_PROTOCOL) + data = base64.encodebytes(data).decode("ascii") + return data, "pickle" + + +DataSerializerType = Callable[[SchemeNode], Optional[Tuple[AnyStr, str]]] + + +def default_deserializer(payload, format_): + return loads(payload, format_) + + +def default_deserializer_with_pickle_fallback( + payload, format_, *, unpickler_class=None +): + if format_ == "pickle": + if isinstance(payload, str): + payload = payload.encode("ascii") + if unpickler_class is None: + unpickler_class = pickle.Unpickler + unpickler = unpickler_class(io.BytesIO(base64.decodebytes(payload))) + return unpickler.load() + else: + return default_deserializer(payload, format_) + + +DataDeserializerType = Callable[[AnyStr, str], Any] + + +def scheme_load( + scheme, stream, registry=None, + error_handler=None, warning_handler=None, + data_deserializer: DataDeserializerType = None +): + """ + Populate a Scheme instance with workflow read from an ows data stream. + + Parameters + ---------- + scheme: Scheme + stream: typing.IO + registry: WidgetRegistry + error_handler: Callable[[Exception], None] + warning_handler: Callable[[Warning], None] + data_deserializer: Callable[[AnyStr, str], Any] + """ desc = parse_ows_stream(stream) # type: _scheme if registry is None: @@ -411,6 +483,12 @@ def scheme_load(scheme, stream, registry=None, error_handler=None): def error_handler(exc): raise exc + if warning_handler is None: + warning_handler = warnings.warn + + if data_deserializer is None: + data_deserializer = default_deserializer + desc = resolve_replaced(desc, registry) nodes_not_found = [] nodes = [] @@ -432,15 +510,35 @@ def error_handler(exc): w_desc, title=node_d.title, position=node_d.position) data = node_d.data - if data: + if data is not None: try: - properties = loads(data.data, data.format) - except Exception: - log.error("Could not load properties for %r.", node.title, - exc_info=True) + properties = data_deserializer(data.data, data.format) + except UnsupportedFormatError as err: + err.node = node + err.args = err.args + (node,) + error_handler(err) + if isinstance(err, UnsupportedPickleFormatError): + warning = PickleDataWarning( + "The file contains pickle data. The settings " + "for '{}' were not restored.".format(node_d.title) + ) + else: + warning = DeserializationWarning( + "Could not load properties for %r".format(node.title), + ) + warning.node = node + warning_handler(warning) + node.setProperty("__ows_data_deserialization_error", (type(err), err.args)) + except Exception as err: # pylint: disable=broad-except + error_handler(err) + warning = DeserializationWarning( + "Could not load properties for %r.", node.title + ) + warning.node = node + warning_handler(node) + node.setProperty("__ows_data_deserialization_error", (type(err), err.args)) else: node.properties = properties - nodes.append(node) nodes_by_id[node_d.id] = node @@ -499,127 +597,214 @@ def error_handler(exc): return scheme -def scheme_to_etree(scheme, data_format="literal", pickle_fallback=False): +def scheme_to_interm( + scheme: 'Scheme', + data_serializer: DataSerializerType = None, + error_handler: Callable[[Exception], None] = None, + warning_handler: Callable[[Warning], None] = None, +) -> _scheme: """ - Return an `xml.etree.ElementTree` representation of the `scheme`. + Return a workflow scheme in its intermediate representation for + serialization. """ + node_ids = {} # type: Dict[SchemeNode, str] + nodes = [] + links = [] + annotations = [] + window_presets = [] + + if warning_handler is None: + warning_handler = default_warning_handler + + if error_handler is None: + error_handler = default_error_handler + + if data_serializer is None: + data_serializer = default_serializer + + # Nodes + for node_id, node in enumerate(scheme.nodes): # type: SchemeNode + data_payload = None + try: + data_payload_ = data_serializer(node) + except Exception as err: + error_handler(err) + else: + if data_payload_ is not None: + assert len(data_payload_) == 2 + data_, format_ = data_payload_ + data_payload = _data(format_, data_) + desc = node.description + inode = _node( + id=str(node_id), + title=node.title, + name=node.description.name, + position=node.position, + qualified_name=desc.qualified_name, + project_name=desc.project_name or "", + version=desc.version or "", + data=data_payload, + ) + node_ids[node] = str(node_id) + nodes.append(inode) + + for link_id, link in enumerate(scheme.links): + ilink = _link( + id=str(link_id), + source_node_id=node_ids[link.source_node], + source_channel=link.source_channel.name, + sink_node_id=node_ids[link.sink_node], + sink_channel=link.sink_channel.name, + enabled=link.enabled, + ) + links.append(ilink) + + for annot_id, annot in enumerate(scheme.annotations): + if isinstance(annot, SchemeTextAnnotation): + atype = "text" + params = _text_params( + geometry=annot.geometry, + text=annot.text, + content_type=annot.content_type, + font={}, # deprecated. + ) + elif isinstance(annot, SchemeArrowAnnotation): + atype = "arrow" + params = _arrow_params( + geometry=annot.geometry, + color=annot.color, + ) + else: + assert False + + iannot = _annotation( + str(annot_id), type=atype, params=params, + ) + annotations.append(iannot) + + for preset in scheme.window_group_presets(): # type: Scheme.WindowGroup + state = [(node_ids[n], state) for n, state in preset.state] + window_presets.append( + _window_group(preset.name, preset.default, state) + ) + + return _scheme( + scheme.title, "2.0", scheme.description, nodes, links, annotations, + session_state=_session_data(window_presets), + ) + + +def scheme_to_etree_2_0( + scheme: 'Scheme', + data_serializer=None, + **kwargs +): + return interm_to_etree_2_0( + scheme_to_interm(scheme, data_serializer=data_serializer, **kwargs) + ) + + +def interm_to_etree_2_0(scheme: _scheme) -> ElementTree: builder = TreeBuilder(element_factory=Element) - builder.start("scheme", {"version": "2.0", - "title": scheme.title or "", - "description": scheme.description or ""}) + builder.start( + "scheme", { + "version": "2.0", + "title": scheme.title, + "description": scheme.description, + } + ) # Nodes - node_ids = defaultdict(lambda c=itertools.count(): next(c)) builder.start("nodes", {}) - for node in scheme.nodes: # type: SchemeNode - desc = node.description - attrs = {"id": str(node_ids[node]), - "name": desc.name, - "qualified_name": desc.qualified_name, - "project_name": desc.project_name or "", - "version": desc.version or "", - "title": node.title, - } - if node.position is not None: - attrs["position"] = str(node.position) - - if type(node) is not SchemeNode: - attrs["scheme_node_type"] = "%s.%s" % (type(node).__name__, - type(node).__module__) - builder.start("node", attrs) + for node in scheme.nodes: # type: _node + builder.start( + "node", { + "id": node.id, + "name": node.name, + "qualified_name": node.qualified_name, + "project_name": node.project_name, + "version": node.version, + "title": node.title, + "position": node.position, + } + ) builder.end("node") builder.end("nodes") # Links - link_ids = defaultdict(lambda c=itertools.count(): next(c)) builder.start("links", {}) for link in scheme.links: - source = link.source_node - sink = link.sink_node - source_id = node_ids[source] - sink_id = node_ids[sink] - attrs = {"id": str(link_ids[link]), - "source_node_id": str(source_id), - "sink_node_id": str(sink_id), - "source_channel": link.source_channel.name, - "sink_channel": link.sink_channel.name, - "enabled": "true" if link.enabled else "false", - } - builder.start("link", attrs) + builder.start( + "link", { + "id": link.id, + "source_node_id": link.source_node_id, + "sink_node_id": link.sink_node_id, + "source_channel": link.source_channel, + "sink_channel": link.sink_channel, + "enabled": "true" if link.enabled else "false", + } + ) builder.end("link") - builder.end("links") # Annotations - annotation_ids = defaultdict(lambda c=itertools.count(): next(c)) builder.start("annotations", {}) for annotation in scheme.annotations: - annot_id = annotation_ids[annotation] - attrs = {"id": str(annot_id)} - data = None - if isinstance(annotation, SchemeTextAnnotation): + attrs = {"id": annotation.id} + if annotation.type == "text": tag = "text" - attrs.update({"type": annotation.content_type}) - attrs.update({"rect": repr(annotation.rect)}) - - # Save the font attributes - font = annotation.font - attrs.update({"font-family": font.get("family", None), - "font-size": font.get("size", None)}) - attrs = [(key, value) for key, value in attrs.items() - if value is not None] - attrs = dict((key, str(value)) for key, value in attrs) - data = annotation.content - elif isinstance(annotation, SchemeArrowAnnotation): + params = annotation.params # type: _text_params + assert isinstance(params, _text_params) + attrs.update({ + "type": params.content_type, + "rect": "{!r}, {!r}, {!r}, {!r}".format(*params.geometry) + }) + data = params.text + elif annotation.type == "arrow": tag = "arrow" - attrs.update({"start": repr(annotation.start_pos), - "end": repr(annotation.end_pos), - "fill": annotation.color}) + params = annotation.params # type: _arrow_params + start, end = params.geometry + attrs.update({ + "start": "{!r}, {!r}".format(*start), + "end": "{!r}, {!r}".format(*end), + "fill": params.color + }) data = None else: log.warning("Can't save %r", annotation) continue - builder.start(tag, attrs) + builder.start(annotation.type, attrs) if data is not None: builder.data(data) builder.end(tag) builder.end("annotations") - builder.start("thumbnail", {}) - builder.end("thumbnail") - # Node properties/settings builder.start("node_properties", {}) for node in scheme.nodes: - data = None - if node.properties: - try: - data, format = dumps(node.properties, format=data_format, - pickle_fallback=pickle_fallback) - except Exception: - log.error("Error serializing properties for node %r", - node.title, exc_info=True) - if data is not None: - builder.start("properties", - {"node_id": str(node_ids[node]), - "format": format}) - builder.data(data) - builder.end("properties") + if node.data is not None: + data = node.data + builder.start( + "properties", { + "node_id": node.id, + "format": data.format + } + ) + builder.data(data.data) + builder.end("properties") builder.end("node_properties") builder.start("session_state", {}) builder.start("window_groups", {}) - for g in scheme.window_group_presets(): + for g in scheme.session_state.groups: # type: _window_group builder.start( "group", {"name": g.name, "default": str(g.default).lower()} ) - for node, data in g.state: - if node not in node_ids: - continue - builder.start("window_state", {"node_id": str(node_ids[node])}) + for node_id, data in g.state: + builder.start("window_state", {"node_id": node_id}) builder.data(base64.encodebytes(data).decode("ascii")) builder.end("window_state") builder.end("group") @@ -631,7 +816,21 @@ def scheme_to_etree(scheme, data_format="literal", pickle_fallback=False): return tree -def scheme_to_ows_stream(scheme, stream, pretty=False, pickle_fallback=False): +def scheme_to_etree(scheme, data_format="literal", pickle_fallback=False): + """ + Return an `xml.etree.ElementTree` representation of the `scheme`. + """ + if pickle_fallback: + data_serializer = default_serializer_with_pickle_fallback + else: + data_serializer = default_serializer + data_serializer = partial(data_serializer, data_format=data_format) + return interm_to_etree_2_0( + scheme_to_interm(scheme, data_serializer=data_serializer) + ) + + +def scheme_to_ows_stream(scheme, stream, pretty=False, pickle_fallback=False, data_serializer=None, **kwargs): """ Write scheme to a a stream in Orange Scheme .ows (v 2.0) format. @@ -644,13 +843,21 @@ def scheme_to_ows_stream(scheme, stream, pretty=False, pickle_fallback=False): pretty : bool, optional If `True` the output xml will be pretty printed (indented). pickle_fallback : bool, optional - If `True` allow scheme node properties to be saves using pickle + If `True` allow scheme node properties to be saved using pickle protocol if properties cannot be saved using the default notation. """ - tree = scheme_to_etree(scheme, data_format="literal", - pickle_fallback=pickle_fallback) + if pickle_fallback is not False and data_serializer is not None: + raise TypeError("pickle_fallback and data_serializer are mutually " + "exclusive parameters") + if data_serializer is None: + if pickle_fallback: + data_serializer = default_serializer_with_pickle_fallback + else: + data_serializer = default_serializer + + tree = scheme_to_etree_2_0(scheme, data_serializer=data_serializer, **kwargs) if pretty: indent(tree.getroot(), 0) tree.write(stream, encoding="utf-8", xml_declaration=True) @@ -685,51 +892,46 @@ def indent_(element, level, last): return indent_(element, level, True) -def dumps(obj, format="literal", prettyprint=False, pickle_fallback=False): - """ - Serialize `obj` using `format` ('json' or 'literal') and return its - string representation and the used serialization format ('literal', - 'json' or 'pickle'). +class UnsupportedFormatError(ValueError): + node = None # type: Optional[SchemeNode] - If `pickle_fallback` is True and the serialization with `format` - fails object's pickle representation will be returned - """ - if format == "literal": - try: - return (literal_dumps(obj, indent=1 if prettyprint else None), - "literal") - except (ValueError, TypeError) as ex: - if not pickle_fallback: - raise +class UnsupportedPickleFormatError(UnsupportedFormatError): ... - log.warning("Could not serialize to a literal string", - exc_info=True) - elif format == "json": - try: - return (json.dumps(obj, indent=1 if prettyprint else None), - "json") - except (ValueError, TypeError): - if not pickle_fallback: - raise +class UnserializableValueError(ValueError): + node = None # type: Optional[SchemeNode] - log.warning("Could not serialize to a json string", - exc_info=True) - elif format == "pickle": - return base64.encodebytes(pickle.dumps(obj, protocol=PICKLE_PROTOCOL)). \ - decode('ascii'), "pickle" +class UnserializableTypeError(TypeError): + node = None # type: Optional[SchemeNode] - else: - raise ValueError("Unsupported format %r" % format) - if pickle_fallback: - log.warning("Using pickle fallback") - return base64.encodebytes(pickle.dumps(obj, protocol=PICKLE_PROTOCOL)). \ - decode('ascii'), "pickle" +def dumps(obj, format="literal", indent=4): + """ + Serialize `obj` using `format` ('json' or 'literal') and return its + string representation. + + Raises + ------ + TypeError + If object is not a supported type for serialization format + ValueError + If object is a recursive structure + """ + if format == "literal": + return literal_dumps(obj, indent=indent) + elif format == "json": + try: + return json.dumps(obj, indent=indent) + except TypeError as e: + raise UnserializableTypeError(*e.args) from e + except ValueError as e: + raise UnserializableValueError(*e.args) from e + elif format == "pickle": + raise UnsupportedPickleFormatError() else: - raise Exception("Something strange happened.") + raise UnsupportedFormatError("Unsupported format %r" % format) def loads(string, format): @@ -738,9 +940,9 @@ def loads(string, format): elif format == "json": return json.loads(string) elif format == "pickle": - return pickle.loads(base64.decodebytes(string.encode('ascii'))) + raise UnsupportedPickleFormatError() else: - raise ValueError("Unknown format") + raise UnsupportedFormatError("Unsupported format %r" % format) # This is a subset of PyON serialization. @@ -759,8 +961,8 @@ def literal_dumps(obj, indent=None, relaxed_types=True): indent : Optional[int] If not None then it is the indent for the pretty printer. relaxed_types : bool - Relaxed type checking. In addition to exact builtin numberic types, - the numbers.Integer, numbers.Real are checked and alowed if their + Relaxed type checking. In addition to exact builtin numeric types, + the numbers.Integer, numbers.Real are checked and allowed if their repr matches that of the builtin. .. warning:: The exact type of the values will be lost. @@ -804,16 +1006,16 @@ def check(obj): return all(map(check, obj)) elif type(obj) in builtins_mapping: return all(map(check, chain(obj.keys(), obj.values()))) - else: - raise TypeError("{0} can not be serialized as a python " - "literal".format(type(obj))) + + raise UnserializableTypeError("{0} can not be serialized as a python" + "literal".format(type(obj))) def check_relaxed(obj): if type(obj) in builtins: return True if id(obj) in memo: - raise ValueError("{0} is a recursive structure".format(obj)) + raise UnserializableValueError("{0} is a recursive structure".format(obj)) memo[id(obj)] = obj @@ -831,8 +1033,8 @@ def check_relaxed(obj): if repr(obj) == repr(float(obj)): return True - raise TypeError("{0} can not be serialized as a python " - "literal".format(type(obj))) + raise UnserializableTypeError("{0} can not be serialized as a python " + "literal".format(type(obj))) if relaxed_types: check_relaxed(obj) @@ -847,4 +1049,4 @@ def check_relaxed(obj): literal_loads = literal_eval -from .scheme import Scheme # pylint: disable=all +from .scheme import Scheme # pylint: disable=wrong-import-position diff --git a/orangecanvas/scheme/tests/test_readwrite.py b/orangecanvas/scheme/tests/test_readwrite.py index c94b27a87..2b9ec208b 100644 --- a/orangecanvas/scheme/tests/test_readwrite.py +++ b/orangecanvas/scheme/tests/test_readwrite.py @@ -2,16 +2,21 @@ Test read write """ import io +from functools import partial + from xml.etree import ElementTree as ET from ...gui import test from ...registry import WidgetRegistry, WidgetDescription, CategoryDescription from ...registry import tests as registry_tests +from ...registry import OutputSignal, InputSignal from .. import Scheme, SchemeNode, SchemeLink, \ SchemeArrowAnnotation, SchemeTextAnnotation from .. import readwrite +from ..readwrite import scheme_to_interm +from ...registry.tests import small_testing_registry class TestReadWrite(test.QAppTestCase): @@ -139,8 +144,102 @@ def test_resolve_replaced(self): projects = [node.project_name for node in parsed.nodes] self.assertSetEqual(set(projects), set(["Foo", "Bar"])) + def test_scheme_to_interm(self): + workflow = Scheme() + workflow.load_from( + io.BytesIO(FOOBAR_v20.encode()), + registry=foo_registry(with_replaces=False), + ) + + tree = ET.parse(io.BytesIO(FOOBAR_v20.encode())) + parsed = readwrite.parse_ows_etree_v_2_0(tree) + + interm = scheme_to_interm(workflow) + self.assertEqual(parsed, interm) + + def test_properties_serialize(self): + workflow = Scheme() + workflow.load_from( + io.BytesIO(FOOBAR_v20.encode()), + registry=foo_registry(with_replaces=True), + ) + n1, n2 = workflow.nodes + self.assertEqual(n1.properties, {"a": 1, "b": 2}) + self.assertEqual(n2.properties, {"a": 1, "b": 2}) + f = io.BytesIO() + workflow.save_to( + f, data_serializer=partial(readwrite.default_serializer, + data_format="json")) + f.seek(0) + scheme = readwrite.parse_ows_stream(f) + self.assertEqual(scheme.nodes[0].data.format, "json") + f.seek(0) + rl = [] + rl.append(rl) + n2.properties = {"a": {"b": rl}} + with self.assertRaises(readwrite.UnserializableValueError): + workflow.save_to(f) + + n2.properties = {"a": {"b": Obj()}} + + with self.assertRaises(readwrite.UnserializableTypeError): + workflow.save_to(f) + + def test_properties_serialize_pickle_fallback(self): + reg = small_testing_registry() + workflow = Scheme() + node = SchemeNode(reg.widget("one")) + workflow.add_node(node) + rl = [] + rl.append(rl) + node.properties = {"a": {"b": Obj()}} + f = io.BytesIO() + workflow.save_to(f, pickle_fallback=True) + contents = f.getvalue() + w1 = Scheme() + + with self.assertRaises(readwrite.UnsupportedPickleFormatError): + w1.load_from(io.BytesIO(contents), registry=reg) + + w1.clear() + w1.load_from( + io.BytesIO(contents), registry=reg, + data_deserializer=readwrite.default_deserializer_with_pickle_fallback + ) + self.assertEqual(node.properties, w1.nodes[0].properties) + + def test_properties_deserialize_error_handler(self): + reg = small_testing_registry() + workflow = Scheme() + node = SchemeNode(reg.widget("one")) + workflow.add_node(node) + node.properties = {"a": {"b": Obj()}} + f = io.BytesIO() + workflow.save_to(f, data_serializer=readwrite.default_serializer_with_pickle_fallback) + contents = f.getvalue() + workflow = Scheme() + + errors = [] + warnings = [] + workflow.load_from( + io.BytesIO(contents), registry=reg, error_handler=errors.append, + warning_handler=warnings.append, + ) + self.assertEqual(len(errors), 1) + self.assertEqual(len(warnings), 1) + + self.assertIsInstance(errors[0], readwrite.UnsupportedPickleFormatError) + self.assertIs(errors[0].node, workflow.nodes[0]) + self.assertIsInstance(warnings[0], readwrite.DeserializationWarning) + self.assertIs(warnings[0].node, workflow.nodes[0]) + -def foo_registry(): +class Obj: + def __eq__(self, other): + return isinstance(other, Obj) + + +def foo_registry(with_replaces=True): reg = WidgetRegistry() reg.register_category(CategoryDescription("Quack")) reg.register_widget( @@ -150,44 +249,36 @@ def foo_registry(): qualified_name="package.foo", project_name="Foo", category="Quack", + outputs=[ + OutputSignal("foo", "str"), + OutputSignal("foo1", "int"), + ] ) ) reg.register_widget( WidgetDescription( name="Bar", id="barrr", - qualified_name="frob.bar", + qualified_name="frob.bar" if with_replaces else "package.bar", project_name="Bar", - replaces=["package.bar"], + replaces=["package.bar"] if with_replaces else [], category="Quack", + inputs=[ + InputSignal("bar", "str", "bar"), + InputSignal("bar1", "int", "bar1"), + ] ) ) return reg -FOOBAR_v10 = """ - - - - - - - - - - - -""" - FOOBAR_v20 = """ - + qualified_name="package.foo" name="Foo" /> + + + Hello World + + + + {'a': 1, 'b': 2} + {'a': 1, 'b': 2} + """