Skip to content

Commit f15fb97

Browse files
committed
Type hint make_tuple / fix *args/**kwargs return type
Signed-off-by: Michael Carlstrom <[email protected]>
1 parent e6984c8 commit f15fb97

File tree

5 files changed

+45
-30
lines changed

5 files changed

+45
-30
lines changed

include/pybind11/cast.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,21 +1465,24 @@ template <>
14651465
struct handle_type_name<weakref> {
14661466
static constexpr auto name = const_name("weakref.ReferenceType");
14671467
};
1468+
// args/Args/kwargs/KWArgs have name as well as typehint included
14681469
template <>
14691470
struct handle_type_name<args> {
1470-
static constexpr auto name = const_name("*args");
1471+
static constexpr auto name = io_name("*args", "tuple");
14711472
};
14721473
template <typename T>
14731474
struct handle_type_name<Args<T>> {
1474-
static constexpr auto name = const_name("*args: ") + make_caster<T>::name;
1475+
static constexpr auto name
1476+
= io_name("*args: ", "tuple[") + make_caster<T>::name + io_name("", ", ...]");
14751477
};
14761478
template <>
14771479
struct handle_type_name<kwargs> {
1478-
static constexpr auto name = const_name("**kwargs");
1480+
static constexpr auto name = io_name("**kwargs", "dict[str, typing.Any]");
14791481
};
14801482
template <typename T>
14811483
struct handle_type_name<KWArgs<T>> {
1482-
static constexpr auto name = const_name("**kwargs: ") + make_caster<T>::name;
1484+
static constexpr auto name
1485+
= io_name("**kwargs: ", "dict[str, ") + make_caster<T>::name + io_name("", "]");
14831486
};
14841487
template <>
14851488
struct handle_type_name<obj_attr_accessor> {
@@ -1905,13 +1908,20 @@ inline cast_error cast_error_unable_to_convert_call_arg(const std::string &name,
19051908
}
19061909
#endif
19071910

1911+
namespace typing {
1912+
template <typename... Types>
1913+
class Tuple : public tuple {
1914+
using tuple::tuple;
1915+
};
1916+
} // namespace typing
1917+
19081918
template <return_value_policy policy = return_value_policy::automatic_reference>
1909-
tuple make_tuple() {
1919+
typing::Tuple<> make_tuple() {
19101920
return tuple(0);
19111921
}
19121922

19131923
template <return_value_policy policy = return_value_policy::automatic_reference, typename... Args>
1914-
tuple make_tuple(Args &&...args_) {
1924+
typing::Tuple<Args...> make_tuple(Args &&...args_) {
19151925
constexpr size_t size = sizeof...(Args);
19161926
std::array<object, size> args{{reinterpret_steal<object>(
19171927
detail::make_caster<Args>::cast(std::forward<Args>(args_), policy, nullptr))...}};

include/pybind11/detail/init.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,15 @@ template <typename Get,
501501
typename NewInstance,
502502
typename ArgState>
503503
struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
504-
static_assert(std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
505-
"The type returned by `__getstate__` must be the same "
506-
"as the argument accepted by `__setstate__`");
504+
using Ret = intrinsic_t<RetState>;
505+
using Arg = intrinsic_t<ArgState>;
506+
507+
// Subclasses are now allowed for support between type hint and generic versions of types
508+
// (e.g.) typing::List <--> list
509+
static_assert(std::is_same<Ret, Arg>::value || std::is_base_of<Ret, Arg>::value
510+
|| std::is_base_of<Arg, Ret>::value,
511+
"The type returned by `__getstate__` must be the same or subclass of the "
512+
"argument accepted by `__setstate__`");
507513

508514
remove_reference_t<Get> get;
509515
remove_reference_t<Set> set;

include/pybind11/pybind11.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel
120120
const auto c = *pc;
121121
if (c == '{') {
122122
// Write arg name for everything except *args and **kwargs.
123-
is_starred = *(pc + 1) == '*';
123+
// Detect {@*args...} or {@**kwargs...}
124+
is_starred = *(pc + 1) == '@' && *(pc + 2) == '*';
124125
if (is_starred) {
125126
continue;
126127
}
@@ -155,7 +156,7 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel
155156
} else if (c == '%') {
156157
const std::type_info *t = types[type_index++];
157158
if (!t) {
158-
pybind11_fail("Internal error while parsing type signature (1)");
159+
// pybind11_fail("Internal error while parsing type signature (1)");
159160
}
160161
if (auto *tinfo = detail::get_type_info(*t)) {
161162
handle th((PyObject *) tinfo->type);

include/pybind11/typing.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ PYBIND11_NAMESPACE_BEGIN(typing)
3434
There is no additional enforcement of types at runtime.
3535
*/
3636

