Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental: Guard against Python multiple inheritance divergence #4928

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions include/pybind11/detail/type_caster_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,40 @@ inline void all_type_info_add_base_most_derived_first(std::vector<type_info *> &
bases.push_back(addl_base);
}

inline void all_type_info_check_for_divergence(const std::vector<type_info *> &bases) {
using sz_t = std::size_t;
sz_t n = bases.size();
if (n < 3) {
return;
}
std::vector<sz_t> cluster_ids;
cluster_ids.reserve(n);
for (sz_t ci = 0; ci < n; ci++) {
cluster_ids.push_back(ci);
}
for (sz_t i = 0; i < n - 1; i++) {
if (cluster_ids[i] != i) {
continue;
}
for (sz_t j = i + 1; j < n; j++) {
if (PyType_IsSubtype(bases[i]->type, bases[j]->type) != 0) {
sz_t k = cluster_ids[j];
if (k == j) {
cluster_ids[j] = i;
} else {
PyErr_Format(
PyExc_TypeError,
"bases include diverging derived types: base=%s, derived1=%s, derived2=%s",
bases[j]->type->tp_name,
bases[k]->type->tp_name,
bases[i]->type->tp_name);
throw error_already_set();
}
}
}
}
}

// Populates a just-created cache entry.
PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_info *> &bases) {
assert(bases.empty());
Expand Down Expand Up @@ -168,6 +202,7 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
}
}
}
all_type_info_check_for_divergence(bases);
}

/**
Expand Down Expand Up @@ -750,14 +785,22 @@ class type_caster_generic {
// if we can find an exact match (or, for a simple C++ type, an inherited match); if
// so, we can safely reinterpret_cast to the relevant pointer.
if (bases.size() > 1) {
std::vector<type_info *> matching_bases;
for (auto *base : bases) {
if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type)
: base->type == typeinfo->type) {
this_.load_value(
reinterpret_cast<instance *>(src.ptr())->get_value_and_holder(base));
return true;
matching_bases.push_back(base);
}
}
if (!matching_bases.empty()) {
if (matching_bases.size() > 1) {
matching_bases.push_back(const_cast<type_info *>(typeinfo));
all_type_info_check_for_divergence(matching_bases);
}
this_.load_value(reinterpret_cast<instance *>(src.ptr())->get_value_and_holder(
matching_bases[0]));
return true;
}
}

// Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type
Expand Down
21 changes: 21 additions & 0 deletions tests/test_python_multiple_inheritance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ struct CppDrvd : CppBase {
int drvd_value;
};

struct CppDrvd2 : CppBase {
explicit CppDrvd2(int value) : CppBase(value), drvd2_value(value * 5) {}
int get_drvd2_value() const { return drvd2_value; }
void reset_drvd2_value(int new_value) { drvd2_value = new_value; }

int get_base_value_from_drvd2() const { return get_base_value(); }
void reset_base_value_from_drvd2(int new_value) { reset_base_value(new_value); }

private:
int drvd2_value;
};

} // namespace test_python_multiple_inheritance

TEST_SUBMODULE(python_multiple_inheritance, m) {
Expand All @@ -42,4 +54,13 @@ TEST_SUBMODULE(python_multiple_inheritance, m) {
.def("reset_drvd_value", &CppDrvd::reset_drvd_value)
.def("get_base_value_from_drvd", &CppDrvd::get_base_value_from_drvd)
.def("reset_base_value_from_drvd", &CppDrvd::reset_base_value_from_drvd);

py::class_<CppDrvd2, CppBase>(m, "CppDrvd2")
.def(py::init<int>())
.def("get_drvd2_value", &CppDrvd2::get_drvd2_value)
.def("reset_drvd2_value", &CppDrvd2::reset_drvd2_value)
.def("get_base_value_from_drvd2", &CppDrvd2::get_base_value_from_drvd2)
.def("reset_base_value_from_drvd2", &CppDrvd2::reset_base_value_from_drvd2);

m.def("pass_CppBase", [](const CppBase *) {});
}
48 changes: 48 additions & 0 deletions tests/test_python_multiple_inheritance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Adapted from:
# https://github.com/google/clif/blob/5718e4d0807fd3b6a8187dde140069120b81ecef/clif/testing/python/python_multiple_inheritance_test.py

import pytest

from pybind11_tests import python_multiple_inheritance as m


Expand All @@ -12,6 +14,28 @@ class PPCC(PC, m.CppDrvd):
pass


class PPPCCC(PPCC, m.CppDrvd2):
pass


class PC1(m.CppDrvd):
pass


class PC2(m.CppDrvd2):
pass


class PCD(PC1, PC2):
pass


class PCDI(PC1, PC2):
def __init__(self):
PC1.__init__(self, 11)
PC2.__init__(self, 12)


def test_PC():
d = PC(11)
assert d.get_base_value() == 11
Expand All @@ -33,3 +57,27 @@ def test_PPCC():
d.reset_base_value_from_drvd(30)
assert d.get_base_value() == 30
assert d.get_base_value_from_drvd() == 30


def NOtest_PPPCCC():
# terminate called after throwing an instance of 'pybind11::error_already_set'
# what(): TypeError: bases include diverging derived types:
# base=pybind11_tests.python_multiple_inheritance.CppBase,
# derived1=pybind11_tests.python_multiple_inheritance.CppDrvd,
# derived2=pybind11_tests.python_multiple_inheritance.CppDrvd2
PPPCCC(11)


def test_PCD():
# This escapes all_type_info_check_for_divergence() because CppBase does not appear in bases.
with pytest.raises(
TypeError,
match=r"CppDrvd2\.__init__\(\) must be called when overriding __init__$",
):
PCD(11)


def test_PCDI():
obj = PCDI()
with pytest.raises(TypeError, match="^bases include diverging derived types: "):
m.pass_CppBase(obj)