Skip to content

Commit

Permalink
feat: infer self & Self HIR types (#564)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wodann authored Jun 5, 2024
1 parent 45b92db commit 3884019
Show file tree
Hide file tree
Showing 37 changed files with 399 additions and 76 deletions.
2 changes: 2 additions & 0 deletions crates/mun_codegen/src/ir/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ impl<'db, 'ink, 't> BodyIrGenerator<'db, 'ink, 't> {
.expect("unknown path")
.0
{
ValueNs::ImplSelf(_) => unimplemented!("no support for self types"),
ValueNs::LocalBinding(pat) => {
if let Some(param) = self.pat_to_param.get(&pat) {
*param
Expand Down Expand Up @@ -618,6 +619,7 @@ impl<'db, 'ink, 't> BodyIrGenerator<'db, 'ink, 't> {
.expect("unknown path")
.0
{
ValueNs::ImplSelf(_) => unimplemented!("no support for self types"),
ValueNs::LocalBinding(pat) => *self
.pat_to_local
.get(&pat)
Expand Down
4 changes: 4 additions & 0 deletions crates/mun_compiler_daemon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ mun_compiler = { version = "0.6.0-dev", path = "../mun_compiler" }
mun_project = { version = "0.6.0-dev", path = "../mun_project" }
mun_hir = { version = "0.6.0-dev", path = "../mun_hir" }
notify = { version = "4.0", default-features = false }

# Enable std feature for winapi through feature unification to ensure notify uses the correct `c_void` type
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3.8", features = ["std"] }
1 change: 1 addition & 0 deletions crates/mun_hir/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ once_cell = { version = "1.19.0", default-features = false }
rustc-hash = { version = "1.1", default-features = false }
salsa = { version = "0.16.1", default-features = false }
smallvec = { version = "1.11.2", features = ["union"], default-features = false }
bitflags = { version = "2.5.0", default-features = false }

[dev-dependencies]
mun_test = { path = "../mun_test" }
Expand Down
2 changes: 1 addition & 1 deletion crates/mun_hir/src/code_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub use self::{
function::{Function, FunctionData},
module::{Module, ModuleDef},
package::Package,
r#impl::{AssocItem, Impl, ImplData},
r#impl::{AssocItem, ImplData},
r#struct::{Field, Struct, StructData, StructKind, StructMemoryKind},
src::HasSource,
type_alias::{TypeAlias, TypeAliasData},
Expand Down
9 changes: 5 additions & 4 deletions crates/mun_hir/src/code_model/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
expr::{validator::ExprValidator, BodySourceMap},
has_module::HasModule,
ids::{FunctionId, Lookup},
item_tree::FunctionFlags,
name_resolution::Namespace,
resolve::HasResolver,
type_ref::{LocalTypeRefId, TypeRefMap, TypeRefSourceMap},
Expand Down Expand Up @@ -34,7 +35,7 @@ pub struct FunctionData {
ret_type: LocalTypeRefId,
type_ref_map: TypeRefMap,
type_ref_source_map: TypeRefSourceMap,
is_extern: bool,
flags: FunctionFlags,
}

impl FunctionData {
Expand Down Expand Up @@ -68,7 +69,7 @@ impl FunctionData {
ret_type,
type_ref_map,
type_ref_source_map,
is_extern: func.is_extern,
flags: func.flags,
visibility: item_tree[func.visibility].clone(),
})
}
Expand Down Expand Up @@ -99,7 +100,7 @@ impl FunctionData {

/// Returns true if this function is an extern function.
pub fn is_extern(&self) -> bool {
self.is_extern
self.flags.is_extern()
}
}

Expand Down Expand Up @@ -168,7 +169,7 @@ impl Function {
}

pub fn is_extern(self, db: &dyn HirDatabase) -> bool {
db.fn_data(self.id).is_extern
db.fn_data(self.id).flags.is_extern()
}

pub(crate) fn body_source_map(self, db: &dyn HirDatabase) -> Arc<BodySourceMap> {
Expand Down
21 changes: 19 additions & 2 deletions crates/mun_hir/src/code_model/module.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Function, Package, Struct, TypeAlias};
use super::{r#impl::Impl, AssocItem, Function, Package, Struct, TypeAlias};
use crate::{
ids::{ItemDefinitionId, ModuleId},
primitive_type::PrimitiveType,
Expand Down Expand Up @@ -81,7 +81,7 @@ impl Module {
let package_defs = db.package_defs(self.id.package);
package_defs.add_diagnostics(db.upcast(), self.id.local_id, sink);

// Add diagnostics from impls
// Add diagnostics from inherent impls
let inherent_impls = db.inherent_impls_in_package(self.id.package);
inherent_impls.add_module_diagnostics(db, self.id.local_id, sink);

Expand All @@ -102,6 +102,15 @@ impl Module {
_ => (),
}
}

// Add diagnostics from impls
for item in self.impls(db) {
for associated_item in item.items(db) {
let AssocItem::Function(fun) = associated_item;

fun.diagnostics(db, sink);
}
}
}

/// Returns all the child modules of this module
Expand Down Expand Up @@ -141,6 +150,14 @@ impl Module {
)
.collect()
}

pub fn impls(self, db: &dyn HirDatabase) -> Vec<Impl> {
let package_defs = db.package_defs(self.id.package);
package_defs.modules[self.id.local_id]
.impls()
.map(Impl::from)
.collect()
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
Expand Down
3 changes: 3 additions & 0 deletions crates/mun_hir/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
#[salsa::invoke(crate::ty::type_for_def)]
fn type_for_def(&self, def: TypableDef, ns: Namespace) -> Ty;

#[salsa::invoke(crate::ty::type_for_impl_self)]
fn type_for_impl_self(&self, def: ImplId) -> Ty;

#[salsa::invoke(InherentImpls::inherent_impls_in_package_query)]
fn inherent_impls_in_package(&self, package: PackageId) -> Arc<InherentImpls>;
}
Expand Down
26 changes: 22 additions & 4 deletions crates/mun_hir/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
diagnostics::DiagnosticSink,
ids::{DefWithBodyId, Lookup},
in_file::InFile,
name::AsName,
name::{name, AsName},
primitive_type::{PrimitiveFloat, PrimitiveInt},
type_ref::{LocalTypeRefId, TypeRef, TypeRefMap, TypeRefMapBuilder, TypeRefSourceMap},
DefDatabase, FileId, HirDatabase, Name, Path,
Expand Down Expand Up @@ -46,6 +46,7 @@ pub struct Body {
///
/// If this `Body` is for the body of a constant, this will just be empty.
params: Vec<(PatId, LocalTypeRefId)>,
self_param: Option<(PatId, LocalTypeRefId)>,
/// The `ExprId` of the actual body expression.
body_expr: ExprId,
ret_type: LocalTypeRefId,
Expand Down Expand Up @@ -82,6 +83,10 @@ impl Body {
&self.params
}

pub fn self_param(&self) -> Option<&(PatId, LocalTypeRefId)> {
self.self_param.as_ref()
}

pub fn body_expr(&self) -> ExprId {
self.body_expr
}
Expand Down Expand Up @@ -147,7 +152,7 @@ impl Index<LocalTypeRefId> for Body {
type ExprPtr = Either<AstPtr<ast::Expr>, AstPtr<ast::RecordField>>;
type ExprSource = InFile<ExprPtr>;

type PatPtr = AstPtr<ast::Pat>; //Either<AstPtr<ast::Pat>, AstPtr<ast::SelfParam>>;
type PatPtr = Either<AstPtr<ast::Pat>, AstPtr<ast::SelfParam>>;
type PatSource = InFile<PatPtr>;

type RecordPtr = AstPtr<ast::RecordField>;
Expand Down Expand Up @@ -189,7 +194,7 @@ impl BodySourceMap {
}

pub(crate) fn node_pat(&self, node: &ast::Pat) -> Option<PatId> {
self.pat_map.get(&AstPtr::new(node)).cloned()
self.pat_map.get(&Either::Left(AstPtr::new(node))).cloned()
}

pub fn type_refs(&self) -> &TypeRefSourceMap {
Expand Down Expand Up @@ -466,6 +471,7 @@ pub(crate) struct ExprCollector<'a> {
pats: Arena<Pat>,
source_map: BodySourceMap,
params: Vec<(PatId, LocalTypeRefId)>,
self_param: Option<(PatId, LocalTypeRefId)>,
body_expr: Option<ExprId>,
ret_type: Option<LocalTypeRefId>,
type_ref_builder: TypeRefMapBuilder,
Expand All @@ -482,6 +488,7 @@ impl<'a> ExprCollector<'a> {
pats: Arena::default(),
source_map: BodySourceMap::default(),
params: Vec::new(),
self_param: None,
body_expr: None,
ret_type: None,
type_ref_builder: TypeRefMap::builder(),
Expand Down Expand Up @@ -525,6 +532,16 @@ impl<'a> ExprCollector<'a> {

fn collect_fn_body(&mut self, node: &ast::FunctionDef) {
if let Some(param_list) = node.param_list() {
if let Some(self_param) = param_list.self_param() {
let self_pat = self.alloc_pat(
Pat::Bind { name: name![self] },
Either::Right(AstPtr::new(&self_param)),
);

let self_type = self.type_ref_builder.alloc_self();
self.self_param = Some((self_pat, self_type));
}

for param in param_list.params() {
let pat = if let Some(pat) = param.pat() {
pat
Expand Down Expand Up @@ -895,7 +912,7 @@ impl<'a> ExprCollector<'a> {
ast::PatKind::PlaceholderPat(_) => Pat::Wild,
};
let ptr = AstPtr::new(&pat);
self.alloc_pat(pattern, ptr)
self.alloc_pat(pattern, Either::Left(ptr))
}

fn collect_return(&mut self, expr: ast::ReturnExpr) -> ExprId {
Expand Down Expand Up @@ -930,6 +947,7 @@ impl<'a> ExprCollector<'a> {
exprs: self.exprs,
pats: self.pats,
params: self.params,
self_param: self.self_param,
body_expr: self.body_expr.expect("A body should have been collected"),
type_refs,
ret_type: self
Expand Down
3 changes: 3 additions & 0 deletions crates/mun_hir/src/expr/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ impl ExprScopes {
scope_by_expr: FxHashMap::default(),
};
let root = scopes.root_scope();
if let Some(self_param) = body.self_param {
scopes.add_bindings(body, root, self_param.0);
}
scopes.add_params_bindings(body, root, body.params().iter().map(|p| &p.0));
compute_expr_scopes(body.body_expr(), body, &mut scopes, root);
scopes
Expand Down
4 changes: 4 additions & 0 deletions crates/mun_hir/src/expr/validator/uninitialized_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ impl<'d> ExprValidator<'d> {

// Add all parameter patterns to the set of initialized patterns (they must have
// been initialized)
if let Some((pat, _)) = self.body.self_param {
initialized_patterns.insert(pat);
}

for (pat, _) in self.body.params.iter() {
initialized_patterns.insert(*pat);
}
Expand Down
30 changes: 29 additions & 1 deletion crates/mun_hir/src/item_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,38 @@ pub struct Import {
pub struct Function {
pub name: Name,
pub visibility: RawVisibilityId,
pub is_extern: bool,
pub types: TypeRefMap,
pub params: IdRange<Param>,
pub ret_type: LocalTypeRefId,
pub ast_id: FileAstId<ast::FunctionDef>,
pub(crate) flags: FunctionFlags,
}

bitflags::bitflags! {
#[doc = "Flags that are used to store additional information about a function"]
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
pub(crate) struct FunctionFlags: u8 {
const HAS_SELF_PARAM = 1 << 0;
const HAS_BODY = 1 << 1;
const IS_EXTERN = 1 << 2;
}
}

impl FunctionFlags {
/// Whether the function has a self parameter.
pub fn has_self_param(self) -> bool {
self.contains(Self::HAS_SELF_PARAM)
}

/// Whether the function has a body.
pub fn has_body(self) -> bool {
self.contains(Self::HAS_BODY)
}

/// Whether the function is extern.
pub fn is_extern(self) -> bool {
self.contains(Self::IS_EXTERN)
}
}

#[derive(Debug, Clone, Eq, PartialEq)]
Expand All @@ -317,6 +344,7 @@ pub struct Param {
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ParamAstId {
Param(FileAstId<ast::Param>),
SelfParam(FileAstId<ast::SelfParam>),
}

#[derive(Debug, Clone, Eq, PartialEq)]
Expand Down
36 changes: 30 additions & 6 deletions crates/mun_hir/src/item_tree/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ use mun_syntax::ast::{
use smallvec::SmallVec;

use super::{
diagnostics, AssociatedItem, Field, Fields, Function, IdRange, Impl, ItemTree, ItemTreeData,
ItemTreeNode, ItemVisibilities, LocalItemTreeId, ModItem, Param, ParamAstId, RawVisibilityId,
Struct, TypeAlias,
diagnostics, AssociatedItem, Field, Fields, Function, FunctionFlags, IdRange, Impl, ItemTree,
ItemTreeData, ItemTreeNode, ItemVisibilities, LocalItemTreeId, ModItem, Param, ParamAstId,
RawVisibilityId, Struct, TypeAlias,
};
use crate::{
item_tree::Import,
Expand Down Expand Up @@ -158,7 +158,21 @@ impl Context {

// Lower all the params
let start_param_idx = self.next_param_idx();
let mut has_self_param = false;
if let Some(param_list) = func.param_list() {
if let Some(self_param) = param_list.self_param() {
let ast_id = self.source_ast_id_map.ast_id(&self_param);
let type_ref = match self_param.ascribed_type().as_ref() {
Some(type_ref) => types.alloc_from_node(type_ref),
None => types.alloc_self(),
};
self.data.params.alloc(Param {
type_ref,
ast_id: ParamAstId::SelfParam(ast_id),
});
has_self_param = true;
}

for param in param_list.params() {
let ast_id = self.source_ast_id_map.ast_id(&param);
let type_ref = types.alloc_from_node_opt(param.ascribed_type().as_ref());
Expand All @@ -177,18 +191,28 @@ impl Context {
Some(ty) => types.alloc_from_node(&ty),
};

let is_extern = func.is_extern();

let (types, _types_source_map) = types.finish();
let ast_id = self.source_ast_id_map.ast_id(func);

let mut flags = FunctionFlags::default();
if func.is_extern() {
flags |= FunctionFlags::IS_EXTERN;
}
if func.body().is_some() {
flags |= FunctionFlags::HAS_BODY;
}
if has_self_param {
flags |= FunctionFlags::HAS_SELF_PARAM;
}

let res = Function {
name,
visibility,
is_extern,
types,
params,
ret_type,
ast_id,
flags,
};

Some(self.data.functions.alloc(res).into())
Expand Down
Loading

0 comments on commit 3884019

Please sign in to comment.