37-
template <typename... Types>
38-
class Tuple : public tuple {
39-
using tuple::tuple;
40-
};
37+
// Tuple type hint defined in cast.h for use in py::make_tuple to avoid circular includes
4138

4239
template <typename K, typename V>
4340
class Dict : public dict {

tests/test_kwargs_and_defaults.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ def test_function_signatures(doc):
3434
)
3535
assert doc(m.args_function) == "args_function(*args) -> tuple"
3636
assert (
37-
doc(m.args_kwargs_function) == "args_kwargs_function(*args, **kwargs) -> tuple"
37+
doc(m.args_kwargs_function)
38+
== "args_kwargs_function(*args, **kwargs) -> tuple[tuple, dict[str, typing.Any]]"
3839
)
3940
assert (
4041
doc(m.args_kwargs_subclass_function)
41-
== "args_kwargs_subclass_function(*args: str, **kwargs: str) -> tuple"
42+
== "args_kwargs_subclass_function(*args: str, **kwargs: str) -> tuple[tuple[str, ...], dict[str, str]]"
4243
)
4344
assert (
4445
doc(m.KWClass.foo0)
@@ -138,7 +139,7 @@ def test_mixed_args_and_kwargs(msg):
138139
msg(excinfo.value)
139140
== """
140141
mixed_plus_args(): incompatible function arguments. The following argument types are supported:
141-
1. (arg0: typing.SupportsInt, arg1: typing.SupportsFloat, *args) -> tuple
142+
1. (arg0: typing.SupportsInt, arg1: typing.SupportsFloat, *args) -> tuple[int, float, tuple]
142143
143144
Invoked with: 1
144145
"""
@@ -149,7 +150,7 @@ def test_mixed_args_and_kwargs(msg):
149150
msg(excinfo.value)
150151
== """
151152
mixed_plus_args(): incompatible function arguments. The following argument types are supported:
152-
1. (arg0: typing.SupportsInt, arg1: typing.SupportsFloat, *args) -> tuple
153+
1. (arg0: typing.SupportsInt, arg1: typing.SupportsFloat, *args) -> tuple[int, float, tuple]
153154
154155
Invoked with:
155156
"""
@@ -183,7 +184,7 @@ def test_mixed_args_and_kwargs(msg):
183184
msg(excinfo.value)
184185
== """
185186
mixed_plus_args_kwargs_defaults(): incompatible function arguments. The following argument types are supported:
186-
1. (i: typing.SupportsInt = 1, j: typing.SupportsFloat = 3.14159, *args, **kwargs) -> tuple
187+
1. (i: typing.SupportsInt = 1, j: typing.SupportsFloat = 3.14159, *args, **kwargs) -> tuple[int, float, tuple, dict[str, typing.Any]]
187188
188189
Invoked with: 1; kwargs: i=1
189190
"""
@@ -194,7 +195,7 @@ def test_mixed_args_and_kwargs(msg):
194195
msg(excinfo.value)
195196
== """
196197
mixed_plus_args_kwargs_defaults(): incompatible function arguments. The following argument types are supported:
197-
1. (i: typing.SupportsInt = 1, j: typing.SupportsFloat = 3.14159, *args, **kwargs) -> tuple
198+
1. (i: typing.SupportsInt = 1, j: typing.SupportsFloat = 3.14159, *args, **kwargs) -> tuple[int, float, tuple, dict[str, typing.Any]]
198199
199200
Invoked with: 1, 2; kwargs: j=1
200201
"""
@@ -211,7 +212,7 @@ def test_mixed_args_and_kwargs(msg):
211212
msg(excinfo.value)
212213
== """
213214
args_kwonly(): incompatible function arguments. The following argument types are supported:
214-
1. (i: typing.SupportsInt, j: typing.SupportsFloat, *args, z: typing.SupportsInt) -> tuple
215+
1. (i: typing.SupportsInt, j: typing.SupportsFloat, *args, z: typing.SupportsInt) -> tuple[int, float, tuple, int]
215216
216217
Invoked with: 2, 2.5, 22
217218
"""
@@ -233,12 +234,12 @@ def test_mixed_args_and_kwargs(msg):
233234
)
234235
assert (
235236
m.args_kwonly_kwargs.__doc__
236-
== "args_kwonly_kwargs(i: typing.SupportsInt, j: typing.SupportsFloat, *args, z: typing.SupportsInt, **kwargs) -> tuple\n"
237+
== "args_kwonly_kwargs(i: typing.SupportsInt, j: typing.SupportsFloat, *args, z: typing.SupportsInt, **kwargs) -> tuple[int, float, tuple, int, dict[str, typing.Any]]\n"
237238
)
238239

