Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental: Guard against Python multiple inheritance divergence #30114

Closed
19 changes: 14 additions & 5 deletions include/pybind11/detail/smart_holder_type_casters.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,16 +282,25 @@ class modified_type_caster_generic_load_impl {
// if we can find an exact match (or, for a simple C++ type, an inherited match); if
// so, we can safely reinterpret_cast to the relevant pointer.
if (bases.size() > 1) {
std::vector<type_info *> matching_bases;
for (auto *base : bases) {
if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type)
: base->type == typeinfo->type) {
this_.load_value_and_holder(
reinterpret_cast<instance *>(src.ptr())->get_value_and_holder(base));
loaded_v_h_cpptype = base->cpptype;
reinterpret_cast_deemed_ok = true;
return true;
matching_bases.push_back(base);
}
}
if (!matching_bases.empty()) {
if (matching_bases.size() > 1) {
matching_bases.push_back(const_cast<type_info *>(typeinfo));
all_type_info_check_for_divergence(matching_bases);
}
this_.load_value_and_holder(
reinterpret_cast<instance *>(src.ptr())->get_value_and_holder(
matching_bases[0]));
loaded_v_h_cpptype = matching_bases[0]->cpptype;
reinterpret_cast_deemed_ok = true;
return true;
}
}

// Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type
Expand Down
49 changes: 46 additions & 3 deletions include/pybind11/detail/type_caster_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,40 @@ inline void all_type_info_add_base_most_derived_first(std::vector<type_info *> &
bases.push_back(addl_base);
}

inline void all_type_info_check_for_divergence(const std::vector<type_info *> &bases) {
using sz_t = std::size_t;
sz_t n = bases.size();
if (n < 3) {
return;
}
std::vector<sz_t> cluster_ids;
cluster_ids.reserve(n);
for (sz_t ci = 0; ci < n; ci++) {
cluster_ids.push_back(ci);
}
for (sz_t i = 0; i < n - 1; i++) {
if (cluster_ids[i] != i) {
continue;
}
for (sz_t j = i + 1; j < n; j++) {
if (PyType_IsSubtype(bases[i]->type, bases[j]->type) != 0) {
sz_t k = cluster_ids[j];
if (k == j) {
cluster_ids[j] = i;
} else {
PyErr_Format(
PyExc_TypeError,
"bases include diverging derived types: base=%s, derived1=%s, derived2=%s",
bases[j]->type->tp_name,
bases[k]->type->tp_name,
bases[i]->type->tp_name);
throw error_already_set();
}
}
}
}
}

// Populates a just-created cache entry.
PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_info *> &bases) {
assert(bases.empty());
Expand Down Expand Up @@ -168,6 +202,7 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
}
}
}
all_type_info_check_for_divergence(bases);
}

