From 4b6a44f002ea2c09f4af991e7ebb82561803dd02 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 21 May 2020 16:15:30 +0300 Subject: [PATCH 1/7] Failed tests for coalesce & sum queries added --- pony/orm/tests/test_prop_sum_orderby.py | 181 ++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 pony/orm/tests/test_prop_sum_orderby.py diff --git a/pony/orm/tests/test_prop_sum_orderby.py b/pony/orm/tests/test_prop_sum_orderby.py new file mode 100644 index 000000000..7b36531aa --- /dev/null +++ b/pony/orm/tests/test_prop_sum_orderby.py @@ -0,0 +1,181 @@ +from __future__ import absolute_import, print_function, division + +import unittest + +from pony.orm.core import * +from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() + +db = Database('sqlite', ':memory:') + +class Product(db.Entity): + id = PrimaryKey(int) + name = Required(str) + comments = Set('Comment') + + @property + def sum_01(self): + return coalesce(select(c.points for c in self.comments).sum(), 0) + + @property + def sum_02(self): + return coalesce(select(c.points for c in self.comments).sum(), 0.0) + + @property + def sum_03(self): + return coalesce(select(sum(c.points) for c in self.comments), 0) + + @property + def sum_04(self): + return coalesce(select(sum(c.points) for c in self.comments), 0.0) + + @property + def sum_05(self): + return sum(c.points for c in self.comments) + + @property + def sum_06(self): + return coalesce(sum(c.points for c in self.comments), 0) + + @property + def sum_07(self): + return coalesce(sum(c.points for c in self.comments), 0.0) + + @property + def sum_08(self): + return select(sum(c.points) for c in self.comments) + + @property + def sum_09(self): + return select(coalesce(sum(c.points), 0) for c in self.comments) + + @property + def sum_10(self): + return select(coalesce(sum(c.points), 0.0) for c in self.comments) + + @property + def sum_11(self): + return select(sum(c.points) for c in self.comments) + + @property + def sum_12(self): + return sum(self.comments.points) + + @property + def sum_13(self): + return coalesce(sum(self.comments.points), 0) + + @property + def sum_14(self): + return coalesce(sum(self.comments.points), 0.0) + + +class Comment(db.Entity): + id = PrimaryKey(int) + points = Required(int) + product = Optional('Product') + + +class TestQuerySetMonad(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + p1 = Product(id=1, name='P1') + p2 = Product(id=2, name='P1', comments=[ + Comment(id=201, points=5) + ]) + p3 = Product(id=3, name='P1', comments=[ + Comment(id=301, points=1), Comment(id=302, points=2) + ]) + p4 = Product(id=4, name='P1', comments=[ + Comment(id=401, points=1), Comment(id=402, points=5), Comment(id=403, points=1) + ]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def setUp(self): + rollback() + db_session.__enter__() + + def tearDown(self): + rollback() + db_session.__exit__() + + def test_sum_01(self): + q = list(Product.select().sort_by(lambda p: p.sum_01)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_02(self): + q = list(Product.select().sort_by(lambda p: p.sum_02)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_03(self): + q = list(Product.select().sort_by(lambda p: p.sum_03)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_04(self): + q = list(Product.select().sort_by(lambda p: p.sum_04)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_05(self): + q = list(Product.select().sort_by(lambda p: p.sum_05)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_06(self): + q = list(Product.select().sort_by(lambda p: p.sum_06)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_07(self): + q = list(Product.select().sort_by(lambda p: p.sum_07)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_08(self): + q = list(Product.select().sort_by(lambda p: p.sum_08)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_09(self): + q = list(Product.select().sort_by(lambda p: p.sum_09)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_10(self): + q = list(Product.select().sort_by(lambda p: p.sum_10)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_11(self): + q = list(Product.select().sort_by(lambda p: p.sum_11)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_12(self): + q = list(Product.select().sort_by(lambda p: p.sum_12)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_13(self): + q = list(Product.select().sort_by(lambda p: p.sum_13)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + def test_sum_14(self): + q = list(Product.select().sort_by(lambda p: p.sum_14)) + result = [p.id for p in q] + self.assertEqual(result, [1, 3, 2, 4]) + + +if __name__ == "__main__": + unittest.main() From 02618a4425808c169abb17867ee581b58a4c6fef Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 May 2020 13:37:05 +0300 Subject: [PATCH 2/7] Fix incorrect namespace stack handling --- pony/orm/sqltranslation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 4ad1a2af7..a4596fc55 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2569,9 +2569,9 @@ def __call__(monad, *args, **kwargs): root_translator.vartypes.update(func_vartypes) root_translator.vars.update(func_vars) + func_ast = copy_ast(func_ast) stack = translator.namespace_stack stack.append(name_mapping) - func_ast = copy_ast(func_ast) try: prev_code_key = translator.code_key translator.code_key = func_id @@ -2584,7 +2584,8 @@ def __call__(monad, *args, **kwargs): msg = e.args[0] + ' (inside %s)' % (monad.func_name) e.args = (msg,) raise - stack.pop() + finally: + stack.pop() return func_ast.monad class HybridMethodMonad(HybridFuncMonad): From bcfec2cc50d5725a54643610453577e14e4c3d8e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 May 2020 14:09:10 +0300 Subject: [PATCH 3/7] Allow mixing compatible types in coalesce arguments --- pony/orm/sqltranslation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a4596fc55..fc801f4c2 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2845,7 +2845,10 @@ def call(monad, *args): t = arg.type result = [ [ sql ] for sql in arg.getsql() ] for arg in args[1:]: - if arg.type is not t: throw(TypeError, 'All arguments of coalesce() function should have the same type') + if arg.type is not t: + t = coerce_types(t, arg.type) + if t is None: + throw(TypeError, 'All arguments of coalesce() function should have the same type') for i, sql in enumerate(arg.getsql()): result[i].append(sql) sql = [ [ 'COALESCE' ] + coalesce_args for coalesce_args in result ] From 194fd4b7c8b972e5dec82ca08aeb98785dd12391 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 May 2020 14:10:31 +0300 Subject: [PATCH 4/7] Micro refactoring --- pony/orm/sqltranslation.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index fc801f4c2..cdba84d47 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -901,9 +901,8 @@ def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames, namespace = None if namespace is not None: translator.namespace_stack.append(namespace) - - with translator: - try: + try: + with translator: translator.dispatch(func_ast) if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes else: nodes = (func_ast,) @@ -929,10 +928,10 @@ def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames, else: translator.having_conditions.extend(m.getsql()) translator.vars = None return translator - finally: - if namespace is not None: - ns = translator.namespace_stack.pop() - assert ns is namespace + finally: + if namespace is not None: + ns = translator.namespace_stack.pop() + assert ns is namespace def preGenExpr(translator, node): inner_tree = node.code translator_cls = translator.__class__ From ed379264afb23e6882ea6b030ea5ef389e6e5a36 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 May 2020 14:58:22 +0300 Subject: [PATCH 5/7] Fix using subqueries in coalesce --- pony/orm/sqltranslation.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index cdba84d47..3c7dab4da 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1427,6 +1427,8 @@ def __init__(monad, type, nullable=True): monad.mixin_init() def mixin_init(monad): pass + def to_single_cell_value(monad): + return monad def cmp(monad, op, monad2): return CmpMonad(op, monad, monad2) def contains(monad, item, not_in=False): throw(TypeError) @@ -2840,14 +2842,16 @@ class FuncCoalesceMonad(FuncMonad): func = coalesce def call(monad, *args): if len(args) < 2: throw(TranslationError, 'coalesce() function requires at least two arguments') - arg = args[0] + arg = args[0].to_single_cell_value() t = arg.type result = [ [ sql ] for sql in arg.getsql() ] for arg in args[1:]: + arg = arg.to_single_cell_value() if arg.type is not t: - t = coerce_types(t, arg.type) - if t is None: + t2 = coerce_types(t, arg.type) + if t2 is None: throw(TypeError, 'All arguments of coalesce() function should have the same type') + t = t2 for i, sql in enumerate(arg.getsql()): result[i].append(sql) sql = [ [ 'COALESCE' ] + coalesce_args for coalesce_args in result ] @@ -3357,6 +3361,8 @@ def __init__(monad, subtranslator): monad.subtranslator = subtranslator monad.item_type = item_type monad.limit = monad.offset = None + def to_single_cell_value(monad): + return ExprMonad.new(monad.item_type, monad.getsql()[0]) def requires_distinct(monad, joined=False): assert False def call_limit(monad, limit=None, offset=None): @@ -3564,7 +3570,7 @@ def call_group_concat(monad, sep=None, distinct=None): throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % type(sep).__name__) return monad.aggregate('GROUP_CONCAT', distinct, sep=sep) def getsql(monad): - return monad.subtranslator.construct_subquery_ast(monad.limit, monad.offset) + return [ monad.subtranslator.construct_subquery_ast(monad.limit, monad.offset) ] def find_or_create_having_ast(sections): groupby_offset = None From 2f635b667031ab374adbedc4834dccdc95dc65d7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 May 2020 19:14:57 +0300 Subject: [PATCH 6/7] Fix using aggregated subqueries in `order by` section --- pony/orm/sqltranslation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 3c7dab4da..7f11720ff 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -910,8 +910,9 @@ def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames, translator.inside_order_by = True new_order = [] for node in nodes: - if isinstance(node.monad, SetMixin): - t = node.monad.type.item_type + monad = node.monad.to_single_cell_value() + if isinstance(monad, SetMixin): + t = monad.type.item_type if isinstance(type(t), type): t = t.__name__ throw(TranslationError, 'Set of %s (%s) cannot be used for ordering' % (t, ast2src(node))) From 09ed56fa2ed36e5deae1cf1aeb17bdecc5191816 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 May 2020 22:42:45 +0300 Subject: [PATCH 7/7] Fix handling translator stack when exception is occurred --- pony/orm/sqltranslation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 7f11720ff..18a48624a 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -221,11 +221,9 @@ def __init__(translator, tree, parent_translator, code_key=None, filter_num=None try: translator.init(tree, parent_translator, code_key, filter_num, extractors, vars, vartypes, left_join, optimize) except UseAnotherTranslator as e: - assert local.translators - t = local.translators.pop() - assert t is e.translator + translator = e.translator raise - else: + finally: assert local.translators t = local.translators.pop() assert t is translator