diff --git a/src/aiida/storage/sqlite_zip/utils.py b/src/aiida/storage/sqlite_zip/utils.py index 2027685fd..c2a2a0ace 100644 --- a/src/aiida/storage/sqlite_zip/utils.py +++ b/src/aiida/storage/sqlite_zip/utils.py @@ -49,79 +49,31 @@ def sqlite_case_sensitive_like(dbapi_connection, _): cursor.close() -# def _contains(lhs: Union[dict, list], rhs: Union[dict, list]): -# if isinstance(lhs, dict) and isinstance(rhs, dict): -# for key in rhs: -# if key not in lhs or not _contains(lhs[key], rhs[key]): -# return False -# return True -# elif isinstance(lhs, list) and isinstance(rhs, list): -# for item in rhs: -# if not any(_contains(element, item) for element in lhs): -# return False -# return True -# else: -# return lhs == rhs - - -# def _json_contains(lhs: Union[str, bytes, bytearray, dict, list], rhs: Union[str, bytes, bytearray, dict, list]): -# try: -# if isinstance(lhs, (str, bytes, bytearray)): -# lhs = json.loads(lhs) -# if isinstance(rhs, (str, bytes, bytearray)): -# rhs = json.loads(rhs) -# except json.JSONDecodeError: -# return 0 -# return int(_contains(lhs, rhs)) - -def _json_contains(lhs: Union[str, bytes, bytearray, dict, list], - rhs: Union[str, bytes, bytearray, dict, list]) -> int: - if isinstance(lhs, (str, bytes, bytearray)): - try: - lhs = json.loads(lhs) - except json.JSONDecodeError: - return 0 - if isinstance(rhs, (str, bytes, bytearray)): - try: - rhs = json.loads(rhs) - except json.JSONDecodeError: - return 0 - - stack = deque() - stack.append((lhs, rhs)) - while stack.count() > 0: - l, r = stack.popleft() - if isinstance(l, dict): - if not isinstance(r, dict): - return 0 - for key, value in r.items(): - if key not in l: return 0 - stack.append((l[key], value)) - elif isinstance(l, list): - if not isinstance(r, list): - return 0 - lp, lo = set(), [] - for e in l: - if isinstance(e, (dict, list)): - lo.append(e) - else: - lp.add(e) - rp, ro = set(), [] - for e in r: - if isinstance(e, (dict, list)): - ro.append(e) - else: - rp.add(e) - if not lp.issuperset(rp): - return 0 - for le in lo: - for re in ro: - stack.append((le, re)) - else: - return int(l == r) - return 1 +def _contains(lhs: Union[dict, list], rhs: Union[dict, list]): + if isinstance(lhs, dict) and isinstance(rhs, dict): + for key in rhs: + if key not in lhs or not _contains(lhs[key], rhs[key]): + return False + return True + + elif isinstance(lhs, list) and isinstance(rhs, list): + for item in rhs: + if not any(_contains(e, item) for e in lhs): + return False + return True + else: + return lhs == rhs +def _json_contains(lhs: Union[str, bytes, bytearray, dict, list], rhs: Union[str, bytes, bytearray, dict, list]): + try: + if isinstance(lhs, (str, bytes, bytearray)): + lhs = json.loads(lhs) + if isinstance(rhs, (str, bytes, bytearray)): + rhs = json.loads(rhs) + except json.JSONDecodeError: + return 0 + return int(_contains(lhs, rhs)) def register_json_contains(dbapi_connection, _): diff --git a/tests/benchmark/test_json_contains.py b/tests/benchmark/test_json_contains.py index 4d6f0e200..40afd515b 100644 --- a/tests/benchmark/test_json_contains.py +++ b/tests/benchmark/test_json_contains.py @@ -1,5 +1,3 @@ -import functools - import pytest import random import string @@ -56,10 +54,33 @@ def extract_component(data, p: float = -1): @pytest.mark.benchmark(group=GROUP_NAME) -@pytest.mark.parametrize('depth', COMPLEX_JSON_DEPTH_RANGE) -@pytest.mark.parametrize('breadth', COMPLEX_JSON_BREADTH_RANGE) +@pytest.mark.parametrize('depth', [1, 2, 4, 8]) +@pytest.mark.parametrize('breadth', [1, 2, 4]) +@pytest.mark.usefixtures('aiida_profile_clean') +def test_deep_json(benchmark, depth, breadth): + lhs = gen_json(depth, breadth) + rhs = extract_component(lhs, p=1./depth) + assert 0 == len(QueryBuilder().append(orm.Dict).all()) + + orm.Dict({ + 'id': f'{depth}-{breadth}', + 'data': lhs, + }).store() + qb = QueryBuilder().append(orm.Dict, filters={ + 'attributes.data': {'contains': rhs}, + }, project=[ + 'attributes.id' + ]) + qb.all() + result = benchmark(qb.all) + assert len(result) == 1 + + +@pytest.mark.benchmark(group=GROUP_NAME) +@pytest.mark.parametrize('depth', [2]) +@pytest.mark.parametrize('breadth', [1, 10, 100]) @pytest.mark.usefixtures('aiida_profile_clean') -def test_complex_json(benchmark, depth, breadth): +def test_wide_json(benchmark, depth, breadth): lhs = gen_json(depth, breadth) rhs = extract_component(lhs, p=1./depth) assert 0 == len(QueryBuilder().append(orm.Dict).all())