From 077d0b6507f310bf9ec0c83c8066ef937e157181 Mon Sep 17 00:00:00 2001 From: Tim Ohliger Date: Sun, 5 Jan 2025 23:47:18 +0100 Subject: [PATCH] Added return_descr/arg_descr for correct typing in typing::Callable --- include/pybind11/detail/descr.h | 10 ++++++++++ include/pybind11/pybind11.h | 21 ++++++++++++++++++--- include/pybind11/typing.h | 11 +++++++---- tests/test_pytypes.cpp | 6 ++++++ tests/test_pytypes.py | 28 ++++++++++++++++++---------- 5 files changed, 59 insertions(+), 17 deletions(-) diff --git a/include/pybind11/detail/descr.h b/include/pybind11/detail/descr.h index e8bee4a191..f45822079f 100644 --- a/include/pybind11/detail/descr.h +++ b/include/pybind11/detail/descr.h @@ -174,5 +174,15 @@ constexpr descr type_descr(const descr &descr) { return const_name("{") + descr + const_name("}"); } +template +constexpr descr arg_descr(const descr &descr) { + return const_name("@^") + descr + const_name("@^"); +} + +template +constexpr descr return_descr(const descr &descr) { + return const_name("@$") + descr + const_name("@$"); +} + PYBIND11_NAMESPACE_END(detail) PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index b9ec5698f4..9705cf3c93 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -440,7 +440,11 @@ class cpp_function : public function { std::string signature; size_t type_index = 0, arg_index = 0; bool is_starred = false; + // `is_return_value` is true if we are currently inside the return type of the signature. + // The same is true for `use_return_value`, except for forced usage of arg/return type + // using @^/@$. bool is_return_value = false; + bool use_return_value = false; for (const auto *pc = text; *pc != '\0'; ++pc) { const auto c = *pc; @@ -495,11 +499,21 @@ class cpp_function : public function { signature += detail::quote_cpp_type_name(detail::clean_type_id(t->name())); } } else if (c == '@') { + // `@^ ... @^` and `@$ ... @$` are used to force arg/return value type (see + // typing::Callable/detail::arg_descr/detail::return_descr) + if ((*(pc + 1) == '^' && is_return_value) + || (*(pc + 1) == '$' && !is_return_value)) { + use_return_value = !use_return_value; + } + if (*(pc + 1) == '^' || *(pc + 1) == '$') { + ++pc; + continue; + } // Handle types that differ depending on whether they appear - // in an argument or a return value position - // For named arguments (py::arg()) with noconvert set, use return value type + // in an argument or a return value position (see io_name). + // For named arguments (py::arg()) with noconvert set, return value type is used. ++pc; - if (!is_return_value + if (!use_return_value && !(arg_index < rec->args.size() && !rec->args[arg_index].convert)) { while (*pc && *pc != '@') signature += *pc++; @@ -518,6 +532,7 @@ class cpp_function : public function { } else { if (c == '-' && *(pc + 1) == '>') { is_return_value = true; + use_return_value = true; } signature += c; } diff --git a/include/pybind11/typing.h b/include/pybind11/typing.h index 9693addb33..fac145bfc5 100644 --- a/include/pybind11/typing.h +++ b/include/pybind11/typing.h @@ -188,16 +188,19 @@ template struct handle_type_name> { using retval_type = conditional_t::value, void_type, Return>; static constexpr auto name - = const_name("Callable[[") + ::pybind11::detail::concat(make_caster::name...) - + const_name("], ") + make_caster::name + const_name("]"); + = const_name("Callable[[") + + ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster::name)...) + + const_name("], ") + ::pybind11::detail::return_descr(make_caster::name) + + const_name("]"); }; template struct handle_type_name> { // PEP 484 specifies this syntax for defining only return types of callables using retval_type = conditional_t::value, void_type, Return>; - static constexpr auto name - = const_name("Callable[..., ") + make_caster::name + const_name("]"); + static constexpr auto name = const_name("Callable[..., ") + + ::pybind11::detail::return_descr(make_caster::name) + + const_name("]"); }; template diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index c212ea804c..0221b59878 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -1139,6 +1139,12 @@ TEST_SUBMODULE(pytypes, m) { m.def("identity_iterable", [](const py::typing::Iterable &x) { return x; }); // Iterator m.def("identity_iterator", [](const py::typing::Iterator &x) { return x; }); + // Callable identity + m.def("identity_callable", + [](const py::typing::Callable &x) { return x; }); + // Callable identity + m.def("identity_callable_ellipsis", + [](const py::typing::Callable &x) { return x; }); // Callable m.def("apply_callable", [](const RealNumber &x, const py::typing::Callable &f) { diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 1d999dfdc8..f33b12d688 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -1252,18 +1252,26 @@ def test_arg_return_type_hints(doc): doc(m.identity_iterator) == "identity_iterator(arg0: Iterator[Union[float, int]]) -> Iterator[float]" ) + # Callable identity + assert ( + doc(m.identity_callable) + == "identity_callable(arg0: Callable[[Union[float, int]], float]) -> Callable[[Union[float, int]], float]" + ) + # Callable identity + assert ( + doc(m.identity_callable_ellipsis) + == "identity_callable_ellipsis(arg0: Callable[..., float]) -> Callable[..., float]" + ) # Callable - # TODO: Needs support for arg/return environments - # assert ( - # doc(m.apply_callable) - # == "apply_callable(arg0: Union[float, int], arg1: Callable[[Union[float, int]], float]) -> float" - # ) + assert ( + doc(m.apply_callable) + == "apply_callable(arg0: Union[float, int], arg1: Callable[[Union[float, int]], float]) -> float" + ) # Callable - # TODO: Needs support for arg/return environments - # assert ( - # doc(m.apply_callable_ellipsis) - # == "apply_callable_ellipsis(arg0: Union[float, int], arg1: Callable[..., float]) -> float" - # ) + assert ( + doc(m.apply_callable_ellipsis) + == "apply_callable_ellipsis(arg0: Union[float, int], arg1: Callable[..., float]) -> float" + ) # Union assert ( doc(m.identity_union)