Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer ParamSpec constraint from arguments #15896

Merged
merged 7 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,7 +1988,7 @@ def infer_function_type_arguments(
)

arg_pass_nums = self.get_arg_infer_passes(
callee_type.arg_types, formal_to_actual, len(args)
callee_type, args, arg_types, formal_to_actual, len(args)
)

pass1_args: list[Type | None] = []
Expand All @@ -2002,6 +2002,7 @@ def infer_function_type_arguments(
callee_type,
pass1_args,
arg_kinds,
arg_names,
formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function(),
Expand Down Expand Up @@ -2062,6 +2063,7 @@ def infer_function_type_arguments(
callee_type,
arg_types,
arg_kinds,
arg_names,
formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function(),
Expand Down Expand Up @@ -2141,6 +2143,7 @@ def infer_function_type_arguments_pass2(
callee_type,
arg_types,
arg_kinds,
arg_names,
formal_to_actual,
context=self.argument_infer_context(),
)
Expand All @@ -2153,7 +2156,12 @@ def argument_infer_context(self) -> ArgumentInferContext:
)

def get_arg_infer_passes(
self, arg_types: list[Type], formal_to_actual: list[list[int]], num_actuals: int
self,
callee: CallableType,
args: list[Expression],
arg_types: list[Type],
formal_to_actual: list[list[int]],
num_actuals: int,
) -> list[int]:
"""Return pass numbers for args for two-pass argument type inference.

Expand All @@ -2164,8 +2172,28 @@ def get_arg_infer_passes(
lambdas more effectively.
"""
res = [1] * num_actuals
for i, arg in enumerate(arg_types):
if arg.accept(ArgInferSecondPassQuery()):
for i, arg in enumerate(callee.arg_types):
skip_param_spec = False
p_formal = get_proper_type(callee.arg_types[i])
if isinstance(p_formal, CallableType) and p_formal.param_spec():
for j in formal_to_actual[i]:
p_actual = get_proper_type(arg_types[j])
# This is an exception from the usual logic where we put generic Callable
# arguments in the second pass. If we have a non-generic actual, it is
# likely to infer good constraints, for example if we have:
# def run(Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ...
# def test(x: int, y: int) -> int: ...
# run(test, 1, 2)
# we will use `test` for inference, since it will allow to infer also
# argument *names* for P <: [x: int, y: int].
if (
isinstance(p_actual, CallableType)
and not p_actual.variables
and not isinstance(args[j], LambdaExpr)
):
skip_param_spec = True
break
if not skip_param_spec and arg.accept(ArgInferSecondPassQuery()):
for j in formal_to_actual[i]:
res[j] = 2
return res
Expand Down Expand Up @@ -4897,7 +4925,9 @@ def infer_lambda_type_using_context(
self.chk.fail(message_registry.CANNOT_INFER_LAMBDA_TYPE, e)
return None, None

return callable_ctx, callable_ctx
# Type of lambda must have correct argument names, to prevent false
# negatives when lambdas appear in `ParamSpec` context.
return callable_ctx.copy_modified(arg_names=e.arg_names), callable_ctx

def visit_super_expr(self, e: SuperExpr) -> Type:
"""Type check a super expression (non-lvalue)."""
Expand Down Expand Up @@ -5915,6 +5945,7 @@ def __init__(self) -> None:
super().__init__(types.ANY_STRATEGY)

def visit_callable_type(self, t: CallableType) -> bool:
# TODO: we need to check only for type variables of original callable.
return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery())


Expand Down
138 changes: 98 additions & 40 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def infer_constraints_for_callable(
callee: CallableType,
arg_types: Sequence[Type | None],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
context: ArgumentInferContext,
) -> list[Constraint]:
Expand All @@ -117,6 +118,20 @@ def infer_constraints_for_callable(
constraints: list[Constraint] = []
mapper = ArgTypeExpander(context)

param_spec = callee.param_spec()
param_spec_arg_types = []
param_spec_arg_names = []
param_spec_arg_kinds = []

incomplete_star_mapping = False
for i, actuals in enumerate(formal_to_actual):
for actual in actuals:
if actual is None and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2):
# We can't use arguments to infer ParamSpec constraint, if only some
# are present in the current inference pass.
incomplete_star_mapping = True
break

for i, actuals in enumerate(formal_to_actual):
if isinstance(callee.arg_types[i], UnpackType):
unpack_type = callee.arg_types[i]
Expand Down Expand Up @@ -176,11 +191,47 @@ def infer_constraints_for_callable(
actual_type = mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)
# TODO: if callee has ParamSpec, we need to collect all actuals that map to star
# args and create single constraint between P and resulting Parameters instead.
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
constraints.extend(c)

if (
param_spec
and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2)
and not incomplete_star_mapping
):
# If actual arguments are mapped to ParamSpec type, we can't infer individual
# constraints, instead store them and infer single constraint at the end.
# It is impossible to map actual kind to formal kind, so use some heuristic.
# This inference is used as a fallback, so relying on heuristic should be OK.
param_spec_arg_types.append(
mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
)
)
actual_kind = arg_kinds[actual]
param_spec_arg_kinds.append(
ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind
)
param_spec_arg_names.append(arg_names[actual] if arg_names else None)
else:
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
constraints.extend(c)
if (
param_spec
and not any(c.type_var == param_spec.id for c in constraints)
and not incomplete_star_mapping
):
# Use ParamSpec constraint from arguments only if there are no other constraints,
# since as explained above it is quite ad-hoc.
constraints.append(
Constraint(
param_spec,
SUPERTYPE_OF,
Parameters(
arg_types=param_spec_arg_types,
arg_kinds=param_spec_arg_kinds,
arg_names=param_spec_arg_names,
imprecise_arg_kinds=True,
),
)
)
return constraints


