From 54b0a97fa5602980ef1dd368336a08ad89da4748 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 17 Apr 2020 19:49:48 +0300 Subject: [PATCH] Add support for volatile collection attributes that don't throw "Phantom object appeared/disappeared" exceptions --- pony/orm/core.py | 13 +++-- pony/orm/tests/test_diagram_keys.py | 7 --- pony/orm/tests/test_volatile.py | 77 ++++++++++++++++++++++++++++- 3 files changed, 84 insertions(+), 13 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index fc54cb498..cbe555567 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2140,7 +2140,7 @@ def _init_(attr, entity, name): if attr.py_type == float: if attr.is_pk: throw(TypeError, 'PrimaryKey attribute %s cannot be of type float' % attr) elif attr.is_unique: throw(TypeError, 'Unique attribute %s cannot be of type float' % attr) - if attr.is_volatile and (attr.is_pk or attr.is_collection): throw(TypeError, + if attr.is_volatile and attr.is_pk: throw(TypeError, '%s attribute %s cannot be volatile' % (attr.__class__.__name__, attr)) if attr.interleave is not None: @@ -2150,6 +2150,8 @@ def _init_(attr, entity, name): '`interleave` option value should be True, False or None. Got: %r' % attr.interleave) def linked(attr): reverse = attr.reverse + if reverse.is_volatile: + attr.is_volatile = True if attr.cascade_delete is None: attr.cascade_delete = attr.is_collection and reverse.is_required elif attr.cascade_delete: @@ -2867,7 +2869,7 @@ def prefetch_load_all(attr, objects): else: phantoms = setdata2 - items if setdata2.added: phantoms -= setdata2.added - if phantoms: throw(UnrepeatableReadError, + if phantoms and not attr.is_volatile: throw(UnrepeatableReadError, 'Phantom object %s disappeared from collection %s.%s' % (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name)) items -= setdata2 @@ -2889,7 +2891,8 @@ def load(attr, obj, items=None): assert obj._status_ not in del_statuses setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() - elif setdata.is_fully_loaded: return setdata + elif setdata.is_fully_loaded and not attr.is_volatile: + return setdata entity = attr.entity reverse = attr.reverse rentity = reverse.entity @@ -2968,7 +2971,7 @@ def load(attr, obj, items=None): else: phantoms = setdata2 - items if setdata2.added: phantoms -= setdata2.added - if phantoms: throw(UnrepeatableReadError, + if phantoms and not attr.is_volatile: throw(UnrepeatableReadError, 'Phantom object %s disappeared from collection %s.%s' % (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name)) items -= setdata2 @@ -3125,7 +3128,7 @@ def db_reverse_add(attr, objects, item): for obj in objects: setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() - elif setdata.is_fully_loaded: throw(UnrepeatableReadError, + elif setdata.is_fully_loaded and not attr.is_volatile: throw(UnrepeatableReadError, 'Phantom object %s appeared in collection %s.%s' % (safe_repr(item), safe_repr(obj), attr.name)) setdata.add(item) def reverse_remove(attr, objects, item, undo_funcs): diff --git a/pony/orm/tests/test_diagram_keys.py b/pony/orm/tests/test_diagram_keys.py index 653656053..a1c7f5baa 100644 --- a/pony/orm/tests/test_diagram_keys.py +++ b/pony/orm/tests/test_diagram_keys.py @@ -159,13 +159,6 @@ def test_volatile_pk(self): class Entity1(db.Entity): a = PrimaryKey(int, volatile=True) - @raises_exception(TypeError, 'Set attribute Entity1.b cannot be volatile') - def test_volatile_set(self): - db = self.db = Database(**db_params) - class Entity1(db.Entity): - a = PrimaryKey(int) - b = Set('Entity2', volatile=True) - @raises_exception(TypeError, 'Volatile attribute Entity1.b cannot be part of primary key') def test_volatile_composite_pk(self): db = self.db = Database(**db_params) diff --git a/pony/orm/tests/test_volatile.py b/pony/orm/tests/test_volatile.py index e534ee01f..c710e7c22 100644 --- a/pony/orm/tests/test_volatile.py +++ b/pony/orm/tests/test_volatile.py @@ -5,7 +5,7 @@ from pony.orm.tests import setup_database, teardown_database -class TestVolatile(unittest.TestCase): +class TestVolatile1(unittest.TestCase): def setUp(self): db = self.db = Database() @@ -48,5 +48,80 @@ def test_2(self): item.flush() self.assertEqual(item.index, 1) + +class TestVolatile2(unittest.TestCase): + def setUp(self): + db = self.db = Database() + + class Group(db.Entity): + number = PrimaryKey(int) + students = Set('Student', volatile=True) + + class Student(db.Entity): + id = PrimaryKey(int) + name = Required(str) + group = Required('Group') + courses = Set('Course') + + class Course(db.Entity): + id = PrimaryKey(int) + name = Required(str) + students = Set('Student', volatile=True) + + setup_database(db) + + with db_session: + g1 = Group(number=123) + s1 = Student(id=1, name='A', group=g1) + s2 = Student(id=2, name='B', group=g1) + c1 = Course(id=1, name='C1', students=[s1, s2]) + c2 = Course(id=2, name='C1', students=[s1]) + + self.Group = Group + self.Student = Student + self.Course = Course + + def tearDown(self): + teardown_database(self.db) + + def test_1(self): + self.assertTrue(self.Group.students.is_volatile) + self.assertTrue(self.Student.group.is_volatile) + self.assertTrue(self.Student.courses.is_volatile) + self.assertTrue(self.Course.students.is_volatile) + + def test_2(self): + with db_session: + g1 = self.Group[123] + students = set(s.id for s in g1.students) + self.assertEqual(students, {1, 2}) + self.db.execute('''insert into student values(3, 'C', 123)''') + g1.students.load() + students = set(s.id for s in g1.students) + self.assertEqual(students, {1, 2, 3}) + + def test_3(self): + with db_session: + g1 = self.Group[123] + students = set(s.id for s in g1.students) + self.assertEqual(students, {1, 2}) + self.db.execute("insert into student values(3, 'C', 123)") + s3 = self.Student[3] + students = set(s.id for s in g1.students) + self.assertEqual(students, {1, 2, 3}) + + def test_4(self): + with db_session: + c1 = self.Course[1] + students = set(s.id for s in c1.students) + self.assertEqual(students, {1, 2}) + self.db.execute("insert into student values(3, 'C', 123)") + attr = self.Student.courses + self.db.execute("insert into %s values(1, 3)" % attr.table) + c1.students.load() + students = set(s.id for s in c1.students) + self.assertEqual(students, {1, 2, 3}) + + if __name__ == '__main__': unittest.main()