diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 4462cd755ef9f..ab0319c0163ca 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,4 +1,4 @@ -use infer::TypeInferenceBuilder; +use infer::TypeInferenceContext; use ruff_db::files::File; use ruff_python_ast as ast; @@ -482,14 +482,14 @@ impl<'db> Type<'db> { /// /// Returns `None` if `self` is not a callable type. #[must_use] - pub fn call(&self, db: &'db dyn Db) -> Option> { + fn call(&self, db: &'db dyn Db, _context: &mut TypeInferenceContext<'db>) -> Option> { match self { Type::Function(function_type) => Some(function_type.return_type(db)), // TODO annotated return type on `__new__` or metaclass `__call__` Type::Class(class) => Some(Type::Instance(*class)), - // TODO: handle classes which implement the Callable protocol + // TODO: handle classes which implement `__call__` Type::Instance(_instance_ty) => Some(Type::Unknown), // `Any` is callable, and its return type is also `Any`. @@ -497,7 +497,7 @@ impl<'db> Type<'db> { Type::Unknown => Some(Type::Unknown), - // TODO: union and intersection types, if they reduce to `Callable` + // TODO: union and intersection types Type::Union(_) => Some(Type::Unknown), Type::Intersection(_) => Some(Type::Unknown), @@ -513,11 +513,14 @@ impl<'db> Type<'db> { /// for y in x: /// pass /// ``` - fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> { + /// Return None and emit a diagnostic if this type is not iterable. + fn iterate( + &self, + db: &'db dyn Db, + context: &mut TypeInferenceContext<'db>, + ) -> Option> { if let Type::Tuple(tuple_type) = self { - return IterationOutcome::Iterable { - element_ty: UnionType::from_elements(db, &**tuple_type.elements(db)), - }; + return Some(UnionType::from_elements(db, &**tuple_type.elements(db))); } // `self` represents the type of the iterable; @@ -526,19 +529,16 @@ impl<'db> Type<'db> { let dunder_iter_method = iterable_meta_type.member(db, "__iter__"); if !dunder_iter_method.is_unbound() { - let Some(iterator_ty) = dunder_iter_method.call(db) else { - return IterationOutcome::NotIterable { - not_iterable_ty: *self, - }; + let Some(iterator_ty) = dunder_iter_method.call(db, context) else { + context.not_iterable_diagnostic(*self); + return None; }; let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__"); - return dunder_next_method - .call(db) - .map(|element_ty| IterationOutcome::Iterable { element_ty }) - .unwrap_or(IterationOutcome::NotIterable { - not_iterable_ty: *self, - }); + return dunder_next_method.call(db, context).or_else(|| { + context.not_iterable_diagnostic(*self); + None + }); } // Although it's not considered great practice, @@ -549,12 +549,10 @@ impl<'db> Type<'db> { // accepting `int` or `SupportsIndex` let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__"); - dunder_get_item_method - .call(db) - .map(|element_ty| IterationOutcome::Iterable { element_ty }) - .unwrap_or(IterationOutcome::NotIterable { - not_iterable_ty: *self, - }) + dunder_get_item_method.call(db, context).or_else(|| { + context.not_iterable_diagnostic(*self); + None + }) } #[must_use] @@ -619,28 +617,6 @@ impl<'db> From<&Type<'db>> for Type<'db> { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum IterationOutcome<'db> { - Iterable { element_ty: Type<'db> }, - NotIterable { not_iterable_ty: Type<'db> }, -} - -impl<'db> IterationOutcome<'db> { - fn unwrap_with_diagnostic( - self, - iterable_node: ast::AnyNodeRef, - inference_builder: &mut TypeInferenceBuilder<'db>, - ) -> Type<'db> { - match self { - Self::Iterable { element_ty } => element_ty, - Self::NotIterable { not_iterable_ty } => { - inference_builder.not_iterable_diagnostic(iterable_node, not_iterable_ty); - Type::Unknown - } - } - } -} - #[salsa::interned] pub struct FunctionType<'db> { /// name of the function at definition diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index a38e6cc6bd6b2..71b334c972443 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -34,7 +34,7 @@ use salsa::plumbing::AsId; use ruff_db::files::File; use ruff_db::parsed::parsed_module; -use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp}; +use ruff_python_ast::{self as ast, ExprContext, UnaryOp}; use ruff_text_size::Ranged; use crate::module_name::ModuleName; @@ -271,8 +271,7 @@ pub(super) struct TypeInferenceBuilder<'db> { file: File, scope: ScopeId<'db>, - /// The type inference results - types: TypeInference<'db>, + context: TypeInferenceContext<'db>, } impl<'db> TypeInferenceBuilder<'db> { @@ -296,6 +295,14 @@ impl<'db> TypeInferenceBuilder<'db> { InferenceRegion::Scope(scope) => (scope.file(db), scope), }; + let context = TypeInferenceContext { + db, + file, + scope, + node: None, + types: TypeInference::default(), + }; + Self { db, index, @@ -304,20 +311,10 @@ impl<'db> TypeInferenceBuilder<'db> { file, scope, - types: TypeInference::default(), + context, } } - fn extend(&mut self, inference: &TypeInference<'db>) { - self.types.bindings.extend(inference.bindings.iter()); - self.types - .declarations - .extend(inference.declarations.iter()); - self.types.expressions.extend(inference.expressions.iter()); - self.types.diagnostics.extend(&inference.diagnostics); - self.types.has_deferred |= inference.has_deferred; - } - /// Are we currently inferring types in a stub file? fn is_stub(&self) -> bool { self.file.is_stub(self.db.upcast()) @@ -368,21 +365,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - if self.types.has_deferred { - let mut deferred_expression_types: FxHashMap> = - FxHashMap::default(); - // invariant: only annotations and base classes are deferred, and both of these only - // occur within a declaration (annotated assignment, function or class definition) - for definition in self.types.declarations.keys() { - if infer_definition_types(self.db, *definition).has_deferred { - let deferred = infer_deferred_types(self.db, *definition); - deferred_expression_types.extend(deferred.expressions.iter()); - } - } - self.types - .expressions - .extend(deferred_expression_types.iter()); - } + self.context.infer_deferred_types(); } fn infer_region_definition(&mut self, definition: Definition<'db>) { @@ -471,38 +454,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(expression.node_ref(self.db)); } - fn invalid_assignment_diagnostic( - &mut self, - node: AnyNodeRef, - declared_ty: Type<'db>, - assigned_ty: Type<'db>, - ) { - match declared_ty { - Type::Class(class) => { - self.add_diagnostic(node, "invalid-assignment", format_args!( - "Implicit shadowing of class '{}'; annotate to make it explicit if this is intentional.", - class.name(self.db))); - } - Type::Function(function) => { - self.add_diagnostic(node, "invalid-assignment", format_args!( - "Implicit shadowing of function '{}'; annotate to make it explicit if this is intentional.", - function.name(self.db))); - } - _ => { - self.add_diagnostic( - node, - "invalid-assignment", - format_args!( - "Object of type '{}' is not assignable to '{}'.", - assigned_ty.display(self.db), - declared_ty.display(self.db), - ), - ); - } - } - } - - fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) { + fn add_binding(&mut self, binding: Definition<'db>, ty: Type<'db>) { debug_assert!(binding.is_binding(self.db)); let use_def = self.index.use_def_map(binding.file_scope(self.db)); let declarations = use_def.declarations_at_binding(binding); @@ -512,8 +464,7 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO point out the conflicting declarations in the diagnostic? let symbol_table = self.index.symbol_table(binding.file_scope(self.db)); let symbol_name = symbol_table.symbol(binding.symbol(self.db)).name(); - self.add_diagnostic( - node, + self.context.add_diagnostic( "conflicting-declarations", format_args!( "Conflicting declared types for '{symbol_name}': {}.", @@ -523,15 +474,16 @@ impl<'db> TypeInferenceBuilder<'db> { ty }); if !bound_ty.is_assignable_to(self.db, declared_ty) { - self.invalid_assignment_diagnostic(node, declared_ty, bound_ty); + self.context + .invalid_assignment_diagnostic(declared_ty, bound_ty); // allow declarations to override inference in case of invalid assignment bound_ty = declared_ty; }; - self.types.bindings.insert(binding, bound_ty); + self.context.add_binding_ty(binding, bound_ty); } - fn add_declaration(&mut self, node: AnyNodeRef, declaration: Definition<'db>, ty: Type<'db>) { + fn add_declaration(&mut self, declaration: Definition<'db>, ty: Type<'db>) { debug_assert!(declaration.is_declaration(self.db)); let use_def = self.index.use_def_map(declaration.file_scope(self.db)); let prior_bindings = use_def.bindings_at_declaration(declaration); @@ -540,8 +492,7 @@ impl<'db> TypeInferenceBuilder<'db> { let ty = if inferred_ty.is_assignable_to(self.db, ty) { ty } else { - self.add_diagnostic( - node, + self.context.add_diagnostic( "invalid-declaration", format_args!( "Cannot declare type '{}' for inferred type '{}'.", @@ -551,12 +502,12 @@ impl<'db> TypeInferenceBuilder<'db> { ); Type::Unknown }; - self.types.declarations.insert(declaration, ty); + + self.context.add_declaration_ty(declaration, ty); } fn add_declaration_with_binding( &mut self, - node: AnyNodeRef, definition: Definition<'db>, declared_ty: Type<'db>, inferred_ty: Type<'db>, @@ -566,19 +517,21 @@ impl<'db> TypeInferenceBuilder<'db> { let inferred_ty = if inferred_ty.is_assignable_to(self.db, declared_ty) { inferred_ty } else { - self.invalid_assignment_diagnostic(node, declared_ty, inferred_ty); + self.context + .invalid_assignment_diagnostic(declared_ty, inferred_ty); // if the assignment is invalid, fall back to assuming the annotation is correct declared_ty }; - self.types.declarations.insert(definition, declared_ty); - self.types.bindings.insert(definition, inferred_ty); + + self.context.add_declaration_ty(definition, declared_ty); + self.context.add_binding_ty(definition, inferred_ty); } - fn infer_module(&mut self, module: &ast::ModModule) { + fn infer_module(&mut self, module: &'db ast::ModModule) { self.infer_body(&module.body); } - fn infer_class_type_params(&mut self, class: &ast::StmtClassDef) { + fn infer_class_type_params(&mut self, class: &'db ast::StmtClassDef) { let type_params = class .type_params .as_deref() @@ -591,11 +544,11 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_class_body(&mut self, class: &ast::StmtClassDef) { + fn infer_class_body(&mut self, class: &'db ast::StmtClassDef) { self.infer_body(&class.body); } - fn infer_function_type_params(&mut self, function: &ast::StmtFunctionDef) { + fn infer_function_type_params(&mut self, function: &'db ast::StmtFunctionDef) { let type_params = function .type_params .as_deref() @@ -607,17 +560,17 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_parameters(&function.parameters); } - fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) { + fn infer_function_body(&mut self, function: &'db ast::StmtFunctionDef) { self.infer_body(&function.body); } - fn infer_body(&mut self, suite: &[ast::Stmt]) { + fn infer_body(&mut self, suite: &'db [ast::Stmt]) { for statement in suite { self.infer_statement(statement); } } - fn infer_statement(&mut self, statement: &ast::Stmt) { + fn infer_statement(&mut self, statement: &'db ast::Stmt) { match statement { ast::Stmt::FunctionDef(function) => self.infer_function_definition_statement(function), ast::Stmt::ClassDef(class) => self.infer_class_definition_statement(class), @@ -656,7 +609,7 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_definition(&mut self, node: impl Into) { let definition = self.index.definition(node); let result = infer_definition_types(self.db, definition); - self.extend(result); + self.context.extend(result); } fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) { @@ -665,9 +618,10 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_function_definition( &mut self, - function: &ast::StmtFunctionDef, + function: &'db ast::StmtFunctionDef, definition: Definition<'db>, ) { + self.context.set_node(function); let ast::StmtFunctionDef { range: _, is_async: _, @@ -698,7 +652,7 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO: this should also be applied to parameter annotations. if self.is_stub() { - self.types.has_deferred = true; + self.context.set_has_deferred(); } else { self.infer_optional_annotation_expression(returns.as_deref()); } @@ -711,10 +665,10 @@ impl<'db> TypeInferenceBuilder<'db> { decorator_tys, )); - self.add_declaration_with_binding(function.into(), definition, function_ty, function_ty); + self.add_declaration_with_binding(definition, function_ty, function_ty); } - fn infer_parameters(&mut self, parameters: &ast::Parameters) { + fn infer_parameters(&mut self, parameters: &'db ast::Parameters) { let ast::Parameters { range: _, posonlyargs: _, @@ -735,7 +689,10 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_parameter_with_default(&mut self, parameter_with_default: &ast::ParameterWithDefault) { + fn infer_parameter_with_default( + &mut self, + parameter_with_default: &'db ast::ParameterWithDefault, + ) { let ast::ParameterWithDefault { range: _, parameter, @@ -747,7 +704,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_definition(parameter_with_default); } - fn infer_parameter(&mut self, parameter: &ast::Parameter) { + fn infer_parameter(&mut self, parameter: &'db ast::Parameter) { let ast::Parameter { range: _, name: _, @@ -761,9 +718,10 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_parameter_with_default_definition( &mut self, - parameter_with_default: &ast::ParameterWithDefault, + parameter_with_default: &'db ast::ParameterWithDefault, definition: Definition<'db>, ) { + self.context.set_node(parameter_with_default); // TODO(dhruvmanila): Infer types from annotation or default expression // TODO check that default is assignable to parameter type self.infer_parameter_definition(¶meter_with_default.parameter, definition); @@ -771,21 +729,17 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_parameter_definition( &mut self, - parameter: &ast::Parameter, + parameter: &'db ast::Parameter, definition: Definition<'db>, ) { + self.context.set_node(parameter); // TODO(dhruvmanila): Annotation expression is resolved at the enclosing scope, infer the // parameter type from there let annotated_ty = Type::Unknown; if parameter.annotation.is_some() { - self.add_declaration_with_binding( - parameter.into(), - definition, - annotated_ty, - annotated_ty, - ); + self.add_declaration_with_binding(definition, annotated_ty, annotated_ty); } else { - self.add_binding(parameter.into(), definition, annotated_ty); + self.add_binding(definition, annotated_ty); } } @@ -793,7 +747,12 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_definition(class); } - fn infer_class_definition(&mut self, class: &ast::StmtClassDef, definition: Definition<'db>) { + fn infer_class_definition( + &mut self, + class: &'db ast::StmtClassDef, + definition: Definition<'db>, + ) { + self.context.set_node(class); let ast::StmtClassDef { range: _, name, @@ -819,7 +778,7 @@ impl<'db> TypeInferenceBuilder<'db> { body_scope, )); - self.add_declaration_with_binding(class.into(), definition, class_ty, class_ty); + self.add_declaration_with_binding(definition, class_ty, class_ty); for keyword in class.keywords() { self.infer_expression(&keyword.value); @@ -828,7 +787,7 @@ impl<'db> TypeInferenceBuilder<'db> { // inference of bases deferred in stubs // TODO also defer stringified generic type parameters if self.is_stub() { - self.types.has_deferred = true; + self.context.set_has_deferred(); } else { for base in class.bases() { self.infer_expression(base); @@ -836,13 +795,13 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_function_deferred(&mut self, function: &ast::StmtFunctionDef) { + fn infer_function_deferred(&mut self, function: &'db ast::StmtFunctionDef) { if self.is_stub() { self.infer_optional_annotation_expression(function.returns.as_deref()); } } - fn infer_class_deferred(&mut self, class: &ast::StmtClassDef) { + fn infer_class_deferred(&mut self, class: &'db ast::StmtClassDef) { if self.is_stub() { for base in class.bases() { self.infer_expression(base); @@ -850,7 +809,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) { + fn infer_if_statement(&mut self, if_statement: &'db ast::StmtIf) { let ast::StmtIf { range: _, test, @@ -874,7 +833,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_try_statement(&mut self, try_statement: &ast::StmtTry) { + fn infer_try_statement(&mut self, try_statement: &'db ast::StmtTry) { let ast::StmtTry { range: _, body, @@ -912,7 +871,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_body(finalbody); } - fn infer_with_statement(&mut self, with_statement: &ast::StmtWith) { + fn infer_with_statement(&mut self, with_statement: &'db ast::StmtWith) { let ast::StmtWith { range: _, is_async: _, @@ -937,29 +896,26 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_with_item_definition( &mut self, - target: &ast::ExprName, + target: &'db ast::ExprName, with_item: &ast::WithItem, definition: Definition<'db>, ) { + self.context.set_node(target); let expression = self.index.expression(&with_item.context_expr); let result = infer_expression_types(self.db, expression); - self.extend(result); + self.context.extend(result); // TODO(dhruvmanila): The correct type inference here is the return type of the __enter__ // method of the context manager. - let context_expr_ty = self - .types - .expression_ty(with_item.context_expr.scoped_ast_id(self.db, self.scope)); + let context_expr_ty = self.context.expression_ty(&with_item.context_expr); - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), context_expr_ty); - self.add_binding(target.into(), definition, context_expr_ty); + self.context.add_expression_ty(target, context_expr_ty); + self.add_binding(definition, context_expr_ty); } fn infer_except_handler_definition( &mut self, - except_handler_definition: &ExceptHandlerDefinitionKind, + except_handler_definition: &'db ExceptHandlerDefinitionKind, definition: Definition<'db>, ) { let node_ty = except_handler_definition @@ -984,14 +940,10 @@ impl<'db> TypeInferenceBuilder<'db> { } }; - self.add_binding( - except_handler_definition.node().into(), - definition, - symbol_ty, - ); + self.add_binding(definition, symbol_ty); } - fn infer_match_statement(&mut self, match_statement: &ast::StmtMatch) { + fn infer_match_statement(&mut self, match_statement: &'db ast::StmtMatch) { let ast::StmtMatch { range: _, subject, @@ -1000,7 +952,7 @@ impl<'db> TypeInferenceBuilder<'db> { let expression = self.index.expression(subject.as_ref()); let result = infer_expression_types(self.db, expression); - self.extend(result); + self.context.extend(result); for case in cases { let ast::MatchCase { @@ -1017,18 +969,19 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_match_pattern_definition( &mut self, - pattern: &ast::Pattern, + pattern: &'db ast::Pattern, _index: u32, definition: Definition<'db>, ) { + self.context.set_node(pattern); // TODO(dhruvmanila): The correct way to infer types here is to perform structural matching // against the subject expression type (which we can query via `infer_expression_types`) // and extract the type at the `index` position if the pattern matches. This will be // similar to the logic in `self.infer_assignment_definition`. - self.add_binding(pattern.into(), definition, Type::Unknown); + self.add_binding(definition, Type::Unknown); } - fn infer_match_pattern(&mut self, pattern: &ast::Pattern) { + fn infer_match_pattern(&mut self, pattern: &'db ast::Pattern) { // TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against // the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510 match pattern { @@ -1082,7 +1035,7 @@ impl<'db> TypeInferenceBuilder<'db> { }; } - fn infer_assignment_statement(&mut self, assignment: &ast::StmtAssign) { + fn infer_assignment_statement(&mut self, assignment: &'db ast::StmtAssign) { let ast::StmtAssign { range: _, targets, @@ -1097,7 +1050,8 @@ impl<'db> TypeInferenceBuilder<'db> { // the "get `Expression`, call `infer_expression_types` on it, `self.extend`" dance // will be removed; it'll all happen in `infer_assignment_definition` instead. let expression = self.index.expression(value.as_ref()); - self.extend(infer_expression_types(self.db, expression)); + self.context + .extend(infer_expression_types(self.db, expression)); self.infer_expression(target); } } @@ -1105,23 +1059,20 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_assignment_definition( &mut self, - target: &ast::ExprName, + target: &'db ast::ExprName, assignment: &ast::StmtAssign, definition: Definition<'db>, ) { + self.context.set_node(target); let expression = self.index.expression(assignment.value.as_ref()); let result = infer_expression_types(self.db, expression); - self.extend(result); - let value_ty = self - .types - .expression_ty(assignment.value.scoped_ast_id(self.db, self.scope)); - self.add_binding(assignment.into(), definition, value_ty); - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), value_ty); + self.context.extend(result); + let value_ty = self.context.expression_ty(&assignment.value); + self.add_binding(definition, value_ty); + self.context.add_expression_ty(target, value_ty); } - fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) { + fn infer_annotated_assignment_statement(&mut self, assignment: &'db ast::StmtAnnAssign) { // assignments to non-Names are not Definitions if matches!(*assignment.target, ast::Expr::Name(_)) { self.infer_definition(assignment); @@ -1141,7 +1092,7 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_annotated_assignment_definition( &mut self, - assignment: &ast::StmtAnnAssign, + assignment: &'db ast::StmtAnnAssign, definition: Definition<'db>, ) { let ast::StmtAnnAssign { @@ -1151,24 +1102,20 @@ impl<'db> TypeInferenceBuilder<'db> { value, simple: _, } = assignment; + self.context.set_node(target.as_ref()); let annotation_ty = self.infer_annotation_expression(annotation); if let Some(value) = value { let value_ty = self.infer_expression(value); - self.add_declaration_with_binding( - assignment.into(), - definition, - annotation_ty, - value_ty, - ); + self.add_declaration_with_binding(definition, annotation_ty, value_ty); } else { - self.add_declaration(assignment.into(), definition, annotation_ty); + self.add_declaration(definition, annotation_ty); } self.infer_expression(target); } - fn infer_augmented_assignment_statement(&mut self, assignment: &ast::StmtAugAssign) { + fn infer_augmented_assignment_statement(&mut self, assignment: &'db ast::StmtAugAssign) { if assignment.target.is_name_expr() { self.infer_definition(assignment); } else { @@ -1179,14 +1126,14 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_augment_assignment_definition( &mut self, - assignment: &ast::StmtAugAssign, + assignment: &'db ast::StmtAugAssign, definition: Definition<'db>, ) { let target_ty = self.infer_augment_assignment(assignment); - self.add_binding(assignment.into(), definition, target_ty); + self.add_binding(definition, target_ty); } - fn infer_augment_assignment(&mut self, assignment: &ast::StmtAugAssign) -> Type<'db> { + fn infer_augment_assignment(&mut self, assignment: &'db ast::StmtAugAssign) -> Type<'db> { let ast::StmtAugAssign { range: _, target, @@ -1200,7 +1147,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) { + fn infer_type_alias_statement(&mut self, type_alias_statement: &'db ast::StmtTypeAlias) { let ast::StmtTypeAlias { range: _, name, @@ -1214,7 +1161,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_for_statement(&mut self, for_statement: &ast::StmtFor) { + fn infer_for_statement(&mut self, for_statement: &'db ast::StmtFor) { let ast::StmtFor { range: _, target, @@ -1235,48 +1182,33 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_body(orelse); } - /// Emit a diagnostic declaring that the object represented by `node` is not iterable - pub(super) fn not_iterable_diagnostic(&mut self, node: AnyNodeRef, not_iterable_ty: Type<'db>) { - self.add_diagnostic( - node, - "not-iterable", - format_args!( - "Object of type '{}' is not iterable", - not_iterable_ty.display(self.db) - ), - ); - } - fn infer_for_statement_definition( &mut self, - target: &ast::ExprName, + target: &'db ast::ExprName, iterable: &ast::Expr, is_async: bool, definition: Definition<'db>, ) { + self.context.set_node(target); let expression = self.index.expression(iterable); let result = infer_expression_types(self.db, expression); - self.extend(result); - let iterable_ty = self - .types - .expression_ty(iterable.scoped_ast_id(self.db, self.scope)); + self.context.extend(result); + let iterable_ty = self.context.expression_ty(iterable); let loop_var_value_ty = if is_async { // TODO(Alex): async iterables/iterators! Type::Unknown } else { iterable_ty - .iterate(self.db) - .unwrap_with_diagnostic(iterable.into(), self) + .iterate(self.db, &mut self.context) + .unwrap_or(Type::Unknown) }; - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), loop_var_value_ty); - self.add_binding(target.into(), definition, loop_var_value_ty); + self.context.add_expression_ty(target, loop_var_value_ty); + self.add_binding(definition, loop_var_value_ty); } - fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) { + fn infer_while_statement(&mut self, while_statement: &'db ast::StmtWhile) { let ast::StmtWhile { range: _, test, @@ -1298,6 +1230,7 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_import_definition(&mut self, alias: &'db ast::Alias, definition: Definition<'db>) { + self.context.set_node(alias); let ast::Alias { range: _, name, @@ -1308,7 +1241,7 @@ impl<'db> TypeInferenceBuilder<'db> { if let Some(module) = self.module_ty_from_name(module_name) { module } else { - self.unresolved_module_diagnostic(alias, 0, Some(name)); + self.context.unresolved_module_diagnostic(0, Some(name)); Type::Unknown } } else { @@ -1316,7 +1249,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown }; - self.add_binding(alias.into(), definition, module_ty); + self.add_binding(definition, module_ty); } fn infer_import_from_statement(&mut self, import: &ast::StmtImportFrom) { @@ -1332,7 +1265,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_assert_statement(&mut self, assert: &ast::StmtAssert) { + fn infer_assert_statement(&mut self, assert: &'db ast::StmtAssert) { let ast::StmtAssert { range: _, test, @@ -1343,7 +1276,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_optional_expression(msg.as_deref()); } - fn infer_raise_statement(&mut self, raise: &ast::StmtRaise) { + fn infer_raise_statement(&mut self, raise: &'db ast::StmtRaise) { let ast::StmtRaise { range: _, exc, @@ -1353,23 +1286,6 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_optional_expression(cause.as_deref()); } - fn unresolved_module_diagnostic( - &mut self, - import_node: impl Into>, - level: u32, - module: Option<&str>, - ) { - self.add_diagnostic( - import_node.into(), - "unresolved-import", - format_args!( - "Cannot resolve import '{}{}'.", - ".".repeat(level as usize), - module.unwrap_or_default() - ), - ); - } - /// Given a `from .foo import bar` relative import, resolve the relative module /// we're importing `bar` from into an absolute [`ModuleName`] /// using the name of the module we're currently analyzing. @@ -1407,9 +1323,10 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_import_from_definition( &mut self, import_from: &'db ast::StmtImportFrom, - alias: &ast::Alias, + alias: &'db ast::Alias, definition: Definition<'db>, ) { + self.context.set_node(alias); // TODO: // - Absolute `*` imports (`from collections import *`) // - Relative `*` imports (`from ...foo import *`) @@ -1446,7 +1363,7 @@ impl<'db> TypeInferenceBuilder<'db> { if let Some(ty) = self.module_ty_from_name(name) { ty } else { - self.unresolved_module_diagnostic(import_from, *level, module); + self.context.unresolved_module_diagnostic(*level, module); Type::Unknown } } @@ -1460,7 +1377,7 @@ impl<'db> TypeInferenceBuilder<'db> { "Relative module resolution '{}' failed: too many leading dots", format_import_from_module(*level, module), ); - self.unresolved_module_diagnostic(import_from, *level, module); + self.context.unresolved_module_diagnostic(*level, module); Type::Unknown } Err(ModuleNameResolutionError::UnknownCurrentModule) => { @@ -1469,7 +1386,7 @@ impl<'db> TypeInferenceBuilder<'db> { format_import_from_module(*level, module), self.file.path(self.db) ); - self.unresolved_module_diagnostic(import_from, *level, module); + self.context.unresolved_module_diagnostic(*level, module); Type::Unknown } }; @@ -1484,8 +1401,7 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO: What if it's a union where one of the elements is `Unbound`? if member_ty.is_unbound() { - self.add_diagnostic( - AnyNodeRef::Alias(alias), + self.context.add_diagnostic( "unresolved-import", format_args!( "Module '{}{}' has no member '{name}'", @@ -1501,17 +1417,16 @@ impl<'db> TypeInferenceBuilder<'db> { // as would be the case for a symbol with type `Unbound`), so it's appropriate to // think of the type of the imported symbol as `Unknown` rather than `Unbound` self.add_binding( - alias.into(), definition, member_ty.replace_unbound_with(self.db, Type::Unknown), ); } - fn infer_return_statement(&mut self, ret: &ast::StmtReturn) { + fn infer_return_statement(&mut self, ret: &'db ast::StmtReturn) { self.infer_optional_expression(ret.value.as_deref()); } - fn infer_delete_statement(&mut self, delete: &ast::StmtDelete) { + fn infer_delete_statement(&mut self, delete: &'db ast::StmtDelete) { let ast::StmtDelete { range: _, targets } = delete; for target in targets { self.infer_expression(target); @@ -1522,7 +1437,7 @@ impl<'db> TypeInferenceBuilder<'db> { resolve_module(self.db, module_name).map(|module| Type::Module(module.file())) } - fn infer_decorator(&mut self, decorator: &ast::Decorator) -> Type<'db> { + fn infer_decorator(&mut self, decorator: &'db ast::Decorator) -> Type<'db> { let ast::Decorator { range: _, expression, @@ -1531,7 +1446,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(expression) } - fn infer_arguments(&mut self, arguments: &ast::Arguments) -> Vec> { + fn infer_arguments(&mut self, arguments: &'db ast::Arguments) -> Vec> { let mut types = Vec::with_capacity( arguments .args @@ -1552,18 +1467,22 @@ impl<'db> TypeInferenceBuilder<'db> { types } - fn infer_optional_expression(&mut self, expression: Option<&ast::Expr>) -> Option> { + fn infer_optional_expression( + &mut self, + expression: Option<&'db ast::Expr>, + ) -> Option> { expression.map(|expr| self.infer_expression(expr)) } fn infer_optional_annotation_expression( &mut self, - expr: Option<&ast::Expr>, + expr: Option<&'db ast::Expr>, ) -> Option> { expr.map(|expr| self.infer_annotation_expression(expr)) } - fn infer_expression(&mut self, expression: &ast::Expr) -> Type<'db> { + fn infer_expression(&mut self, expression: &'db ast::Expr) -> Type<'db> { + let prev_node = self.context.set_node(expression); let ty = match expression { ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::None, ast::Expr::NumberLiteral(literal) => self.infer_number_literal_expression(literal), @@ -1600,10 +1519,9 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Expr::Await(await_expression) => self.infer_await_expression(await_expression), ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), }; + self.context.restore_node(prev_node); - let expr_id = expression.scoped_ast_id(self.db, self.scope); - let previous = self.types.expressions.insert(expr_id, ty); - assert_eq!(previous, None); + self.context.add_expression_ty(expression, ty); ty } @@ -1649,7 +1567,7 @@ impl<'db> TypeInferenceBuilder<'db> { )) } - fn infer_fstring_expression(&mut self, fstring: &ast::ExprFString) -> Type<'db> { + fn infer_fstring_expression(&mut self, fstring: &'db ast::ExprFString) -> Type<'db> { let ast::ExprFString { range: _, value } = fstring; for part in value { @@ -1674,7 +1592,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_fstring_element(&mut self, element: &ast::FStringElement) { + fn infer_fstring_element(&mut self, element: &'db ast::FStringElement) { match element { ast::FStringElement::Literal(_) => { // TODO string literal type @@ -1705,7 +1623,7 @@ impl<'db> TypeInferenceBuilder<'db> { builtins_symbol_ty(self.db, "Ellipsis") } - fn infer_tuple_expression(&mut self, tuple: &ast::ExprTuple) -> Type<'db> { + fn infer_tuple_expression(&mut self, tuple: &'db ast::ExprTuple) -> Type<'db> { let ast::ExprTuple { range: _, elts, @@ -1721,7 +1639,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Tuple(TupleType::new(self.db, element_types.into_boxed_slice())) } - fn infer_list_expression(&mut self, list: &ast::ExprList) -> Type<'db> { + fn infer_list_expression(&mut self, list: &'db ast::ExprList) -> Type<'db> { let ast::ExprList { range: _, elts, @@ -1736,7 +1654,7 @@ impl<'db> TypeInferenceBuilder<'db> { builtins_symbol_ty(self.db, "list").to_instance(self.db) } - fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> { + fn infer_set_expression(&mut self, set: &'db ast::ExprSet) -> Type<'db> { let ast::ExprSet { range: _, elts } = set; for elt in elts { @@ -1747,7 +1665,7 @@ impl<'db> TypeInferenceBuilder<'db> { builtins_symbol_ty(self.db, "set").to_instance(self.db) } - fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> { + fn infer_dict_expression(&mut self, dict: &'db ast::ExprDict) -> Type<'db> { let ast::ExprDict { range: _, items } = dict; for item in items { @@ -1760,7 +1678,7 @@ impl<'db> TypeInferenceBuilder<'db> { } /// Infer the type of the `iter` expression of the first comprehension. - fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) { + fn infer_first_comprehension_iter(&mut self, comprehensions: &'db [ast::Comprehension]) { let mut comprehensions_iter = comprehensions.iter(); let Some(first_comprehension) = comprehensions_iter.next() else { unreachable!("Comprehension must contain at least one generator"); @@ -1768,7 +1686,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(&first_comprehension.iter); } - fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> { + fn infer_generator_expression(&mut self, generator: &'db ast::ExprGenerator) -> Type<'db> { let ast::ExprGenerator { range: _, elt: _, @@ -1782,7 +1700,10 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_list_comprehension_expression(&mut self, listcomp: &ast::ExprListComp) -> Type<'db> { + fn infer_list_comprehension_expression( + &mut self, + listcomp: &'db ast::ExprListComp, + ) -> Type<'db> { let ast::ExprListComp { range: _, elt: _, @@ -1795,7 +1716,10 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_dict_comprehension_expression(&mut self, dictcomp: &ast::ExprDictComp) -> Type<'db> { + fn infer_dict_comprehension_expression( + &mut self, + dictcomp: &'db ast::ExprDictComp, + ) -> Type<'db> { let ast::ExprDictComp { range: _, key: _, @@ -1809,7 +1733,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_set_comprehension_expression(&mut self, setcomp: &ast::ExprSetComp) -> Type<'db> { + fn infer_set_comprehension_expression(&mut self, setcomp: &'db ast::ExprSetComp) -> Type<'db> { let ast::ExprSetComp { range: _, elt: _, @@ -1822,7 +1746,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_generator_expression_scope(&mut self, generator: &ast::ExprGenerator) { + fn infer_generator_expression_scope(&mut self, generator: &'db ast::ExprGenerator) { let ast::ExprGenerator { range: _, elt, @@ -1834,7 +1758,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_comprehensions(generators); } - fn infer_list_comprehension_expression_scope(&mut self, listcomp: &ast::ExprListComp) { + fn infer_list_comprehension_expression_scope(&mut self, listcomp: &'db ast::ExprListComp) { let ast::ExprListComp { range: _, elt, @@ -1845,7 +1769,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_comprehensions(generators); } - fn infer_dict_comprehension_expression_scope(&mut self, dictcomp: &ast::ExprDictComp) { + fn infer_dict_comprehension_expression_scope(&mut self, dictcomp: &'db ast::ExprDictComp) { let ast::ExprDictComp { range: _, key, @@ -1858,7 +1782,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_comprehensions(generators); } - fn infer_set_comprehension_expression_scope(&mut self, setcomp: &ast::ExprSetComp) { + fn infer_set_comprehension_expression_scope(&mut self, setcomp: &'db ast::ExprSetComp) { let ast::ExprSetComp { range: _, elt, @@ -1869,7 +1793,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_comprehensions(generators); } - fn infer_comprehensions(&mut self, comprehensions: &[ast::Comprehension]) { + fn infer_comprehensions(&mut self, comprehensions: &'db [ast::Comprehension]) { let mut comprehensions_iter = comprehensions.iter(); let Some(first_comprehension) = comprehensions_iter.next() else { unreachable!("Comprehension must contain at least one generator"); @@ -1880,7 +1804,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_comprehension(&mut self, comprehension: &ast::Comprehension, is_first: bool) { + fn infer_comprehension(&mut self, comprehension: &'db ast::Comprehension, is_first: bool) { let ast::Comprehension { range: _, target, @@ -1906,11 +1830,12 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_comprehension_definition( &mut self, iterable: &ast::Expr, - target: &ast::ExprName, + target: &'db ast::ExprName, is_first: bool, is_async: bool, definition: Definition<'db>, ) { + self.context.set_node(target); let expression = self.index.expression(iterable); let result = infer_expression_types(self.db, expression); @@ -1928,7 +1853,7 @@ impl<'db> TypeInferenceBuilder<'db> { .to_scope_id(self.db, self.file); result.expression_ty(iterable.scoped_ast_id(self.db, lookup_scope)) } else { - self.extend(result); + self.context.extend(result); result.expression_ty(iterable.scoped_ast_id(self.db, self.scope)) }; @@ -1937,26 +1862,24 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } else { iterable_ty - .iterate(self.db) - .unwrap_with_diagnostic(iterable.into(), self) + .iterate(self.db, &mut self.context) + .unwrap_or(Type::Unknown) }; - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), target_ty); - self.add_binding(target.into(), definition, target_ty); + self.context.add_expression_ty(target, target_ty); + self.add_binding(definition, target_ty); } fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type<'db> { let definition = self.index.definition(named); let result = infer_definition_types(self.db, definition); - self.extend(result); + self.context.extend(result); result.binding_ty(definition) } fn infer_named_expression_definition( &mut self, - named: &ast::ExprNamed, + named: &'db ast::ExprNamed, definition: Definition<'db>, ) -> Type<'db> { let ast::ExprNamed { @@ -1964,16 +1887,17 @@ impl<'db> TypeInferenceBuilder<'db> { target, value, } = named; + self.context.set_node(target.as_ref()); let value_ty = self.infer_expression(value); self.infer_expression(target); - self.add_binding(named.into(), definition, value_ty); + self.add_binding(definition, value_ty); value_ty } - fn infer_if_expression(&mut self, if_expression: &ast::ExprIf) -> Type<'db> { + fn infer_if_expression(&mut self, if_expression: &'db ast::ExprIf) -> Type<'db> { let ast::ExprIf { range: _, test, @@ -1990,11 +1914,11 @@ impl<'db> TypeInferenceBuilder<'db> { UnionType::from_elements(self.db, [body_ty, orelse_ty]) } - fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) { + fn infer_lambda_body(&mut self, lambda_expression: &'db ast::ExprLambda) { self.infer_expression(&lambda_expression.body); } - fn infer_lambda_expression(&mut self, lambda_expression: &ast::ExprLambda) -> Type<'db> { + fn infer_lambda_expression(&mut self, lambda_expression: &'db ast::ExprLambda) -> Type<'db> { let ast::ExprLambda { range: _, parameters, @@ -2016,7 +1940,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_call_expression(&mut self, call_expression: &ast::ExprCall) -> Type<'db> { + fn infer_call_expression(&mut self, call_expression: &'db ast::ExprCall) -> Type<'db> { let ast::ExprCall { range: _, func, @@ -2025,20 +1949,21 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_arguments(arguments); let function_type = self.infer_expression(func); - function_type.call(self.db).unwrap_or_else(|| { - self.add_diagnostic( - func.as_ref().into(), - "call-non-callable", - format_args!( - "Object of type '{}' is not callable", - function_type.display(self.db) - ), - ); - Type::Unknown - }) + function_type + .call(self.db, &mut self.context) + .unwrap_or_else(|| { + self.context.add_diagnostic( + "call-non-callable", + format_args!( + "Object of type '{}' is not callable", + function_type.display(self.db) + ), + ); + Type::Unknown + }) } - fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> { + fn infer_starred_expression(&mut self, starred: &'db ast::ExprStarred) -> Type<'db> { let ast::ExprStarred { range: _, value, @@ -2047,14 +1972,14 @@ impl<'db> TypeInferenceBuilder<'db> { let iterable_ty = self.infer_expression(value); iterable_ty - .iterate(self.db) - .unwrap_with_diagnostic(value.as_ref().into(), self); + .iterate(self.db, &mut self.context) + .unwrap_or(Type::Unknown); // TODO Type::Unknown } - fn infer_yield_expression(&mut self, yield_expression: &ast::ExprYield) -> Type<'db> { + fn infer_yield_expression(&mut self, yield_expression: &'db ast::ExprYield) -> Type<'db> { let ast::ExprYield { range: _, value } = yield_expression; self.infer_optional_expression(value.as_deref()); @@ -2063,19 +1988,19 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_yield_from_expression(&mut self, yield_from: &ast::ExprYieldFrom) -> Type<'db> { + fn infer_yield_from_expression(&mut self, yield_from: &'db ast::ExprYieldFrom) -> Type<'db> { let ast::ExprYieldFrom { range: _, value } = yield_from; let iterable_ty = self.infer_expression(value); iterable_ty - .iterate(self.db) - .unwrap_with_diagnostic(value.as_ref().into(), self); + .iterate(self.db, &mut self.context) + .unwrap_or(Type::Unknown); // TODO get type from `ReturnType` of generator Type::Unknown } - fn infer_await_expression(&mut self, await_expression: &ast::ExprAwait) -> Type<'db> { + fn infer_await_expression(&mut self, await_expression: &'db ast::ExprAwait) -> Type<'db> { let ast::ExprAwait { range: _, value } = await_expression; self.infer_expression(value); @@ -2139,7 +2064,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> { + fn infer_name_expression(&mut self, name: &'db ast::ExprName) -> Type<'db> { let ast::ExprName { range: _, id, ctx } = name; let file_scope_id = self.scope.file_scope_id(self.db); @@ -2178,7 +2103,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { + fn infer_attribute_expression(&mut self, attribute: &'db ast::ExprAttribute) -> Type<'db> { let ast::ExprAttribute { value, attr, @@ -2196,7 +2121,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_unary_expression(&mut self, unary: &ast::ExprUnaryOp) -> Type<'db> { + fn infer_unary_expression(&mut self, unary: &'db ast::ExprUnaryOp) -> Type<'db> { let ast::ExprUnaryOp { range: _, op, @@ -2209,7 +2134,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_binary_expression(&mut self, binary: &ast::ExprBinOp) -> Type<'db> { + fn infer_binary_expression(&mut self, binary: &'db ast::ExprBinOp) -> Type<'db> { let ast::ExprBinOp { left, op, @@ -2308,7 +2233,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_boolean_expression(&mut self, bool_op: &ast::ExprBoolOp) -> Type<'db> { + fn infer_boolean_expression(&mut self, bool_op: &'db ast::ExprBoolOp) -> Type<'db> { let ast::ExprBoolOp { range: _, op: _, @@ -2323,7 +2248,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_compare_expression(&mut self, compare: &ast::ExprCompare) -> Type<'db> { + fn infer_compare_expression(&mut self, compare: &'db ast::ExprCompare) -> Type<'db> { let ast::ExprCompare { range: _, left, @@ -2339,7 +2264,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> { + fn infer_subscript_expression(&mut self, subscript: &'db ast::ExprSubscript) -> Type<'db> { let ast::ExprSubscript { range: _, value, @@ -2354,7 +2279,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_slice_expression(&mut self, slice: &ast::ExprSlice) -> Type<'db> { + fn infer_slice_expression(&mut self, slice: &'db ast::ExprSlice) -> Type<'db> { let ast::ExprSlice { range: _, lower, @@ -2370,7 +2295,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } - fn infer_type_parameters(&mut self, type_parameters: &ast::TypeParams) { + fn infer_type_parameters(&mut self, type_parameters: &'db ast::TypeParams) { let ast::TypeParams { range: _, type_params, @@ -2407,40 +2332,19 @@ impl<'db> TypeInferenceBuilder<'db> { } } - /// Adds a new diagnostic. - /// - /// The diagnostic does not get added if the rule isn't enabled for this file. - fn add_diagnostic(&mut self, node: AnyNodeRef, rule: &str, message: std::fmt::Arguments) { - if !self.db.is_file_open(self.file) { - return; - } - - // TODO: Don't emit the diagnostic if: - // * The enclosing node contains any syntax errors - // * The rule is disabled for this file. We probably want to introduce a new query that - // returns a rule selector for a given file that respects the package's settings, - // any global pragma comments in the file, and any per-file-ignores. - - self.types.diagnostics.push(TypeCheckDiagnostic { - file: self.file, - rule: rule.to_string(), - message: message.to_string(), - range: node.range(), - }); - } - pub(super) fn finish(mut self) -> TypeInference<'db> { self.infer_region(); - self.types.shrink_to_fit(); - self.types + self.context.types.shrink_to_fit(); + self.context.types } } /// Annotation expressions. impl<'db> TypeInferenceBuilder<'db> { - fn infer_annotation_expression(&mut self, expression: &ast::Expr) -> Type<'db> { + fn infer_annotation_expression(&mut self, expression: &'db ast::Expr) -> Type<'db> { + let prev_node = self.context.set_node(expression); // https://typing.readthedocs.io/en/latest/spec/annotations.html#grammar-token-expression-grammar-annotation_expression - match expression { + let ty = match expression { // TODO: parse the expression and check whether it is a string annotation, since they // can be annotation expressions distinct from type expressions. // https://typing.readthedocs.io/en/latest/spec/annotations.html#string-annotations @@ -2452,17 +2356,19 @@ impl<'db> TypeInferenceBuilder<'db> { // All other annotation expressions are (possibly) valid type expressions, so handle // them there instead. type_expr => self.infer_type_expression(type_expr), - } + }; + self.context.restore_node(prev_node); + ty } } /// Type expressions impl<'db> TypeInferenceBuilder<'db> { - fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> { + fn infer_type_expression(&mut self, expression: &'db ast::Expr) -> Type<'db> { // https://typing.readthedocs.io/en/latest/spec/annotations.html#grammar-token-expression-grammar-type_expression // TODO: this does not include any of the special forms, and is only a // stub of the forms other than a standalone name in scope. - + let prev_node = self.context.set_node(expression); let ty = match expression { ast::Expr::Name(name) => { debug_assert!( @@ -2596,11 +2502,162 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), }; - let expr_id = expression.scoped_ast_id(self.db, self.scope); + self.context.restore_node(prev_node); + self.context.add_expression_ty(expression, ty); + + ty + } +} + +/// All the state for [`TypeInferenceBuilder`]. +pub(super) struct TypeInferenceContext<'db> { + db: &'db dyn Db, + + /// The file and scope we are inferring types in. + file: File, + scope: ScopeId<'db>, + + /// The node we are currently visiting. + node: Option>, + + /// The type inference results. + types: TypeInference<'db>, +} + +impl<'db> TypeInferenceContext<'db> { + fn set_node(&mut self, node: impl Into>) -> Option> { + std::mem::replace(&mut self.node, Some(node.into())) + } + + fn restore_node(&mut self, node: Option>) { + self.node = node; + } + + fn extend(&mut self, inference: &TypeInference<'db>) { + self.types.bindings.extend(inference.bindings.iter()); + self.types + .declarations + .extend(inference.declarations.iter()); + self.types.expressions.extend(inference.expressions.iter()); + self.types.diagnostics.extend(&inference.diagnostics); + self.types.has_deferred |= inference.has_deferred; + } + + fn add_expression_ty( + &mut self, + expr: &dyn HasScopedAstId, + ty: Type<'db>, + ) { + let expr_id = expr.scoped_ast_id(self.db, self.scope); let previous = self.types.expressions.insert(expr_id, ty); assert!(previous.is_none()); + } - ty + fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> { + self.types + .expression_ty(expr.scoped_ast_id(self.db, self.scope)) + } + + fn add_binding_ty(&mut self, binding: Definition<'db>, ty: Type<'db>) { + self.types.bindings.insert(binding, ty); + } + + fn add_declaration_ty(&mut self, declaration: Definition<'db>, ty: Type<'db>) { + self.types.declarations.insert(declaration, ty); + } + + fn set_has_deferred(&mut self) { + self.types.has_deferred = true; + } + + fn infer_deferred_types(&mut self) { + if self.types.has_deferred { + let mut deferred_expression_types: FxHashMap> = + FxHashMap::default(); + // invariant: only annotations and base classes are deferred, and both of these only + // occur within a declaration (annotated assignment, function or class definition) + for definition in self.types.declarations.keys() { + if infer_definition_types(self.db, *definition).has_deferred { + let deferred = infer_deferred_types(self.db, *definition); + deferred_expression_types.extend(deferred.expressions.iter()); + } + } + self.types + .expressions + .extend(deferred_expression_types.iter()); + } + } + + fn invalid_assignment_diagnostic(&mut self, declared_ty: Type<'db>, assigned_ty: Type<'db>) { + match declared_ty { + Type::Class(class) => { + self.add_diagnostic("invalid-assignment", format_args!( + "Implicit shadowing of class '{}'; annotate to make it explicit if this is intentional.", + class.name(self.db))); + } + Type::Function(function) => { + self.add_diagnostic("invalid-assignment", format_args!( + "Implicit shadowing of function '{}'; annotate to make it explicit if this is intentional.", + function.name(self.db))); + } + _ => { + self.add_diagnostic( + "invalid-assignment", + format_args!( + "Object of type '{}' is not assignable to '{}'.", + assigned_ty.display(self.db), + declared_ty.display(self.db), + ), + ); + } + } + } + + /// Emit a diagnostic that the object represented by `node` is not iterable. + pub(super) fn not_iterable_diagnostic(&mut self, not_iterable_ty: Type<'db>) { + self.add_diagnostic( + "not-iterable", + format_args!( + "Object of type '{}' is not iterable", + not_iterable_ty.display(self.db) + ), + ); + } + + fn unresolved_module_diagnostic(&mut self, level: u32, module: Option<&str>) { + self.add_diagnostic( + "unresolved-import", + format_args!( + "Cannot resolve import '{}{}'.", + ".".repeat(level as usize), + module.unwrap_or_default() + ), + ); + } + + /// Adds a new diagnostic. + /// + /// The diagnostic does not get added if the rule isn't enabled for this file. + fn add_diagnostic(&mut self, rule: &str, message: std::fmt::Arguments) { + if !self.db.is_file_open(self.file) { + return; + } + + // TODO: Don't emit the diagnostic if: + // * The enclosing node contains any syntax errors + // * The rule is disabled for this file. We probably want to introduce a new query that + // returns a rule selector for a given file that respects the package's settings, + // any global pragma comments in the file, and any per-file-ignores. + + self.types.diagnostics.push(TypeCheckDiagnostic { + file: self.file, + rule: rule.to_string(), + message: message.to_string(), + range: self + .node + .expect("add_diagnostic called with no node set") + .range(), + }); } }