Expand Down Expand Up @@ -923,6 +974,14 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
res: list[Constraint] = []
cactual = self.actual.with_unpacked_kwargs()
param_spec = template.param_spec()

template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
if template.type_guard is not None:
template_ret_type = template.type_guard
if cactual.type_guard is not None:
cactual_ret_type = cactual.type_guard
res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))

if param_spec is None:
# TODO: Erase template variables if it is generic?
if (
Expand Down Expand Up @@ -994,34 +1053,6 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
)
extra_tvars = True

if not cactual_ps:
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
prefix_len = min(prefix_len, max_prefix_len)
res.append(
Constraint(
param_spec,
neg_op(self.direction),
Parameters(
arg_types=cactual.arg_types[prefix_len:],
arg_kinds=cactual.arg_kinds[prefix_len:],
arg_names=cactual.arg_names[prefix_len:],
variables=cactual.variables
if not type_state.infer_polymorphic
else [],
),
)
)
else:
if len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types):
cactual_ps = cactual_ps.copy_modified(
prefix=Parameters(
arg_types=cactual_ps.prefix.arg_types[prefix_len:],
arg_kinds=cactual_ps.prefix.arg_kinds[prefix_len:],
arg_names=cactual_ps.prefix.arg_names[prefix_len:],
)
)
res.append(Constraint(param_spec, neg_op(self.direction), cactual_ps))

# Compare prefixes as well
cactual_prefix = cactual.copy_modified(
arg_types=cactual.arg_types[:prefix_len],
Expand All @@ -1034,13 +1065,40 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
continue
res.extend(infer_constraints(t, a, neg_op(self.direction)))

template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
if template.type_guard is not None:
template_ret_type = template.type_guard
if cactual.type_guard is not None:
cactual_ret_type = cactual.type_guard

res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))
param_spec_target: Type | None = None
skip_imprecise = (
any(c.type_var == param_spec.id for c in res) and cactual.imprecise_arg_kinds
)
if not cactual_ps:
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
prefix_len = min(prefix_len, max_prefix_len)
# This logic matches top-level callable constraint exception, if we managed
# to get other constraints for ParamSpec, don't infer one with imprecise kinds
if not skip_imprecise:
param_spec_target = Parameters(
arg_types=cactual.arg_types[prefix_len:],
arg_kinds=cactual.arg_kinds[prefix_len:],
arg_names=cactual.arg_names[prefix_len:],
variables=cactual.variables
if not type_state.infer_polymorphic
else [],
imprecise_arg_kinds=cactual.imprecise_arg_kinds,
)
else:
if (
len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types)
and not skip_imprecise
):
param_spec_target = cactual_ps.copy_modified(
prefix=Parameters(
arg_types=cactual_ps.prefix.arg_types[prefix_len:],
arg_kinds=cactual_ps.prefix.arg_kinds[prefix_len:],
arg_names=cactual_ps.prefix.arg_names[prefix_len:],
imprecise_arg_kinds=cactual_ps.prefix.imprecise_arg_kinds,
)
)
if param_spec_target is not None:
res.append(Constraint(param_spec, neg_op(self.direction), param_spec_target))
if extra_tvars:
for c in res:
c.extra_tvars += cactual.variables
Expand Down
2 changes: 2 additions & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
arg_types=self.expand_types(t.arg_types),
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
)
elif isinstance(repl, ParamSpecType):
# We're substituting one ParamSpec for another; this can mean that the prefix
Expand All @@ -424,6 +425,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:],
ret_type=t.ret_type.accept(self),
from_concatenate=t.from_concatenate or bool(repl.prefix.arg_types),
imprecise_arg_kinds=(t.imprecise_arg_kinds or prefix.imprecise_arg_kinds),
)

var_arg = t.var_arg()
Expand Down
3 changes: 2 additions & 1 deletion mypy/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def infer_function_type_arguments(
callee_type: CallableType,
arg_types: Sequence[Type | None],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
context: ArgumentInferContext,
strict: bool = True,
Expand All @@ -53,7 +54,7 @@ def infer_function_type_arguments(
"""
# Infer constraints.
constraints = infer_constraints_for_callable(
callee_type, arg_types, arg_kinds, formal_to_actual, context
callee_type, arg_types, arg_kinds, arg_names, formal_to_actual, context
)

# Solve constraints.
Expand Down
Loading