Skip to content

Commit

Permalink
Added return_descr/arg_descr for correct typing in typing::Callable
Browse files Browse the repository at this point in the history
  • Loading branch information
timohl committed Jan 5, 2025
1 parent 9877aca commit 077d0b6
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 17 deletions.
10 changes: 10 additions & 0 deletions include/pybind11/detail/descr.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,15 @@ constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
return const_name("{") + descr + const_name("}");
}

template <size_t N, typename... Ts>
constexpr descr<N + 4, Ts...> arg_descr(const descr<N, Ts...> &descr) {
return const_name("@^") + descr + const_name("@^");
}

template <size_t N, typename... Ts>
constexpr descr<N + 4, Ts...> return_descr(const descr<N, Ts...> &descr) {
return const_name("@$") + descr + const_name("@$");
}

PYBIND11_NAMESPACE_END(detail)
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
21 changes: 18 additions & 3 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,11 @@ class cpp_function : public function {
std::string signature;
size_t type_index = 0, arg_index = 0;
bool is_starred = false;
// `is_return_value` is true if we are currently inside the return type of the signature.
// The same is true for `use_return_value`, except for forced usage of arg/return type
// using @^/@$.
bool is_return_value = false;
bool use_return_value = false;
for (const auto *pc = text; *pc != '\0'; ++pc) {
const auto c = *pc;

Expand Down Expand Up @@ -495,11 +499,21 @@ class cpp_function : public function {
signature += detail::quote_cpp_type_name(detail::clean_type_id(t->name()));
}
} else if (c == '@') {
// `@^ ... @^` and `@$ ... @$` are used to force arg/return value type (see
// typing::Callable/detail::arg_descr/detail::return_descr)
if ((*(pc + 1) == '^' && is_return_value)
|| (*(pc + 1) == '$' && !is_return_value)) {
use_return_value = !use_return_value;
}
if (*(pc + 1) == '^' || *(pc + 1) == '$') {
++pc;
continue;
}
// Handle types that differ depending on whether they appear
// in an argument or a return value position
// For named arguments (py::arg()) with noconvert set, use return value type
// in an argument or a return value position (see io_name<text1, text2>).
// For named arguments (py::arg()) with noconvert set, return value type is used.
++pc;
if (!is_return_value
if (!use_return_value
&& !(arg_index < rec->args.size() && !rec->args[arg_index].convert)) {
while (*pc && *pc != '@')
signature += *pc++;
Expand All @@ -518,6 +532,7 @@ class cpp_function : public function {
} else {
if (c == '-' && *(pc + 1) == '>') {
is_return_value = true;
use_return_value = true;
}
signature += c;
}
Expand Down
11 changes: 7 additions & 4 deletions include/pybind11/typing.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,19 @@ template <typename Return, typename... Args>
struct handle_type_name<typing::Callable<Return(Args...)>> {
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
static constexpr auto name
= const_name("Callable[[") + ::pybind11::detail::concat(make_caster<Args>::name...)
+ const_name("], ") + make_caster<retval_type>::name + const_name("]");
= const_name("Callable[[")
+ ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster<Args>::name)...)
+ const_name("], ") + ::pybind11::detail::return_descr(make_caster<retval_type>::name)
+ const_name("]");
};

template <typename Return>
struct handle_type_name<typing::Callable<Return(ellipsis)>> {
// PEP 484 specifies this syntax for defining only return types of callables
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
static constexpr auto name
= const_name("Callable[..., ") + make_caster<retval_type>::name + const_name("]");
static constexpr auto name = const_name("Callable[..., ")
+ ::pybind11::detail::return_descr(make_caster<retval_type>::name)
+ const_name("]");
};

template <typename T>
Expand Down
6 changes: 6 additions & 0 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,12 @@ TEST_SUBMODULE(pytypes, m) {
m.def("identity_iterable", [](const py::typing::Iterable<RealNumber> &x) { return x; });
// Iterator<T>
m.def("identity_iterator", [](const py::typing::Iterator<RealNumber> &x) { return x; });
// Callable<R(A)> identity
m.def("identity_callable",
[](const py::typing::Callable<RealNumber(const RealNumber &)> &x) { return x; });
// Callable<R(...)> identity
m.def("identity_callable_ellipsis",
[](const py::typing::Callable<RealNumber(py::ellipsis)> &x) { return x; });
// Callable<R(A)>
m.def("apply_callable",
[](const RealNumber &x, const py::typing::Callable<RealNumber(const RealNumber &)> &f) {
Expand Down
28 changes: 18 additions & 10 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,18 +1252,26 @@ def test_arg_return_type_hints(doc):
doc(m.identity_iterator)
== "identity_iterator(arg0: Iterator[Union[float, int]]) -> Iterator[float]"
)
# Callable<R(A)> identity
assert (
doc(m.identity_callable)
== "identity_callable(arg0: Callable[[Union[float, int]], float]) -> Callable[[Union[float, int]], float]"
)
# Callable<R(...)> identity
assert (
doc(m.identity_callable_ellipsis)
== "identity_callable_ellipsis(arg0: Callable[..., float]) -> Callable[..., float]"
)
# Callable<R(A)>
# TODO: Needs support for arg/return environments
# assert (
# doc(m.apply_callable)
# == "apply_callable(arg0: Union[float, int], arg1: Callable[[Union[float, int]], float]) -> float"
# )
assert (
doc(m.apply_callable)
== "apply_callable(arg0: Union[float, int], arg1: Callable[[Union[float, int]], float]) -> float"
)
# Callable<R(...)>
# TODO: Needs support for arg/return environments
# assert (
# doc(m.apply_callable_ellipsis)
# == "apply_callable_ellipsis(arg0: Union[float, int], arg1: Callable[..., float]) -> float"
# )
assert (
doc(m.apply_callable_ellipsis)
== "apply_callable_ellipsis(arg0: Union[float, int], arg1: Callable[..., float]) -> float"
)
# Union<T1, T2>
assert (
doc(m.identity_union)
Expand Down

0 comments on commit 077d0b6

Please sign in to comment.