From 5f829b89f7d887bfe925fa3c4f835fe920e284af Mon Sep 17 00:00:00 2001 From: Vishal Pankaj Chandratreya <19171016+tfpf@users.noreply.github.com> Date: Thu, 28 Nov 2024 18:28:16 +0530 Subject: [PATCH] Miscellaneous fixes (#22) * `TypeError` for invalid constructor argument * Added stringifiers * Removed PyErr_FormatWrapper and handled all errors * Use `PyUnicode_FromStringAndSize` because size is known * Version bump * Check key types strictly * Wrong type test: use subtypes --- pyproject.toml | 2 +- src/pysorteddict/pysorteddict.cc | 93 +++++++++++++++++++++++------- tests/test_invalid_construction.py | 2 +- tests/utils.py | 6 +- 4 files changed, 76 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 610a586..eb6a721 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pysorteddict" -version = "0.0.7" +version = "0.0.8" authors = [ {name = "Vishal Pankaj Chandratreya"}, ] diff --git a/src/pysorteddict/pysorteddict.cc b/src/pysorteddict/pysorteddict.cc index 05fdc29..94e7fca 100644 --- a/src/pysorteddict/pysorteddict.cc +++ b/src/pysorteddict/pysorteddict.cc @@ -3,6 +3,8 @@ #include #include #include +#include +#include /** * C++-style comparison implementation for Python objects. @@ -11,24 +13,44 @@ struct PyObject_CustomCompare { bool operator()(PyObject* a, PyObject* b) const { - // This assumes that the comparison operation will never error out. I - // think it should be enough to ensure that the two Python objects - // being compared always have the same type. + // There must exist a total order on the set of possible keys. (Else, + // this comparison may error out.) Hence, only instances of the type + // passed to the constructor may be used as keys. (Instances of types + // derived from that type are not allowed, because comparisons between + // them may error out. See the constructor code.) With these + // precautions, this comparison should always work. return PyObject_RichCompareBool(a, b, Py_LT) == 1; } }; /** - * Set an error message containing the string representation of a Python - * object. + * Obtain the Python representation of a Python object. */ -static void PyErr_FormatWrapper(PyObject* exc, char const* fmt, PyObject* ob) +static std::pair repr(PyObject* ob) { - PyObject* repr = PyObject_Repr(ob); // New reference. - // The second argument is no longer a string constant. Is there an elegant - // fix? - PyErr_Format(exc, fmt, PyUnicode_AsUTF8(repr)); - Py_DECREF(repr); + PyObject* ob_repr = PyObject_Repr(ob); // New reference. + if (ob_repr == nullptr) + { + return { "", false }; + } + std::pair result = { PyUnicode_AsUTF8(ob_repr), true }; + Py_DECREF(ob_repr); + return result; +} + +/** + * Obtain a human-readable string representation of a Python object. + */ +static std::pair str(PyObject* ob) +{ + PyObject* ob_str = PyObject_Str(ob); // New reference. + if (ob_str == nullptr) + { + return { "", false }; + } + std::pair result = { PyUnicode_AsUTF8(ob_str), true }; + Py_DECREF(ob_str); + return result; } // clang-format off @@ -91,7 +113,7 @@ static PyObject* sorted_dict_type_new(PyTypeObject* type, PyObject* args, PyObje // Check the type to use for keys. if (PyObject_RichCompareBool(sd->key_type, (PyObject*)&PyLong_Type, Py_EQ) != 1) { - PyErr_SetString(PyExc_ValueError, "constructor argument must be a supported type"); + PyErr_SetString(PyExc_TypeError, "constructor argument must be a supported type"); // I haven't increased its reference count, so the deallocator // shouldn't decrease it. Hence, set it to a null pointer. sd->key_type = nullptr; @@ -119,9 +141,15 @@ static Py_ssize_t sorted_dict_type_len(PyObject* self) static PyObject* sorted_dict_type_getitem(PyObject* self, PyObject* key) { SortedDictType* sd = (SortedDictType*)self; - if (PyObject_IsInstance(key, sd->key_type) != 1) + if (Py_IS_TYPE(key, (PyTypeObject*)sd->key_type) == 0) { - PyErr_FormatWrapper(PyExc_TypeError, "key must be of type %s", sd->key_type); + PyObject* key_type_repr = PyObject_Repr(sd->key_type); // New reference. + if (key_type_repr == nullptr) + { + return nullptr; + } + PyErr_Format(PyExc_TypeError, "key must be of type %s", PyUnicode_AsUTF8(key_type_repr)); + Py_DECREF(key_type_repr); return nullptr; } auto it = sd->map->find(key); @@ -139,9 +167,15 @@ static PyObject* sorted_dict_type_getitem(PyObject* self, PyObject* key) static int sorted_dict_type_setitem(PyObject* self, PyObject* key, PyObject* value) { SortedDictType* sd = (SortedDictType*)self; - if (PyObject_IsInstance(key, sd->key_type) != 1) + if (Py_IS_TYPE(key, (PyTypeObject*)sd->key_type) == 0) { - PyErr_FormatWrapper(PyExc_TypeError, "key must be of type %s", sd->key_type); + PyObject* key_type_repr = PyObject_Repr(sd->key_type); // New reference. + if (key_type_repr == nullptr) + { + return -1; + } + PyErr_Format(PyExc_TypeError, "key must be of type %s", PyUnicode_AsUTF8(key_type_repr)); + Py_DECREF(key_type_repr); return -1; } @@ -198,15 +232,25 @@ static PyObject* sorted_dict_type_str(PyObject* self) oss << '\x7b'; for (auto& item : *sd->map) { - PyObject* key_repr = PyObject_Repr(item.first); // New reference. - PyObject* value_repr = PyObject_Repr(item.second); // New reference. - oss << delimiter << PyUnicode_AsUTF8(key_repr) << ": " << PyUnicode_AsUTF8(value_repr); + PyObject* key_str = PyObject_Str(item.first); // New reference. + if (key_str == nullptr) + { + return nullptr; + } + PyObject* value_str = PyObject_Str(item.second); // New reference. + if (value_str == nullptr) + { + Py_DECREF(key_str); + return nullptr; + } + oss << delimiter << PyUnicode_AsUTF8(key_str) << ": " << PyUnicode_AsUTF8(value_str); delimiter = actual_delimiter; - Py_DECREF(key_repr); - Py_DECREF(value_repr); + Py_DECREF(key_str); + Py_DECREF(value_str); } oss << '\x7d'; - return PyUnicode_FromString(oss.str().data()); // New reference. + std::string oss_str = oss.str(); + return PyUnicode_FromStringAndSize(oss_str.data(), oss_str.size()); // New reference. } /** @@ -224,6 +268,11 @@ static PyObject* sorted_dict_type_items(PyObject* self, PyObject* args) for (auto& item : *sd->map) { PyObject* pyitem = PyTuple_New(2); // New reference. + if (pyitem == nullptr) + { + Py_DECREF(pyitems); + return nullptr; + } PyTuple_SET_ITEM(pyitem, 0, item.first); Py_INCREF(item.first); PyTuple_SET_ITEM(pyitem, 1, item.second); diff --git a/tests/test_invalid_construction.py b/tests/test_invalid_construction.py index 50f3b83..8fe071b 100644 --- a/tests/test_invalid_construction.py +++ b/tests/test_invalid_construction.py @@ -17,7 +17,7 @@ def test_construct_without_argument(self): self.assertEqual(self.missing_argument, ctx.exception.args[0]) def test_construct_with_object_instance(self): - with self.assertRaises(ValueError) as ctx: + with self.assertRaises(TypeError) as ctx: SortedDict(object()) self.assertEqual(self.wrong_argument, ctx.exception.args[0]) diff --git a/tests/utils.py b/tests/utils.py index 7b6c13f..75c370f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,6 +32,7 @@ def setUpClass(cls): def setUp(self, key_type: type): self.key_type = key_type + self.key_subtype = type("sub" + self.key_type.__name__, (self.key_type,), {}) self.rg = random.Random(__name__) self.keys = [self.small_key() for _ in range(1000)] self.values = [self.small_key() for _ in self.keys] @@ -56,7 +57,7 @@ def test_len(self): def test_getitem_wrong_type(self): with self.assertRaises(TypeError) as ctx: - self.sorted_dict[object()] + self.sorted_dict[self.key_subtype()] self.assertEqual(self.wrong_argument, ctx.exception.args[0]) def test_getitem_not_found(self): @@ -78,10 +79,9 @@ def test_getitem(self): self.assertEqual(5, sys.getrefcount(value)) def test_setitem_wrong_type(self): - key = object() value = self.small_key() with self.assertRaises(TypeError) as ctx: - self.sorted_dict[key] = value + self.sorted_dict[self.key_subtype()] = value self.assertEqual(self.wrong_argument, ctx.exception.args[0]) if self.cpython: