Skip to content

Commit

Permalink
Merge master to orm: bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Jun 1, 2020
2 parents ecd598a + 4e04978 commit 52720af
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 52 deletions.
15 changes: 10 additions & 5 deletions pony/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,11 @@ def copy(self):
result.entities_to_prefetch = self.entities_to_prefetch.copy()
return result
def __enter__(self):
assert local.prefetch_context is None
local.prefetch_context = self
local.prefetch_context_stack.append(self)
def __exit__(self, exc_type, exc_val, exc_tb):
assert local.prefetch_context is self
local.prefetch_context = None
stack = local.prefetch_context_stack
assert stack and stack[-1] is self
stack.pop()
def get_frozen_attrs_to_prefetch(self, entity):
attrs_to_prefetch = self.attrs_to_prefetch_dict.get(entity, ())
if type(attrs_to_prefetch) is set:
Expand All @@ -329,11 +329,16 @@ def __init__(local):
local.db2cache = {}
local.db_context_counter = 0
local.db_session = None
local.prefetch_context = None
local.prefetch_context_stack = []
local.current_user = None
local.perms_context = None
local.user_groups_cache = {}
local.user_roles_cache = defaultdict(dict)
@property
def prefetch_context(local):
if local.prefetch_context_stack:
return local.prefetch_context_stack[-1]
return None
def push_debug_state(local, debug, show_values):
local.debug_stack.append((local.debug, local.show_values))
if not suppress_debug_change:
Expand Down
72 changes: 25 additions & 47 deletions pony/orm/tests/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def set_hooks_to_do_nothing():
set_hooks_to_do_nothing()


def flush_for(*objects):
for obj in objects:
obj.flush()


class TestHooks(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -79,15 +84,23 @@ def tearDown(self):
pass

@db_session
def test_1(self):
def test_1a(self):
p4 = Person(id=4, name='Bob', age=16)
p5 = Person(id=5, name='Lucy', age=23)
self.assertEqual(logged_events, [])
db.flush()
self.assertEqual(logged_events, ['BI_Bob', 'BI_Lucy', 'AI_Bob', 'AI_Lucy'])

@db_session
def test_2(self):
def test_1b(self):
p4 = Person(id=4, name='Bob', age=16)
p5 = Person(id=5, name='Lucy', age=23)
self.assertEqual(logged_events, [])
flush_for(p4, p5)
self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob', 'BI_Lucy', 'AI_Lucy'])

@db_session
def test_2a(self):
p4 = Person(id=4, name='Bob', age=16)
p1 = Person[1] # auto-flush here
p2 = Person[2]
Expand All @@ -98,50 +111,7 @@ def test_2(self):
self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob', 'BU_Mary', 'BI_Lucy', 'AU_Mary', 'AI_Lucy'])

@db_session
def test_3(self):
global do_before_insert
def do_before_insert(person):
some_person = Person.select().first() # should not cause infinite recursion
p4 = Person(id=4, name='Bob', age=16)
db.flush()


def flush_for(*objects):
for obj in objects:
obj.flush()


class ObjectFlushTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_database(db)

@classmethod
def tearDownClass(cls):
teardown_database(db)

def setUp(self):
set_hooks_to_do_nothing()
with db_session:
db.execute('delete from Person')
p1 = Person(id=1, name='John', age=22)
p2 = Person(id=2, name='Mary', age=18)
p3 = Person(id=3, name='Mike', age=25)
logged_events[:] = []

def tearDown(self):
pass

@db_session
def test_1(self):
p4 = Person(id=4, name='Bob', age=16)
p5 = Person(id=5, name='Lucy', age=23)
self.assertEqual(logged_events, [])
flush_for(p4, p5)
self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob', 'BI_Lucy', 'AI_Lucy'])

@db_session
def test_2(self):
def test_2b(self):
p4 = Person(id=4, name='Bob', age=16)
p1 = Person[1] # auto-flush here
p2 = Person[2]
Expand All @@ -157,7 +127,15 @@ def test_3(self):
def do_before_insert(person):
some_person = Person.select().first() # should not cause infinite recursion
p4 = Person(id=4, name='Bob', age=16)
p4.flush()
db.flush()

@db_session
def test_4(self):
global do_before_insert
def do_before_insert(person):
some_person = Person.select().first() # creates nested prefetch_context
p4 = Person(id=4, name='Bob', age=16)
Person.select().first()


if __name__ == '__main__':
Expand Down

0 comments on commit 52720af

Please sign in to comment.