diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 998b0444a02f..b66d96010b56 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -1868,6 +1868,9 @@ def testScalarMapDefaults(self): with self.assertRaises(TypeError): 123 in msg.map_string_string + with self.assertRaises(TypeError): + msg.map_string_string.__contains__(123) + def testScalarMapComparison(self): msg1 = map_unittest_pb2.TestMap() msg2 = map_unittest_pb2.TestMap() @@ -2007,6 +2010,12 @@ def testMessageMap(self): with self.assertRaises(TypeError): msg.map_int32_foreign_message['123'] + with self.assertRaises(TypeError): + '123' in msg.map_int32_foreign_message + + with self.assertRaises(TypeError): + msg.map_int32_foreign_message.__contains__('123') + # Can't assign directly to submessage. with self.assertRaises(ValueError): msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123] diff --git a/python/map.c b/python/map.c index bbbd63611852..4b6e97e1932d 100644 --- a/python/map.c +++ b/python/map.c @@ -195,19 +195,19 @@ static PyObject* PyUpb_MapContainer_Subscript(PyObject* _self, PyObject* key) { return PyUpb_UpbToPy(u_val, val_f, self->arena); } -static PyObject* PyUpb_MapContainer_Contains(PyObject* _self, PyObject* key) { +static int PyUpb_MapContainer_Contains(PyObject* _self, PyObject* key) { PyUpb_MapContainer* self = (PyUpb_MapContainer*)_self; upb_Map* map = PyUpb_MapContainer_GetIfReified(self); - if (!map) Py_RETURN_FALSE; + if (!map) return 0; const upb_FieldDef* f = PyUpb_MapContainer_GetField(self); const upb_MessageDef* entry_m = upb_FieldDef_MessageSubDef(f); const upb_FieldDef* key_f = upb_MessageDef_Field(entry_m, 0); upb_MessageValue u_key; - if (!PyUpb_PyToUpb(key, key_f, &u_key, NULL)) return NULL; + if (!PyUpb_PyToUpb(key, key_f, &u_key, NULL)) return -1; if (upb_Map_Get(map, u_key, NULL)) { - Py_RETURN_TRUE; + return 1; } else { - Py_RETURN_FALSE; + return 0; } } @@ -339,8 +339,6 @@ PyObject* PyUpb_MapContainer_GetOrCreateWrapper(upb_Map* map, // ----------------------------------------------------------------------------- static PyMethodDef PyUpb_ScalarMapContainer_Methods[] = { - {"__contains__", PyUpb_MapContainer_Contains, METH_O, - "Tests whether a key is a member of the map."}, {"clear", PyUpb_MapContainer_Clear, METH_NOARGS, "Removes all elements from the map."}, {"get", (PyCFunction)PyUpb_MapContainer_Get, METH_VARARGS | METH_KEYWORDS, @@ -363,6 +361,7 @@ static PyType_Slot PyUpb_ScalarMapContainer_Slots[] = { {Py_mp_length, PyUpb_MapContainer_Length}, {Py_mp_subscript, PyUpb_MapContainer_Subscript}, {Py_mp_ass_subscript, PyUpb_MapContainer_AssignSubscript}, + {Py_sq_contains, PyUpb_MapContainer_Contains}, {Py_tp_methods, PyUpb_ScalarMapContainer_Methods}, {Py_tp_iter, PyUpb_MapIterator_New}, {Py_tp_repr, PyUpb_MapContainer_Repr}, @@ -382,8 +381,6 @@ static PyType_Spec PyUpb_ScalarMapContainer_Spec = { // ----------------------------------------------------------------------------- static PyMethodDef PyUpb_MessageMapContainer_Methods[] = { - {"__contains__", PyUpb_MapContainer_Contains, METH_O, - "Tests whether the map contains this element."}, {"clear", PyUpb_MapContainer_Clear, METH_NOARGS, "Removes all elements from the map."}, {"get", (PyCFunction)PyUpb_MapContainer_Get, METH_VARARGS | METH_KEYWORDS, @@ -408,6 +405,7 @@ static PyType_Slot PyUpb_MessageMapContainer_Slots[] = { {Py_mp_length, PyUpb_MapContainer_Length}, {Py_mp_subscript, PyUpb_MapContainer_Subscript}, {Py_mp_ass_subscript, PyUpb_MapContainer_AssignSubscript}, + {Py_sq_contains, PyUpb_MapContainer_Contains}, {Py_tp_methods, PyUpb_MessageMapContainer_Methods}, {Py_tp_iter, PyUpb_MapIterator_New}, {Py_tp_repr, PyUpb_MapContainer_Repr}, @@ -477,28 +475,38 @@ static PyType_Spec PyUpb_MapIterator_Spec = { static PyObject* GetMutableMappingBase(void) { PyObject* collections = NULL; PyObject* mapping = NULL; - PyObject* bases = NULL; + PyObject* base = NULL; if ((collections = PyImport_ImportModule("collections.abc")) && (mapping = PyObject_GetAttrString(collections, "MutableMapping"))) { - bases = Py_BuildValue("(O)", mapping); + base = Py_BuildValue("O", mapping); } Py_XDECREF(collections); Py_XDECREF(mapping); - return bases; + return base; } bool PyUpb_Map_Init(PyObject* m) { PyUpb_ModuleState* state = PyUpb_ModuleState_GetFromModule(m); - PyObject* bases = GetMutableMappingBase(); - if (!bases) return false; + PyObject* base = GetMutableMappingBase(); + if (!base) return false; + + const char* methods[] = {"keys", "items", "values", "__eq__", "__ne__", + "pop", "popitem", "update", "setdefault", NULL}; - state->message_map_container_type = - PyUpb_AddClassWithBases(m, &PyUpb_MessageMapContainer_Spec, bases); - state->scalar_map_container_type = - PyUpb_AddClassWithBases(m, &PyUpb_ScalarMapContainer_Spec, bases); + state->message_map_container_type = PyUpb_AddClassWithRegister( + m, &PyUpb_MessageMapContainer_Spec, base, methods); + if (!state->message_map_container_type) { + return false; + } + state->scalar_map_container_type = PyUpb_AddClassWithRegister( + m, &PyUpb_ScalarMapContainer_Spec, base, methods); + if (!state->scalar_map_container_type) { + return false; + } state->map_iterator_type = PyUpb_AddClass(m, &PyUpb_MapIterator_Spec); - Py_DECREF(bases); + Py_DECREF(base); + Py_DECREF(methods); return state->message_map_container_type && state->scalar_map_container_type && state->map_iterator_type; diff --git a/python/protobuf.c b/python/protobuf.c index 4f53154dcbd2..316b1f6603e0 100644 --- a/python/protobuf.c +++ b/python/protobuf.c @@ -323,6 +323,31 @@ PyTypeObject* PyUpb_AddClassWithBases(PyObject* m, PyType_Spec* spec, return (PyTypeObject*)type; } +PyTypeObject* PyUpb_AddClassWithRegister(PyObject* m, PyType_Spec* spec, + PyObject* virtual_base, + const char** methods) { + PyObject* type = PyType_FromSpec(spec); + PyObject* ret1 = PyObject_CallMethod(virtual_base, "register", "O", type); + if (!ret1) { + Py_XDECREF(type); + return NULL; + } + for (size_t i = 0; methods[i] != NULL; i++) { + PyObject* method = PyObject_GetAttrString(virtual_base, methods[i]); + if (!method) { + Py_XDECREF(type); + return NULL; + } + int ret2 = PyObject_SetAttrString(type, methods[i], method); + if (ret2 < 0) { + Py_XDECREF(type); + return NULL; + } + } + + return (PyTypeObject*)type; +} + const char* PyUpb_GetStrData(PyObject* obj) { if (PyUnicode_Check(obj)) { return PyUnicode_AsUTF8AndSize(obj, NULL); diff --git a/python/protobuf.h b/python/protobuf.h index 9c5894cbf9a8..21a8d09a86b6 100644 --- a/python/protobuf.h +++ b/python/protobuf.h @@ -180,6 +180,12 @@ PyTypeObject* PyUpb_AddClass(PyObject* m, PyType_Spec* spec); PyTypeObject* PyUpb_AddClassWithBases(PyObject* m, PyType_Spec* spec, PyObject* bases); +// Like PyUpb_AddClass(), but allows you to specify a tuple of base classes in +// `bases` to register as a "virtual subclass" with mixin methods. +PyTypeObject* PyUpb_AddClassWithRegister(PyObject* m, PyType_Spec* spec, + PyObject* virtual_base, + const char** methods); + // A function that implements the tp_new slot for types that we do not allow // users to create directly. This will immediately fail with an error message. PyObject* PyUpb_Forbidden_New(PyObject* cls, PyObject* args, PyObject* kwds);