Skip to content

Commit

Permalink
Merge branch master to orm: bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Jun 1, 2020
2 parents 2e83a3d + 09ed56f commit c9ca48a
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 18 deletions.
44 changes: 26 additions & 18 deletions pony/orm/sqltranslation.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,9 @@ def __init__(translator, tree, parent_translator, code_key=None, filter_num=None
try:
translator.init(tree, parent_translator, code_key, filter_num, extractors, vars, vartypes, left_join, optimize)
except UseAnotherTranslator as e:
assert local.translators
t = local.translators.pop()
assert t is e.translator
translator = e.translator
raise
else:
finally:
assert local.translators
t = local.translators.pop()
assert t is translator
Expand Down Expand Up @@ -901,18 +899,18 @@ def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames,
namespace = None
if namespace is not None:
translator.namespace_stack.append(namespace)

with translator:
try:
try:
with translator:
translator.dispatch(func_ast)
if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes
else: nodes = (func_ast,)
if order_by:
translator.inside_order_by = True
new_order = []
for node in nodes:
if isinstance(node.monad, SetMixin):
t = node.monad.type.item_type
monad = node.monad.to_single_cell_value()
if isinstance(monad, SetMixin):
t = monad.type.item_type
if isinstance(type(t), type): t = t.__name__
throw(TranslationError, 'Set of %s (%s) cannot be used for ordering'
% (t, ast2src(node)))
Expand All @@ -929,10 +927,10 @@ def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames,
else: translator.having_conditions.extend(m.getsql())
translator.vars = None
return translator
finally:
if namespace is not None:
ns = translator.namespace_stack.pop()
assert ns is namespace
finally:
if namespace is not None:
ns = translator.namespace_stack.pop()
assert ns is namespace
def preGenExpr(translator, node):
inner_tree = node.code
translator_cls = translator.__class__
Expand Down Expand Up @@ -1428,6 +1426,8 @@ def __init__(monad, type, nullable=True):
monad.mixin_init()
def mixin_init(monad):
pass
def to_single_cell_value(monad):
return monad
def cmp(monad, op, monad2):
return CmpMonad(op, monad, monad2)
def contains(monad, item, not_in=False): throw(TypeError)
Expand Down Expand Up @@ -2569,9 +2569,9 @@ def __call__(monad, *args, **kwargs):
root_translator.vartypes.update(func_vartypes)
root_translator.vars.update(func_vars)

func_ast = copy_ast(func_ast)
stack = translator.namespace_stack
stack.append(name_mapping)
func_ast = copy_ast(func_ast)
try:
prev_code_key = translator.code_key
translator.code_key = func_id
Expand All @@ -2584,7 +2584,8 @@ def __call__(monad, *args, **kwargs):
msg = e.args[0] + ' (inside %s)' % (monad.func_name)
e.args = (msg,)
raise
stack.pop()
finally:
stack.pop()
return func_ast.monad

class HybridMethodMonad(HybridFuncMonad):
Expand Down Expand Up @@ -2840,11 +2841,16 @@ class FuncCoalesceMonad(FuncMonad):
func = coalesce
def call(monad, *args):
if len(args) < 2: throw(TranslationError, 'coalesce() function requires at least two arguments')
arg = args[0]
arg = args[0].to_single_cell_value()
t = arg.type
result = [ [ sql ] for sql in arg.getsql() ]
for arg in args[1:]:
if arg.type is not t: throw(TypeError, 'All arguments of coalesce() function should have the same type')
arg = arg.to_single_cell_value()
if arg.type is not t:
t2 = coerce_types(t, arg.type)
if t2 is None:
throw(TypeError, 'All arguments of coalesce() function should have the same type')
t = t2
for i, sql in enumerate(arg.getsql()):
result[i].append(sql)
sql = [ [ 'COALESCE' ] + coalesce_args for coalesce_args in result ]
Expand Down Expand Up @@ -3354,6 +3360,8 @@ def __init__(monad, subtranslator):
monad.subtranslator = subtranslator
monad.item_type = item_type
monad.limit = monad.offset = None
def to_single_cell_value(monad):
return ExprMonad.new(monad.item_type, monad.getsql()[0])
def requires_distinct(monad, joined=False):
assert False
def call_limit(monad, limit=None, offset=None):
Expand Down Expand Up @@ -3561,7 +3569,7 @@ def call_group_concat(monad, sep=None, distinct=None):
throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % type(sep).__name__)
return monad.aggregate('GROUP_CONCAT', distinct, sep=sep)
def getsql(monad):
return monad.subtranslator.construct_subquery_ast(monad.limit, monad.offset)
return [ monad.subtranslator.construct_subquery_ast(monad.limit, monad.offset) ]

