Skip to content

Commit

Permalink
Merge branch master to orm: Entity.select(**kwargs)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Jun 1, 2020
2 parents c9ca48a + 0910f76 commit ecd598a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
16 changes: 11 additions & 5 deletions pony/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3572,7 +3572,7 @@ def clear(wrapper):
def load(wrapper):
wrapper._attr_.load(wrapper._obj_)
@cut_traceback
def select(wrapper, *args):
def select(wrapper, *args, **kwargs):
obj = wrapper._obj_
if obj._status_ in del_statuses: throw_object_was_deleted(obj)
attr = wrapper._attr_
Expand All @@ -3581,8 +3581,10 @@ def select(wrapper, *args):
s = 'lambda item: JOIN(obj in item.%s)' if reverse.is_collection else 'lambda item: item.%s == obj'
query = query.filter(s % reverse.name, {'obj' : obj, 'JOIN': JOIN})
if args:
func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=cut_traceback_depth+1)
func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=cut_traceback_depth+1)
query = query.filter(func, globals, locals)
if kwargs:
query = query._apply_kwargs(kwargs)
return query
filter = select
def limit(wrapper, limit=None, offset=None):
Expand Down Expand Up @@ -4020,8 +4022,12 @@ def get_by_sql(entity, sql, globals=None, locals=None):
assert len(objects) == 1
return objects[0]
@cut_traceback
def select(entity, *args):
return entity._query_from_args_(args, kwargs=None, frame_depth=cut_traceback_depth+1)
def select(entity, *args, **kwargs):
if args: query = entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1)
else:
query = entity._select_all()
if kwargs: query = query._apply_kwargs(kwargs)
return query
@cut_traceback
def select_by_sql(entity, sql, globals=None, locals=None):
return entity._find_by_sql_(None, sql, globals, locals, frame_depth=cut_traceback_depth+1)
Expand Down Expand Up @@ -4363,7 +4369,7 @@ def _load_many_(entity, objects):
def _select_all(entity):
return Query(entity._default_iter_name_, entity._default_genexpr_, {}, { '.0' : entity })
def _query_from_args_(entity, args, kwargs, frame_depth):
if not args and not kwargs: return entity._select_all()
assert args
func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth+1)

if type(func) is types.FunctionType:
Expand Down
28 changes: 28 additions & 0 deletions pony/orm/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,34 @@ def test22(self):
def test23(self):
r = max(s.dob.year for s in Student)
self.assertEqual(r, 2001)
def test_select_kwargs_1(self):
r = Student.select(scholarship=200)[:]
self.assertEqual(r, [Student[3]])
def test_select_kwargs_1a(self):
g = Group[1]
r = g.students.select(scholarship=200)[:]
self.assertEqual(r, [Student[3]])
def test_select_kwargs_2(self):
r = Student.select(scholarship=1000)[:]
self.assertEqual(r, [])
def test_select_kwargs_2a(self):
g = Group[1]
r = g.students.select(scholarship=1000)[:]
self.assertEqual(r, [])
def test_select_kwargs_3(self):
r = Student.select(group=Group[1])[:]
self.assertEqual(set(r), {Student[1], Student[2], Student[3]})
def test_select_kwargs_3a(self):
g = Group[1]
r = g.students.select(group=g)[:]
self.assertEqual(set(r), {Student[1], Student[2], Student[3]})
def test_select_kwargs_4(self):
r = Student.select(group=Group[1], scholarship=200)[:]
self.assertEqual(r, [Student[3]])
def test_select_kwargs_4a(self):
g = Group[1]
r = g.students.select(group=g, scholarship=200)[:]
self.assertEqual(r, [Student[3]])
def test_first1(self):
q = select(s for s in Student).order_by(Student.gpa)
self.assertEqual(q.first(), Student[1])
Expand Down

0 comments on commit ecd598a

Please sign in to comment.