Skip to content

Commit

Permalink
feat: Allow associated types to be ellided from trait constraints (#7026
Browse files Browse the repository at this point in the history
)

Co-authored-by: Ary Borenszweig <[email protected]>
  • Loading branch information
jfecher and asterite authored Jan 14, 2025
1 parent b6097a0 commit aa7b91c
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 50 deletions.
147 changes: 124 additions & 23 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ impl<'context> Elaborator<'context> {

self.define_function_metas(&mut items.functions, &mut items.impls, &mut items.trait_impls);

self.collect_traits(&items.traits);
self.collect_traits(&mut items.traits);

// Before we resolve any function symbols we must go through our impls and
// re-collect the methods within into their proper module. This cannot be
Expand Down Expand Up @@ -354,6 +354,7 @@ impl<'context> Elaborator<'context> {
self.current_trait = Some(trait_id);
self.elaborate_functions(unresolved_trait.fns_with_default_impl);
}

self.current_trait = None;

for impls in items.impls.into_values() {
Expand Down Expand Up @@ -475,7 +476,7 @@ impl<'context> Elaborator<'context> {
self.add_existing_variable_to_scope(name, parameter.clone(), warn_if_unused);
}

self.add_trait_constraints_to_scope(&func_meta);
self.add_trait_constraints_to_scope(&func_meta.trait_constraints, func_meta.location.span);

let (hir_func, body_type) = match kind {
FunctionKind::Builtin
Expand All @@ -501,7 +502,7 @@ impl<'context> Elaborator<'context> {
// when multiple impls are available. Instead we default first to choose the Field or u64 impl.
self.check_and_pop_function_context();

self.remove_trait_constraints_from_scope(&func_meta);
self.remove_trait_constraints_from_scope(&func_meta.trait_constraints);

let func_scope_tree = self.scopes.end_function();

Expand Down Expand Up @@ -733,8 +734,13 @@ impl<'context> Elaborator<'context> {
None
}

/// TODO: This is currently only respected for generic free functions
/// there's a bunch of other places where trait constraints can pop up
/// Resolve the given trait constraints and add them to scope as we go.
/// This second step is necessary to resolve subsequent constraints such
/// as `<T as Foo>::Bar: Eq` which may lookup an impl which was assumed
/// by a previous constraint.
///
/// If these constraints are unwanted afterward they should be manually
/// removed from the interner.
fn resolve_trait_constraints(
&mut self,
where_clause: &[UnresolvedTraitConstraint],
Expand All @@ -745,12 +751,92 @@ impl<'context> Elaborator<'context> {
.collect()
}

pub fn resolve_trait_constraint(
/// Expands any traits in a where clause to mention all associated types if they were
/// elided by the user. See `add_missing_named_generics` for more detail.
///
/// Returns all newly created generics to be added to this function/trait/impl.
fn desugar_trait_constraints(
&mut self,
where_clause: &mut [UnresolvedTraitConstraint],
) -> Vec<ResolvedGeneric> {
where_clause
.iter_mut()
.flat_map(|constraint| self.add_missing_named_generics(&mut constraint.trait_bound))
.collect()
}

/// For each associated type that isn't mentioned in a trait bound, this adds
/// the type as an implicit generic to the where clause and returns the newly
/// created generics in a vector to add to the function/trait/impl later.
/// For example, this will turn a function using a trait with 2 associated types:
///
/// `fn foo<T>() where T: Foo { ... }`
///
/// into:
/// `fn foo<T>() where T: Foo<Bar = A, Baz = B> { ... }`
///
/// with a vector of `<A, B>` returned so that the caller can then modify the function to:
/// `fn foo<T, A, B>() where T: Foo<Bar = A, Baz = B> { ... }`
fn add_missing_named_generics(&mut self, bound: &mut TraitBound) -> Vec<ResolvedGeneric> {
let mut added_generics = Vec::new();

let Ok(item) = self.resolve_path_or_error(bound.trait_path.clone()) else {
return Vec::new();
};

let PathResolutionItem::Trait(trait_id) = item else {
return Vec::new();
};

let the_trait = self.get_trait_mut(trait_id);

if the_trait.associated_types.len() > bound.trait_generics.named_args.len() {
for associated_type in &the_trait.associated_types.clone() {
if !bound
.trait_generics
.named_args
.iter()
.any(|(name, _)| name.0.contents == *associated_type.name.as_ref())
{
// This generic isn't contained in the bound's named arguments,
// so add it by creating a fresh type variable.
let new_generic_id = self.interner.next_type_variable_id();
let type_var = TypeVariable::unbound(new_generic_id, Kind::Normal);

let span = bound.trait_path.span;
let name = associated_type.name.clone();
let typ = Type::NamedGeneric(type_var.clone(), name.clone());
let typ = self.interner.push_quoted_type(typ);
let typ = UnresolvedTypeData::Resolved(typ).with_span(span);
let ident = Ident::new(associated_type.name.as_ref().clone(), span);

bound.trait_generics.named_args.push((ident, typ));
added_generics.push(ResolvedGeneric { name, span, type_var });
}
}
}

added_generics
}

/// Resolves a trait constraint and adds it to scope as an assumed impl.
/// This second step is necessary to resolve subsequent constraints such
/// as `<T as Foo>::Bar: Eq` which may lookup an impl which was assumed
/// by a previous constraint.
fn resolve_trait_constraint(
&mut self,
constraint: &UnresolvedTraitConstraint,
) -> Option<TraitConstraint> {
let typ = self.resolve_type(constraint.typ.clone());
let trait_bound = self.resolve_trait_bound(&constraint.trait_bound)?;

self.add_trait_bound_to_scope(
constraint.trait_bound.trait_path.span,
&typ,
&trait_bound,
trait_bound.trait_id,
);

Some(TraitConstraint { typ, trait_bound })
}

Expand Down Expand Up @@ -800,10 +886,13 @@ impl<'context> Elaborator<'context> {
let has_inline_attribute = has_no_predicates_attribute || should_fold;
let is_pub_allowed = self.pub_allowed(func, in_contract);
self.add_generics(&func.def.generics);
let mut generics = vecmap(&self.generics, |generic| generic.type_var.clone());

let new_generics = self.desugar_trait_constraints(&mut func.def.where_clause);
generics.extend(new_generics.into_iter().map(|generic| generic.type_var));

let mut trait_constraints = self.resolve_trait_constraints(&func.def.where_clause);

let mut generics = vecmap(&self.generics, |generic| generic.type_var.clone());
let mut parameters = Vec::new();
let mut parameter_types = Vec::new();
let mut parameter_idents = Vec::new();
Expand Down Expand Up @@ -874,6 +963,9 @@ impl<'context> Elaborator<'context> {
None
};

// Remove the traits assumed by `resolve_trait_constraints` from scope
self.remove_trait_constraints_from_scope(&trait_constraints);

let meta = FuncMeta {
name: name_ident,
kind: func.kind,
Expand Down Expand Up @@ -1013,10 +1105,10 @@ impl<'context> Elaborator<'context> {
}
}

fn add_trait_constraints_to_scope(&mut self, func_meta: &FuncMeta) {
for constraint in &func_meta.trait_constraints {
fn add_trait_constraints_to_scope(&mut self, constraints: &[TraitConstraint], span: Span) {
for constraint in constraints {
self.add_trait_bound_to_scope(
func_meta,
span,
&constraint.typ,
&constraint.trait_bound,
constraint.trait_bound.trait_id,
Expand All @@ -1030,16 +1122,16 @@ impl<'context> Elaborator<'context> {
let self_type =
self.self_type.clone().expect("Expected a self type if there's a current trait");
self.add_trait_bound_to_scope(
func_meta,
span,
&self_type,
&constraint.trait_bound,
constraint.trait_bound.trait_id,
);
}
}

fn remove_trait_constraints_from_scope(&mut self, func_meta: &FuncMeta) {
for constraint in &func_meta.trait_constraints {
fn remove_trait_constraints_from_scope(&mut self, constraints: &[TraitConstraint]) {
for constraint in constraints {
self.interner
.remove_assumed_trait_implementations_for_trait(constraint.trait_bound.trait_id);
}
Expand All @@ -1052,7 +1144,7 @@ impl<'context> Elaborator<'context> {

fn add_trait_bound_to_scope(
&mut self,
func_meta: &FuncMeta,
span: Span,
object: &Type,
trait_bound: &ResolvedTraitBound,
starting_trait_id: TraitId,
Expand All @@ -1064,7 +1156,6 @@ impl<'context> Elaborator<'context> {
if let Some(the_trait) = self.interner.try_get_trait(trait_id) {
let trait_name = the_trait.name.to_string();
let typ = object.clone();
let span = func_meta.location.span;
self.push_err(TypeCheckError::UnneededTraitConstraint { trait_name, typ, span });
}
}
Expand All @@ -1081,12 +1172,7 @@ impl<'context> Elaborator<'context> {

let parent_trait_bound =
self.instantiate_parent_trait_bound(trait_bound, &parent_trait_bound);
self.add_trait_bound_to_scope(
func_meta,
object,
&parent_trait_bound,
starting_trait_id,
);
self.add_trait_bound_to_scope(span, object, &parent_trait_bound, starting_trait_id);
}
}
}
Expand Down Expand Up @@ -1316,6 +1402,7 @@ impl<'context> Elaborator<'context> {
self.generics = trait_impl.resolved_generics.clone();

let where_clause = self.resolve_trait_constraints(&trait_impl.where_clause);
self.remove_trait_constraints_from_scope(&where_clause);

self.collect_trait_impl_methods(trait_id, trait_impl, &where_clause);

Expand Down Expand Up @@ -1811,6 +1898,17 @@ impl<'context> Elaborator<'context> {
self.add_generics(&trait_impl.generics);
trait_impl.resolved_generics = self.generics.clone();

let new_generics = self.desugar_trait_constraints(&mut trait_impl.where_clause);
for new_generic in new_generics {
trait_impl.resolved_generics.push(new_generic.clone());
self.generics.push(new_generic);
}

// We need to resolve the where clause before any associated types to be
// able to resolve trait as type syntax, eg. `<T as Foo>` in case there
// is a where constraint for `T: Foo`.
let constraints = self.resolve_trait_constraints(&trait_impl.where_clause);

for (_, _, method) in trait_impl.methods.functions.iter_mut() {
// Attach any trait constraints on the impl to the function
method.def.where_clause.append(&mut trait_impl.where_clause.clone());
Expand All @@ -1823,17 +1921,20 @@ impl<'context> Elaborator<'context> {
let impl_id = self.interner.next_trait_impl_id();
self.current_trait_impl = Some(impl_id);

// Fetch trait constraints here
let path_span = trait_impl.trait_path.span;
let (ordered_generics, named_generics) = trait_impl
.trait_id
.map(|trait_id| {
self.resolve_type_args(trait_generics, trait_id, trait_impl.trait_path.span)
// Check for missing generics & associated types for the trait being implemented
self.resolve_trait_args_from_trait_impl(trait_generics, trait_id, path_span)
})
.unwrap_or_default();

trait_impl.resolved_trait_generics = ordered_generics;
self.interner.set_associated_types_for_impl(impl_id, named_generics);

self.remove_trait_constraints_from_scope(&constraints);

let self_type = self.resolve_type(unresolved_type);
self.self_type = Some(self_type.clone());
trait_impl.methods.self_type = Some(self_type);
Expand Down
7 changes: 6 additions & 1 deletion compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
use super::Elaborator;

impl<'context> Elaborator<'context> {
pub fn collect_traits(&mut self, traits: &BTreeMap<TraitId, UnresolvedTrait>) {
pub fn collect_traits(&mut self, traits: &mut BTreeMap<TraitId, UnresolvedTrait>) {
for (trait_id, unresolved_trait) in traits {
self.local_module = unresolved_trait.module_id;

Expand All @@ -39,8 +39,13 @@ impl<'context> Elaborator<'context> {
&resolved_generics,
);

let new_generics =
this.desugar_trait_constraints(&mut unresolved_trait.trait_def.where_clause);
this.generics.extend(new_generics);

let where_clause =
this.resolve_trait_constraints(&unresolved_trait.trait_def.where_clause);
this.remove_trait_constraints_from_scope(&where_clause);

// Each associated type in this trait is also an implicit generic
for associated_type in &this.interner.get_trait(*trait_id).associated_types {
Expand Down
Loading

0 comments on commit aa7b91c

Please sign in to comment.