Skip to content

Commit

Permalink
pythongh-125660: Enable setting persistent_id and persistent_load of …
Browse files Browse the repository at this point in the history
…pickler and unpickler

pickle.Pickler and pickle.Unpickler instances have now managed dicts.
Arbitrary instance attributes, including persistent_id and persistent_load,
can now be set.
  • Loading branch information
serhiy-storchaka committed Oct 19, 2024
1 parent 2e950e3 commit 322e2aa
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 8 deletions.
9 changes: 5 additions & 4 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,11 @@ def save(self, obj, save_persistent_id=True):
self.framer.commit_frame()

# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
if pid is not None and save_persistent_id:
self.save_pers(pid)
return
if save_persistent_id:
pid = self.persistent_id(obj)
if pid is not None:
self.save_pers(pid)
return

# Check the memo
x = self.memo.get(id(obj))
Expand Down
32 changes: 30 additions & 2 deletions Lib/test/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,30 @@ def persistent_load(subself, pid):
unpickler = PersUnpickler(io.BytesIO(self.dumps('abc', proto)))
self.assertEqual(unpickler.load(), 'abc')

def test_pickler_instance_attribute(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
f = io.BytesIO()
pickler = self.pickler(f, proto)
called = []
def persistent_id(obj):
called.append(obj)
return obj
pickler.persistent_id = persistent_id
pickler.dump('abc')
self.assertEqual(called, ['abc'])
self.assertEqual(self.loads(f.getvalue()), 'abc')

def test_unpickler_instance_attribute(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
unpickler = self.unpickler(io.BytesIO(self.dumps('abc', proto)))
called = []
def persistent_load(pid):
called.append(pid)
return pid
unpickler.persistent_load = persistent_load
self.assertEqual(unpickler.load(), 'abc')
self.assertEqual(called, ['abc'])

class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):

pickler_class = pickle._Pickler
Expand Down Expand Up @@ -368,17 +392,20 @@ class SizeofTests(unittest.TestCase):

def test_pickler(self):
basesize = support.calcobjsize('6P2n3i2n3i2P')
P = struct.calcsize('P')
p = _pickle.Pickler(io.BytesIO())
self.assertEqual(object.__sizeof__(p), basesize)
MT_size = struct.calcsize('3nP0n')
ME_size = struct.calcsize('Pn0P')
check = self.check_sizeof
check(p, basesize +
2 * P + # Managed dict
MT_size + 8 * ME_size + # Minimal memo table size.
sys.getsizeof(b'x'*4096)) # Minimal write buffer size.
for i in range(6):
p.dump(chr(i))
check(p, basesize +
2 * P + # Managed dict
MT_size + 32 * ME_size + # Size of memo table required to
# save references to 6 objects.
0) # Write buffer is cleared after every dump().
Expand All @@ -395,6 +422,7 @@ def test_unpickler(self):
encoding=encoding, errors=errors)
self.assertEqual(object.__sizeof__(u), basesize)
check(u, basesize +
2 * P + # Managed dict
32 * P + # Minimal memo table size.
len(encoding) + 1 + len(errors) + 1)

Expand All @@ -404,7 +432,7 @@ def check_unpickler(data, memo_size, marks_size):
u = unpickler(io.BytesIO(dump),
encoding='ASCII', errors='strict')
u.load()
check(u, stdsize + memo_size * P + marks_size * n)
check(u, stdsize + 2 * P + memo_size * P + marks_size * n)

check_unpickler(0, 32, 0)
# 20 is minimal non-empty mark stack size.
Expand All @@ -427,7 +455,7 @@ def recurse(deep):
u = unpickler(io.BytesIO(pickle.dumps('a', 0)),
encoding='ASCII', errors='strict')
u.load()
check(u, stdsize + 32 * P + 2 + 1)
check(u, stdsize + 2 * P + 32 * P + 2 + 1)


ALT_IMPORT_MAPPING = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Restore ability to set :attr:`~pickle.Pickler.persistent_id` and
:attr:`~pickle.Unpickler.persistent_load` attributes of instances of the
:class:`!Pickler` and :class:`!Unpickler` classes in the :mod:`pickle`
module.
4 changes: 2 additions & 2 deletions Modules/_pickle.c
Original file line number Diff line number Diff line change
Expand Up @@ -5120,7 +5120,7 @@ static PyType_Spec pickler_type_spec = {
.name = "_pickle.Pickler",
.basicsize = sizeof(PicklerObject),
.flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_IMMUTABLETYPE),
Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_MANAGED_DICT),
.slots = pickler_type_slots,
};

Expand Down Expand Up @@ -7585,7 +7585,7 @@ static PyType_Spec unpickler_type_spec = {
.name = "_pickle.Unpickler",
.basicsize = sizeof(UnpicklerObject),
.flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_IMMUTABLETYPE),
Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_MANAGED_DICT),
.slots = unpickler_type_slots,
};

Expand Down

0 comments on commit 322e2aa

Please sign in to comment.