239240
assert (
240241
m.args_kwonly_kwargs_defaults.__doc__
241-
== "args_kwonly_kwargs_defaults(i: typing.SupportsInt = 1, j: typing.SupportsFloat = 3.14159, *args, z: typing.SupportsInt = 42, **kwargs) -> tuple\n"
242+
== "args_kwonly_kwargs_defaults(i: typing.SupportsInt = 1, j: typing.SupportsFloat = 3.14159, *args, z: typing.SupportsInt = 42, **kwargs) -> tuple[int, float, tuple, int, dict[str, typing.Any]]\n"
242243
)
243244
assert m.args_kwonly_kwargs_defaults() == (1, 3.14159, (), 42, {})
244245
assert m.args_kwonly_kwargs_defaults(2) == (2, 3.14159, (), 42, {})
@@ -344,7 +345,7 @@ def test_positional_only_args():
344345
# Mix it with args and kwargs:
345346
assert (
346347
m.args_kwonly_full_monty.__doc__
347-
== "args_kwonly_full_monty(arg0: typing.SupportsInt = 1, arg1: typing.SupportsInt = 2, /, j: typing.SupportsFloat = 3.14159, *args, z: typing.SupportsInt = 42, **kwargs) -> tuple\n"
348+
== "args_kwonly_full_monty(arg0: typing.SupportsInt = 1, arg1: typing.SupportsInt = 2, /, j: typing.SupportsFloat = 3.14159, *args, z: typing.SupportsInt = 42, **kwargs) -> tuple[int, int, float, tuple, int, dict[str, typing.Any]]\n"
348349
)
349350
assert m.args_kwonly_full_monty() == (1, 2, 3.14159, (), 42, {})
350351
assert m.args_kwonly_full_monty(8) == (8, 2, 3.14159, (), 42, {})
@@ -394,23 +395,23 @@ def test_positional_only_args():
394395
def test_signatures():
395396
assert (
396397
m.kw_only_all.__doc__
397-
== "kw_only_all(*, i: typing.SupportsInt, j: typing.SupportsInt) -> tuple\n"
398+
== "kw_only_all(*, i: typing.SupportsInt, j: typing.SupportsInt) -> tuple[int, int]\n"
398399
)
399400
assert (
400401
m.kw_only_mixed.__doc__
401-
== "kw_only_mixed(i: typing.SupportsInt, *, j: typing.SupportsInt) -> tuple\n"
402+
== "kw_only_mixed(i: typing.SupportsInt, *, j: typing.SupportsInt) -> tuple[int, int]\n"
402403
)
403404
assert (
404405
m.pos_only_all.__doc__
405-
== "pos_only_all(i: typing.SupportsInt, j: typing.SupportsInt, /) -> tuple\n"
406+
== "pos_only_all(i: typing.SupportsInt, j: typing.SupportsInt, /) -> tuple[int, int]\n"
406407
)
407408
assert (
408409
m.pos_only_mix.__doc__
409-
== "pos_only_mix(i: typing.SupportsInt, /, j: typing.SupportsInt) -> tuple\n"
410+
== "pos_only_mix(i: typing.SupportsInt, /, j: typing.SupportsInt) -> tuple[int, int]\n"
410411
)
411412
assert (
412413
m.pos_kw_only_mix.__doc__
413-
== "pos_kw_only_mix(i: typing.SupportsInt, /, j: typing.SupportsInt, *, k: typing.SupportsInt) -> tuple\n"
414+
== "pos_kw_only_mix(i: typing.SupportsInt, /, j: typing.SupportsInt, *, k: typing.SupportsInt) -> tuple[int, int, int]\n"
414415
)
415416

416417

0 commit comments

Comments
 (0)