diff --git a/pony/orm/core.py b/pony/orm/core.py index b514277b2..e89f94cd7 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -298,11 +298,11 @@ def copy(self): result.entities_to_prefetch = self.entities_to_prefetch.copy() return result def __enter__(self): - assert local.prefetch_context is None - local.prefetch_context = self + local.prefetch_context_stack.append(self) def __exit__(self, exc_type, exc_val, exc_tb): - assert local.prefetch_context is self - local.prefetch_context = None + stack = local.prefetch_context_stack + assert stack and stack[-1] is self + stack.pop() def get_frozen_attrs_to_prefetch(self, entity): attrs_to_prefetch = self.attrs_to_prefetch_dict.get(entity, ()) if type(attrs_to_prefetch) is set: @@ -329,11 +329,16 @@ def __init__(local): local.db2cache = {} local.db_context_counter = 0 local.db_session = None - local.prefetch_context = None + local.prefetch_context_stack = [] local.current_user = None local.perms_context = None local.user_groups_cache = {} local.user_roles_cache = defaultdict(dict) + @property + def prefetch_context(local): + if local.prefetch_context_stack: + return local.prefetch_context_stack[-1] + return None def push_debug_state(local, debug, show_values): local.debug_stack.append((local.debug, local.show_values)) if not suppress_debug_change: diff --git a/pony/orm/tests/test_hooks.py b/pony/orm/tests/test_hooks.py index af3c2002f..33f55e675 100644 --- a/pony/orm/tests/test_hooks.py +++ b/pony/orm/tests/test_hooks.py @@ -57,6 +57,11 @@ def set_hooks_to_do_nothing(): set_hooks_to_do_nothing() +def flush_for(*objects): + for obj in objects: + obj.flush() + + class TestHooks(unittest.TestCase): @classmethod def setUpClass(cls): @@ -79,7 +84,7 @@ def tearDown(self): pass @db_session - def test_1(self): + def test_1a(self): p4 = Person(id=4, name='Bob', age=16) p5 = Person(id=5, name='Lucy', age=23) self.assertEqual(logged_events, []) @@ -87,7 +92,15 @@ def test_1(self): self.assertEqual(logged_events, ['BI_Bob', 'BI_Lucy', 'AI_Bob', 'AI_Lucy']) @db_session - def test_2(self): + def test_1b(self): + p4 = Person(id=4, name='Bob', age=16) + p5 = Person(id=5, name='Lucy', age=23) + self.assertEqual(logged_events, []) + flush_for(p4, p5) + self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob', 'BI_Lucy', 'AI_Lucy']) + + @db_session + def test_2a(self): p4 = Person(id=4, name='Bob', age=16) p1 = Person[1] # auto-flush here p2 = Person[2] @@ -98,50 +111,7 @@ def test_2(self): self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob', 'BU_Mary', 'BI_Lucy', 'AU_Mary', 'AI_Lucy']) @db_session - def test_3(self): - global do_before_insert - def do_before_insert(person): - some_person = Person.select().first() # should not cause infinite recursion - p4 = Person(id=4, name='Bob', age=16) - db.flush() - - -def flush_for(*objects): - for obj in objects: - obj.flush() - - -class ObjectFlushTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - setup_database(db) - - @classmethod - def tearDownClass(cls): - teardown_database(db) - - def setUp(self): - set_hooks_to_do_nothing() - with db_session: - db.execute('delete from Person') - p1 = Person(id=1, name='John', age=22) - p2 = Person(id=2, name='Mary', age=18) - p3 = Person(id=3, name='Mike', age=25) - logged_events[:] = [] - - def tearDown(self): - pass - - @db_session - def test_1(self): - p4 = Person(id=4, name='Bob', age=16) - p5 = Person(id=5, name='Lucy', age=23) - self.assertEqual(logged_events, []) - flush_for(p4, p5) - self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob', 'BI_Lucy', 'AI_Lucy']) - - @db_session - def test_2(self): + def test_2b(self): p4 = Person(id=4, name='Bob', age=16) p1 = Person[1] # auto-flush here p2 = Person[2] @@ -157,7 +127,15 @@ def test_3(self): def do_before_insert(person): some_person = Person.select().first() # should not cause infinite recursion p4 = Person(id=4, name='Bob', age=16) - p4.flush() + db.flush() + + @db_session + def test_4(self): + global do_before_insert + def do_before_insert(person): + some_person = Person.select().first() # creates nested prefetch_context + p4 = Person(id=4, name='Bob', age=16) + Person.select().first() if __name__ == '__main__':