/**
Expand Down Expand Up @@ -755,14 +790,22 @@ class type_caster_generic {
// if we can find an exact match (or, for a simple C++ type, an inherited match); if
// so, we can safely reinterpret_cast to the relevant pointer.
if (bases.size() > 1) {
std::vector<type_info *> matching_bases;
for (auto *base : bases) {
if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type)
: base->type == typeinfo->type) {
this_.load_value(
reinterpret_cast<instance *>(src.ptr())->get_value_and_holder(base));
return true;
matching_bases.push_back(base);
}
}
if (!matching_bases.empty()) {
if (matching_bases.size() > 1) {
matching_bases.push_back(const_cast<type_info *>(typeinfo));
all_type_info_check_for_divergence(matching_bases);
}
this_.load_value(reinterpret_cast<instance *>(src.ptr())->get_value_and_holder(
matching_bases[0]));
return true;
}
}

// Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type
Expand Down
40 changes: 37 additions & 3 deletions tests/test_python_multiple_inheritance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,28 @@ struct CppDrvd : CppBase<SerNo> {
int drvd_value;
};

template <int SerNo>
struct CppDrvdB : CppBase<SerNo> {
explicit CppDrvdB(int value) : CppBase<SerNo>(value), drvdb_value(value * 5) {}
int get_drvdb_value() const { return drvdb_value; }
void reset_drvdb_value(int new_value) { drvdb_value = new_value; }

int get_base_value_from_drvdb() const { return CppBase<SerNo>::get_base_value(); }
void reset_base_value_from_drvdb(int new_value) {
CppBase<SerNo>::reset_base_value(new_value);
}

private:
int drvdb_value;
};

template <int SerNo, typename... Extra>
void wrap_classes(py::module_ &m, const char *name_base, const char *name_drvd, Extra... extra) {
void wrap_classes(py::module_ &m,
const char *name_base,
const char *name_drvd,
const char *name_drvdb,
const char *pass_base,
Extra... extra) {
py::class_<CppBase<SerNo>>(m, name_base, std::forward<Extra>(extra)...)
.def(py::init<int>())
.def("get_base_value", &CppBase<SerNo>::get_base_value)
Expand All @@ -41,14 +61,28 @@ void wrap_classes(py::module_ &m, const char *name_base, const char *name_drvd,
.def("reset_drvd_value", &CppDrvd<SerNo>::reset_drvd_value)
.def("get_base_value_from_drvd", &CppDrvd<SerNo>::get_base_value_from_drvd)
.def("reset_base_value_from_drvd", &CppDrvd<SerNo>::reset_base_value_from_drvd);

py::class_<CppDrvdB<SerNo>, CppBase<SerNo>>(m, name_drvdb, std::forward<Extra>(extra)...)
.def(py::init<int>())
.def("get_drvdb_value", &CppDrvdB<SerNo>::get_drvdb_value)
.def("reset_drvdb_value", &CppDrvdB<SerNo>::reset_drvdb_value)
.def("get_base_value_from_drvdb", &CppDrvdB<SerNo>::get_base_value_from_drvdb)
.def("reset_base_value_from_drvdb", &CppDrvdB<SerNo>::reset_base_value_from_drvdb);

m.def(pass_base, [](const CppBase<SerNo> *) {});
}

} // namespace test_python_multiple_inheritance

TEST_SUBMODULE(python_multiple_inheritance, m) {
using namespace test_python_multiple_inheritance;
wrap_classes<0>(m, "CppBase0", "CppDrvd0");
wrap_classes<1>(m, "CppBase1", "CppDrvd1", py::metaclass((PyObject *) &PyType_Type));
wrap_classes<0>(m, "CppBase0", "CppDrvd0", "CppDrvdB0", "pass_CppBase0");
wrap_classes<1>(m,
"CppBase1",
"CppDrvd1",
"CppDrvdB1",
"pass_CppBase1",
py::metaclass((PyObject *) &PyType_Type));

m.attr("if_defined_PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS") =
#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
Expand Down
87 changes: 87 additions & 0 deletions tests/test_python_multiple_inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ def __init__(self, value):
del value


class PPPCCC0(PPCC0, m.CppDrvdB0):
pass


class PC10(m.CppDrvd0):
pass


class PC20(m.CppDrvdB0):
pass


class PCD0(PC10, PC20):
pass


class PCDI0(PC10, PC20):
def __init__(self):
PC10.__init__(self, 11)
PC20.__init__(self, 12)


#
# Using py::metaclass((PyObject *) &PyType_Type) (used with py::class_<> for CppBase1, CppDrvd1):
# COPY-PASTE block from above, replace 0 with 1:
Expand Down Expand Up @@ -62,6 +84,28 @@ def __init__(self, value):
del value


class PPPCCC1(PPCC1, m.CppDrvdB1):
pass


class PC11(m.CppDrvd1):
pass


class PC21(m.CppDrvdB1):
pass


class PCD1(PC11, PC21):
pass


class PCDI1(PC11, PC21):
def __init__(self):
PC11.__init__(self, 11)
PC21.__init__(self, 12)


@pytest.mark.parametrize(("pc_type"), [PC0, PC1])
def test_PC(pc_type):
d = pc_type(11)
Expand Down Expand Up @@ -136,3 +180,46 @@ def __init__(self, value):
for _ in range(100):
assert nested_function(0) == (10, 11)
assert nested_function(3) == (13, 14)


def NOtest_PPPCCC0():
# terminate called after throwing an instance of 'pybind11::error_already_set'
# what(): TypeError: bases include diverging derived types:
# base=pybind11_tests.python_multiple_inheritance.CppBase0,
# derived1=pybind11_tests.python_multiple_inheritance.CppDrvd0,
# derived2=pybind11_tests.python_multiple_inheritance.CppDrvdB0
PPPCCC0(11)


def NOtest_PPPCCC1():
# terminate called after throwing an instance of 'pybind11::error_already_set'
# what(): TypeError: bases include diverging derived types:
# base=pybind11_tests.python_multiple_inheritance.CppBase1,
# derived1=pybind11_tests.python_multiple_inheritance.CppDrvd1,
# derived2=pybind11_tests.python_multiple_inheritance.CppDrvdB1
PPPCCC1(11)


@pytest.mark.parametrize(
("pcd_type", "cppdrvdb"), [(PCD0, "CppDrvdB0"), (PCD1, "CppDrvdB1")]
)
def test_PCD(pcd_type, cppdrvdb):
if m.if_defined_PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS:
pytest.skip(
"PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS is defined"
)
# This escapes all_type_info_check_for_divergence() because CppBase does not appear in bases.
with pytest.raises(
TypeError,
match=cppdrvdb + r"\.__init__\(\) must be called when overriding __init__$",
):
pcd_type(11)


@pytest.mark.parametrize(
("pcdi_type", "pass_fn"), [(PCDI0, m.pass_CppBase0), (PCDI1, m.pass_CppBase1)]
)
def test_PCDI(pcdi_type, pass_fn):
obj = pcdi_type()
with pytest.raises(TypeError, match="^bases include diverging derived types: "):
pass_fn(obj)
Loading