diff --git a/src/aiida/storage/sqlite_zip/utils.py b/src/aiida/storage/sqlite_zip/utils.py index 304ef7fa0..2027685fd 100644 --- a/src/aiida/storage/sqlite_zip/utils.py +++ b/src/aiida/storage/sqlite_zip/utils.py @@ -11,7 +11,8 @@ import json import zipfile from pathlib import Path -from typing import Any, AnyStr, Dict, Optional, Union +from typing import Any, Dict, Optional, Union +from collections import deque from sqlalchemy import event from sqlalchemy.future.engine import Engine, create_engine @@ -48,30 +49,79 @@ def sqlite_case_sensitive_like(dbapi_connection, _): cursor.close() -def _contains(lhs: Union[dict, list], rhs: Union[dict, list]): - if isinstance(lhs, dict) and isinstance(rhs, dict): - for key in rhs: - if key not in lhs or not _contains(lhs[key], rhs[key]): - return False - return True - elif isinstance(lhs, list) and isinstance(rhs, list): - for item in rhs: - if not any(_contains(element, item) for element in lhs): - return False - return True - else: - return lhs == rhs - - -def _json_contains(lhs: Union[str, bytes, bytearray, dict, list], rhs: Union[str, bytes, bytearray, dict, list]): - try: - if isinstance(lhs, (str, bytes, bytearray)): +# def _contains(lhs: Union[dict, list], rhs: Union[dict, list]): +# if isinstance(lhs, dict) and isinstance(rhs, dict): +# for key in rhs: +# if key not in lhs or not _contains(lhs[key], rhs[key]): +# return False +# return True +# elif isinstance(lhs, list) and isinstance(rhs, list): +# for item in rhs: +# if not any(_contains(element, item) for element in lhs): +# return False +# return True +# else: +# return lhs == rhs + + +# def _json_contains(lhs: Union[str, bytes, bytearray, dict, list], rhs: Union[str, bytes, bytearray, dict, list]): +# try: +# if isinstance(lhs, (str, bytes, bytearray)): +# lhs = json.loads(lhs) +# if isinstance(rhs, (str, bytes, bytearray)): +# rhs = json.loads(rhs) +# except json.JSONDecodeError: +# return 0 +# return int(_contains(lhs, rhs)) + +def _json_contains(lhs: Union[str, bytes, bytearray, dict, list], + rhs: Union[str, bytes, bytearray, dict, list]) -> int: + if isinstance(lhs, (str, bytes, bytearray)): + try: lhs = json.loads(lhs) - if isinstance(rhs, (str, bytes, bytearray)): + except json.JSONDecodeError: + return 0 + if isinstance(rhs, (str, bytes, bytearray)): + try: rhs = json.loads(rhs) - except json.JSONDecodeError: - return 0 - return int(_contains(lhs, rhs)) + except json.JSONDecodeError: + return 0 + + stack = deque() + stack.append((lhs, rhs)) + while stack.count() > 0: + l, r = stack.popleft() + if isinstance(l, dict): + if not isinstance(r, dict): + return 0 + for key, value in r.items(): + if key not in l: return 0 + stack.append((l[key], value)) + elif isinstance(l, list): + if not isinstance(r, list): + return 0 + lp, lo = set(), [] + for e in l: + if isinstance(e, (dict, list)): + lo.append(e) + else: + lp.add(e) + rp, ro = set(), [] + for e in r: + if isinstance(e, (dict, list)): + ro.append(e) + else: + rp.add(e) + if not lp.issuperset(rp): + return 0 + for le in lo: + for re in ro: + stack.append((le, re)) + else: + return int(l == r) + return 1 + + def register_json_contains(dbapi_connection, _): diff --git a/tests/benchmark/test_json_contains.py b/tests/benchmark/test_json_contains.py index e95bd68b7..4d6f0e200 100644 --- a/tests/benchmark/test_json_contains.py +++ b/tests/benchmark/test_json_contains.py @@ -99,4 +99,3 @@ def test_large_table(benchmark, num_entries): qb.all() result = benchmark(qb.all) assert len(result) == num_entries -