def find_or_create_having_ast(sections):
groupby_offset = None
Expand Down
181 changes: 181 additions & 0 deletions pony/orm/tests/test_prop_sum_orderby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import absolute_import, print_function, division

import unittest

from pony.orm.core import *
from pony.orm.tests.testutils import *
from pony.orm.tests import setup_database, teardown_database

db = Database()

db = Database('sqlite', ':memory:')

class Product(db.Entity):
id = PrimaryKey(int)
name = Required(str)
comments = Set('Comment')

@property
def sum_01(self):
return coalesce(select(c.points for c in self.comments).sum(), 0)

@property
def sum_02(self):
return coalesce(select(c.points for c in self.comments).sum(), 0.0)

@property
def sum_03(self):
return coalesce(select(sum(c.points) for c in self.comments), 0)

@property
def sum_04(self):
return coalesce(select(sum(c.points) for c in self.comments), 0.0)

@property
def sum_05(self):
return sum(c.points for c in self.comments)

@property
def sum_06(self):
return coalesce(sum(c.points for c in self.comments), 0)

@property
def sum_07(self):
return coalesce(sum(c.points for c in self.comments), 0.0)

@property
def sum_08(self):
return select(sum(c.points) for c in self.comments)

@property
def sum_09(self):
return select(coalesce(sum(c.points), 0) for c in self.comments)

@property
def sum_10(self):
return select(coalesce(sum(c.points), 0.0) for c in self.comments)

@property
def sum_11(self):
return select(sum(c.points) for c in self.comments)

@property
def sum_12(self):
return sum(self.comments.points)

@property
def sum_13(self):
return coalesce(sum(self.comments.points), 0)

@property
def sum_14(self):
return coalesce(sum(self.comments.points), 0.0)


class Comment(db.Entity):
id = PrimaryKey(int)
points = Required(int)
product = Optional('Product')


class TestQuerySetMonad(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_database(db)
with db_session:
p1 = Product(id=1, name='P1')
p2 = Product(id=2, name='P1', comments=[
Comment(id=201, points=5)
])
p3 = Product(id=3, name='P1', comments=[
Comment(id=301, points=1), Comment(id=302, points=2)
])
p4 = Product(id=4, name='P1', comments=[
Comment(id=401, points=1), Comment(id=402, points=5), Comment(id=403, points=1)
])

@classmethod
def tearDownClass(cls):
teardown_database(db)

def setUp(self):
rollback()
db_session.__enter__()

def tearDown(self):
rollback()
db_session.__exit__()

def test_sum_01(self):
q = list(Product.select().sort_by(lambda p: p.sum_01))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_02(self):
q = list(Product.select().sort_by(lambda p: p.sum_02))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_03(self):
q = list(Product.select().sort_by(lambda p: p.sum_03))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_04(self):
q = list(Product.select().sort_by(lambda p: p.sum_04))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_05(self):
q = list(Product.select().sort_by(lambda p: p.sum_05))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_06(self):
q = list(Product.select().sort_by(lambda p: p.sum_06))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_07(self):
q = list(Product.select().sort_by(lambda p: p.sum_07))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_08(self):
q = list(Product.select().sort_by(lambda p: p.sum_08))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_09(self):
q = list(Product.select().sort_by(lambda p: p.sum_09))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_10(self):
q = list(Product.select().sort_by(lambda p: p.sum_10))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_11(self):
q = list(Product.select().sort_by(lambda p: p.sum_11))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_12(self):
q = list(Product.select().sort_by(lambda p: p.sum_12))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_13(self):
q = list(Product.select().sort_by(lambda p: p.sum_13))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])

def test_sum_14(self):
q = list(Product.select().sort_by(lambda p: p.sum_14))
result = [p.id for p in q]
self.assertEqual(result, [1, 3, 2, 4])


if __name__ == "__main__":
unittest.main()

0 comments on commit c9ca48a

Please sign in to comment.