Skip to content

Commit

Permalink
add benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
rabbull committed Nov 23, 2024
1 parent 598f821 commit 93ad037
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 76 deletions.
94 changes: 23 additions & 71 deletions src/aiida/storage/sqlite_zip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,79 +49,31 @@ def sqlite_case_sensitive_like(dbapi_connection, _):
cursor.close()


# def _contains(lhs: Union[dict, list], rhs: Union[dict, list]):
# if isinstance(lhs, dict) and isinstance(rhs, dict):
# for key in rhs:
# if key not in lhs or not _contains(lhs[key], rhs[key]):
# return False
# return True
# elif isinstance(lhs, list) and isinstance(rhs, list):
# for item in rhs:
# if not any(_contains(element, item) for element in lhs):
# return False
# return True
# else:
# return lhs == rhs


# def _json_contains(lhs: Union[str, bytes, bytearray, dict, list], rhs: Union[str, bytes, bytearray, dict, list]):
# try:
# if isinstance(lhs, (str, bytes, bytearray)):
# lhs = json.loads(lhs)
# if isinstance(rhs, (str, bytes, bytearray)):
# rhs = json.loads(rhs)
# except json.JSONDecodeError:
# return 0
# return int(_contains(lhs, rhs))

def _json_contains(lhs: Union[str, bytes, bytearray, dict, list],
rhs: Union[str, bytes, bytearray, dict, list]) -> int:
if isinstance(lhs, (str, bytes, bytearray)):
try:
lhs = json.loads(lhs)
except json.JSONDecodeError:
return 0
if isinstance(rhs, (str, bytes, bytearray)):
try:
rhs = json.loads(rhs)
except json.JSONDecodeError:
return 0

stack = deque()
stack.append((lhs, rhs))
while stack.count() > 0:
l, r = stack.popleft()
if isinstance(l, dict):
if not isinstance(r, dict):
return 0
for key, value in r.items():
if key not in l: return 0
stack.append((l[key], value))
elif isinstance(l, list):
if not isinstance(r, list):
return 0
lp, lo = set(), []
for e in l:
if isinstance(e, (dict, list)):
lo.append(e)
else:
lp.add(e)
rp, ro = set(), []
for e in r:
if isinstance(e, (dict, list)):
ro.append(e)
else:
rp.add(e)
if not lp.issuperset(rp):
return 0
for le in lo:
for re in ro:
stack.append((le, re))
else:
return int(l == r)
return 1
def _contains(lhs: Union[dict, list], rhs: Union[dict, list]):
if isinstance(lhs, dict) and isinstance(rhs, dict):
for key in rhs:
if key not in lhs or not _contains(lhs[key], rhs[key]):
return False
return True

elif isinstance(lhs, list) and isinstance(rhs, list):
for item in rhs:
if not any(_contains(e, item) for e in lhs):
return False
return True
else:
return lhs == rhs


def _json_contains(lhs: Union[str, bytes, bytearray, dict, list], rhs: Union[str, bytes, bytearray, dict, list]):
try:
if isinstance(lhs, (str, bytes, bytearray)):
lhs = json.loads(lhs)
if isinstance(rhs, (str, bytes, bytearray)):
rhs = json.loads(rhs)
except json.JSONDecodeError:
return 0
return int(_contains(lhs, rhs))


def register_json_contains(dbapi_connection, _):
Expand Down
31 changes: 26 additions & 5 deletions tests/benchmark/test_json_contains.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import functools

import pytest
import random
import string
Expand Down Expand Up @@ -56,10 +54,33 @@ def extract_component(data, p: float = -1):


@pytest.mark.benchmark(group=GROUP_NAME)
@pytest.mark.parametrize('depth', COMPLEX_JSON_DEPTH_RANGE)
@pytest.mark.parametrize('breadth', COMPLEX_JSON_BREADTH_RANGE)
@pytest.mark.parametrize('depth', [1, 2, 4, 8])
@pytest.mark.parametrize('breadth', [1, 2, 4])
@pytest.mark.usefixtures('aiida_profile_clean')
def test_deep_json(benchmark, depth, breadth):
lhs = gen_json(depth, breadth)
rhs = extract_component(lhs, p=1./depth)
assert 0 == len(QueryBuilder().append(orm.Dict).all())

orm.Dict({
'id': f'{depth}-{breadth}',
'data': lhs,
}).store()
qb = QueryBuilder().append(orm.Dict, filters={
'attributes.data': {'contains': rhs},
}, project=[
'attributes.id'
])
qb.all()
result = benchmark(qb.all)
assert len(result) == 1


@pytest.mark.benchmark(group=GROUP_NAME)
@pytest.mark.parametrize('depth', [2])
@pytest.mark.parametrize('breadth', [1, 10, 100])
@pytest.mark.usefixtures('aiida_profile_clean')
def test_complex_json(benchmark, depth, breadth):
def test_wide_json(benchmark, depth, breadth):
lhs = gen_json(depth, breadth)
rhs = extract_component(lhs, p=1./depth)
assert 0 == len(QueryBuilder().append(orm.Dict).all())
Expand Down

0 comments on commit 93ad037

Please sign in to comment.