Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: trait inheritance #6252

Merged
merged 29 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f2bd51f
Parse trait inheritance
asterite Oct 8, 2024
be1ae07
Make sure nargo_fmt doesn't break with trait inheritance
asterite Oct 8, 2024
19016be
Resolve trait parent bounds
asterite Oct 8, 2024
1bdbcf4
Embed ResolvedTraitBound in TraitConstraint
asterite Oct 8, 2024
3937bd0
Solve trait bounds and lookup methods in parent types
asterite Oct 8, 2024
95fa2c6
Use filter_map
asterite Oct 9, 2024
3b3565e
Merge branch 'master' into ab/trait-inheritance
asterite Oct 9, 2024
c65f865
Add assumed trait implementations for parent traits
asterite Oct 9, 2024
57a538f
Detect trait cycles
asterite Oct 9, 2024
dc50605
Avoid looping forever in case there are cycles
asterite Oct 9, 2024
4b65f50
Add a failing test for trait parents with generics
asterite Oct 9, 2024
436f0be
Add a failing test for when a trait impl is missing a parent implemen…
asterite Oct 9, 2024
38dd4e7
Merge branch 'master' into ab/trait-inheritance
asterite Oct 9, 2024
c32803e
Make it work with generics
asterite Oct 9, 2024
3f8ae2d
Refactor
asterite Oct 9, 2024
9153f27
Check parent traits are implemented
asterite Oct 10, 2024
810b87f
Add a test program
asterite Oct 10, 2024
6701a42
Merge branch 'master' into ab/trait-inheritance
asterite Oct 10, 2024
3278cc6
Correct way to format trait name with generics
asterite Oct 10, 2024
5d79bd5
Correct parent check regarding generics
asterite Oct 10, 2024
7b34a78
Add one more test
asterite Oct 10, 2024
fdce653
Improve missing trait bound error message
asterite Oct 10, 2024
e7ea16f
Add some docs
asterite Oct 10, 2024
8c5f3f4
Simpler way to construct a Diagnostic
asterite Oct 10, 2024
81b7f3b
Apply suggestions from code review
asterite Oct 10, 2024
f69a6ab
Add supertrait and supertraits to cspell
asterite Oct 10, 2024
216fc2e
Extract `bind_ordered_generics`
asterite Oct 10, 2024
566fce7
Remove more duplicated code
asterite Oct 10, 2024
5ca1dee
Add subtrait to cspell
asterite Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion compiler/noirc_frontend/src/ast/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use super::{Documented, GenericTypeArgs, ItemVisibility};
pub struct NoirTrait {
pub name: Ident,
pub generics: UnresolvedGenerics,
pub bounds: Vec<TraitBound>,
pub where_clause: Vec<UnresolvedTraitConstraint>,
pub span: Span,
pub items: Vec<Documented<TraitItem>>,
Expand Down Expand Up @@ -134,7 +135,12 @@ impl Display for NoirTrait {
let generics = vecmap(&self.generics, |generic| generic.to_string());
let generics = if generics.is_empty() { "".into() } else { generics.join(", ") };

writeln!(f, "trait {}{} {{", self.name, generics)?;
write!(f, "trait {}{}", self.name, generics)?;
if !self.bounds.is_empty() {
let bounds = vecmap(&self.bounds, |bound| bound.to_string()).join(" + ");
write!(f, ": {}", bounds)?;
}
writeln!(f, " {{")?;

for item in self.items.iter() {
let item = item.to_string();
Expand Down
10 changes: 6 additions & 4 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::{
HirPrefixExpression,
},
stmt::HirStatement,
traits::TraitConstraint,
traits::{ResolvedTraitBound, TraitConstraint},
},
node_interner::{DefinitionKind, ExprId, FuncId, InternedStatementKind, TraitMethodId},
token::Tokens,
Expand Down Expand Up @@ -743,9 +743,11 @@ impl<'context> Elaborator<'context> {
// that implements the trait.
let constraint = TraitConstraint {
typ: operand_type.clone(),
trait_id: trait_id.trait_id,
trait_generics: TraitGenerics::default(),
span,
trait_bound: ResolvedTraitBound {
trait_id: trait_id.trait_id,
trait_generics: TraitGenerics::default(),
span,
},
};
self.push_trait_constraint(constraint, expr_id);
self.type_check_operator_method(expr_id, trait_id, operand_type, span);
Expand Down
154 changes: 127 additions & 27 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
rc::Rc,
};

use crate::{ast::ItemVisibility, StructField};
use crate::{ast::ItemVisibility, hir_def::traits::ResolvedTraitBound, StructField, TypeBindings};
use crate::{
ast::{
BlockExpression, FunctionKind, GenericTypeArgs, Ident, NoirFunction, NoirStruct, Param,
Expand Down Expand Up @@ -54,6 +54,7 @@ mod unquote;
use fm::FileId;
use iter_extended::vecmap;
use noirc_errors::{Location, Span};
use types::bind_ordered_generics;

use self::traits::check_trait_impl_method_matches_declaration;

Expand Down Expand Up @@ -433,7 +434,8 @@ impl<'context> Elaborator<'context> {

// Now remove all the `where` clause constraints we added
for constraint in &func_meta.trait_constraints {
self.interner.remove_assumed_trait_implementations_for_trait(constraint.trait_id);
self.interner
.remove_assumed_trait_implementations_for_trait(constraint.trait_bound.trait_id);
}

let func_scope_tree = self.scopes.end_function();
Expand Down Expand Up @@ -479,9 +481,9 @@ impl<'context> Elaborator<'context> {

self.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics.ordered,
&constraint.trait_generics.named,
constraint.trait_bound.trait_id,
&constraint.trait_bound.trait_generics.ordered,
&constraint.trait_bound.trait_generics.named,
expr_id,
span,
);
Expand Down Expand Up @@ -510,7 +512,8 @@ impl<'context> Elaborator<'context> {
let generic_type = Type::NamedGeneric(new_generic, Rc::new(name));
let trait_bound = TraitBound { trait_path, trait_id: None, trait_generics };

if let Some(new_constraint) = self.resolve_trait_bound(&trait_bound, generic_type.clone()) {
if let Some(trait_bound) = self.resolve_trait_bound(&trait_bound) {
let new_constraint = TraitConstraint { typ: generic_type.clone(), trait_bound };
trait_constraints.push(new_constraint);
}

Expand Down Expand Up @@ -668,22 +671,19 @@ impl<'context> Elaborator<'context> {
constraint: &UnresolvedTraitConstraint,
) -> Option<TraitConstraint> {
let typ = self.resolve_type(constraint.typ.clone());
self.resolve_trait_bound(&constraint.trait_bound, typ)
let trait_bound = self.resolve_trait_bound(&constraint.trait_bound)?;
Some(TraitConstraint { typ, trait_bound })
}

pub fn resolve_trait_bound(
&mut self,
bound: &TraitBound,
typ: Type,
) -> Option<TraitConstraint> {
pub fn resolve_trait_bound(&mut self, bound: &TraitBound) -> Option<ResolvedTraitBound> {
let the_trait = self.lookup_trait_or_error(bound.trait_path.clone())?;
let trait_id = the_trait.id;
let span = bound.trait_path.span;

let (ordered, named) = self.resolve_type_args(bound.trait_generics.clone(), trait_id, span);

let trait_generics = TraitGenerics { ordered, named };
Some(TraitConstraint { typ, trait_id, trait_generics, span })
Some(ResolvedTraitBound { trait_id, trait_generics, span })
}

/// Extract metadata from a NoirFunction
Expand Down Expand Up @@ -942,21 +942,52 @@ impl<'context> Elaborator<'context> {

fn add_trait_constraints_to_scope(&mut self, func_meta: &FuncMeta) {
for constraint in &func_meta.trait_constraints {
let object = constraint.typ.clone();
let trait_id = constraint.trait_id;
let generics = constraint.trait_generics.clone();

if !self.interner.add_assumed_trait_implementation(object, trait_id, generics) {
if let Some(the_trait) = self.interner.try_get_trait(trait_id) {
let trait_name = the_trait.name.to_string();
let typ = constraint.typ.clone();
let span = func_meta.location.span;
self.push_err(TypeCheckError::UnneededTraitConstraint {
trait_name,
typ,
span,
});
self.add_trait_bound_to_scope(
func_meta,
&constraint.typ,
&constraint.trait_bound,
constraint.trait_bound.trait_id,
);
}
}

fn add_trait_bound_to_scope(
&mut self,
func_meta: &FuncMeta,
object: &Type,
trait_bound: &ResolvedTraitBound,
starting_trait_id: TraitId,
) {
let trait_id = trait_bound.trait_id;
let generics = trait_bound.trait_generics.clone();

if !self.interner.add_assumed_trait_implementation(object.clone(), trait_id, generics) {
if let Some(the_trait) = self.interner.try_get_trait(trait_id) {
let trait_name = the_trait.name.to_string();
let typ = object.clone();
let span = func_meta.location.span;
self.push_err(TypeCheckError::UnneededTraitConstraint { trait_name, typ, span });
}
}

// Also add assumed implementations for the parent traits, if any
if let Some(trait_bounds) =
self.interner.try_get_trait(trait_id).map(|the_trait| the_trait.trait_bounds.clone())
{
for parent_trait_bound in trait_bounds {
// Avoid looping forever in case there are cycles
if parent_trait_bound.trait_id == starting_trait_id {
continue;
}

let parent_trait_bound =
self.instantiate_parent_trait_bound(trait_bound, &parent_trait_bound);
self.add_trait_bound_to_scope(
func_meta,
object,
&parent_trait_bound,
starting_trait_id,
);
}
}
}
Expand All @@ -972,6 +1003,8 @@ impl<'context> Elaborator<'context> {
self.file = trait_impl.file_id;
self.local_module = trait_impl.module_id;

self.check_parent_traits_are_implemented(&trait_impl);

self.generics = trait_impl.resolved_generics;
self.current_trait_impl = trait_impl.impl_id;

Expand All @@ -988,6 +1021,73 @@ impl<'context> Elaborator<'context> {
self.generics.clear();
}

fn check_parent_traits_are_implemented(&mut self, trait_impl: &UnresolvedTraitImpl) {
let Some(trait_id) = trait_impl.trait_id else {
return;
};

let Some(object_type) = &trait_impl.resolved_object_type else {
return;
};

let Some(the_trait) = self.interner.try_get_trait(trait_id) else {
return;
};

if the_trait.trait_bounds.is_empty() {
return;
}

let impl_trait = the_trait.name.to_string();
let the_trait_file = the_trait.location.file;

let mut bindings = TypeBindings::new();
bind_ordered_generics(
&the_trait.generics,
&trait_impl.resolved_trait_generics,
&mut bindings,
);

// Note: we only check if the immediate parents are implemented, we don't check recursively.
// Why? If a parent isn't implemented, we get an error. If a parent is implemented, we'll
// do the same check for the parent, so this trait's parents parents will be checked, so the
// recursion is guaranteed.
for parent_trait_bound in the_trait.trait_bounds.clone() {
let Some(parent_trait) = self.interner.try_get_trait(parent_trait_bound.trait_id)
else {
continue;
};

let parent_trait_bound = ResolvedTraitBound {
trait_generics: parent_trait_bound
.trait_generics
.map(|typ| typ.substitute(&bindings)),
..parent_trait_bound
};

if self
.interner
.try_lookup_trait_implementation(
object_type,
parent_trait_bound.trait_id,
&parent_trait_bound.trait_generics.ordered,
&parent_trait_bound.trait_generics.named,
)
.is_err()
asterite marked this conversation as resolved.
Show resolved Hide resolved
{
let missing_trait =
format!("{}{}", parent_trait.name, parent_trait_bound.trait_generics);
self.push_err(ResolverError::TraitNotImplemented {
impl_trait: impl_trait.clone(),
missing_trait,
type_missing_trait: trait_impl.object_type.to_string(),
span: trait_impl.object_type.span,
missing_trait_location: Location::new(parent_trait_bound.span, the_trait_file),
});
}
}
}

fn collect_impls(
&mut self,
module: LocalModuleId,
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/elaborator/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ impl<'context> Elaborator<'context> {
if let ImplKind::TraitMethod(mut method) = ident.impl_kind {
method.constraint.apply_bindings(&bindings);
if method.assumed {
let trait_generics = method.constraint.trait_generics.clone();
let trait_generics = method.constraint.trait_bound.trait_generics.clone();
let object_type = method.constraint.typ;
let trait_impl = TraitImplKind::Assumed { object_type, trait_generics };
self.interner.select_impl_for_expression(expr_id, trait_impl);
Expand Down Expand Up @@ -748,7 +748,7 @@ impl<'context> Elaborator<'context> {
HirMethodReference::TraitMethodId(method_id, generics) => {
let mut constraint =
self.interner.get_trait(method_id.trait_id).as_constraint(span);
constraint.trait_generics = generics;
constraint.trait_bound.trait_generics = generics;
ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed: false })
}
};
Expand Down
22 changes: 13 additions & 9 deletions compiler/noirc_frontend/src/elaborator/trait_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,36 +167,40 @@ impl<'context> Elaborator<'context> {
let mut substituted_method_ids = HashSet::default();
for method_constraint in method.trait_constraints.iter() {
let substituted_constraint_type = method_constraint.typ.substitute(&bindings);
let substituted_trait_generics =
method_constraint.trait_generics.map(|generic| generic.substitute(&bindings));
let substituted_trait_generics = method_constraint
.trait_bound
.trait_generics
.map(|generic| generic.substitute(&bindings));

substituted_method_ids.insert((
substituted_constraint_type,
method_constraint.trait_id,
method_constraint.trait_bound.trait_id,
substituted_trait_generics,
));
}

for override_trait_constraint in override_meta.trait_constraints.clone() {
let override_constraint_is_from_impl =
trait_impl_where_clause.iter().any(|impl_constraint| {
impl_constraint.trait_id == override_trait_constraint.trait_id
impl_constraint.trait_bound.trait_id
== override_trait_constraint.trait_bound.trait_id
});
if override_constraint_is_from_impl {
continue;
}

if !substituted_method_ids.contains(&(
override_trait_constraint.typ.clone(),
override_trait_constraint.trait_id,
override_trait_constraint.trait_generics.clone(),
override_trait_constraint.trait_bound.trait_id,
override_trait_constraint.trait_bound.trait_generics.clone(),
)) {
let the_trait = self.interner.get_trait(override_trait_constraint.trait_id);
let the_trait =
self.interner.get_trait(override_trait_constraint.trait_bound.trait_id);
self.push_err(DefCollectorErrorKind::ImplIsStricterThanTrait {
constraint_typ: override_trait_constraint.typ,
constraint_name: the_trait.name.0.contents.clone(),
constraint_generics: override_trait_constraint.trait_generics,
constraint_span: override_trait_constraint.span,
constraint_generics: override_trait_constraint.trait_bound.trait_generics,
constraint_span: override_trait_constraint.trait_bound.span,
trait_method_name: method.name.0.contents.clone(),
trait_method_span: method.location.span,
});
Expand Down
22 changes: 20 additions & 2 deletions compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ use crate::{
UnresolvedTraitConstraint, UnresolvedType,
},
hir::{def_collector::dc_crate::UnresolvedTrait, type_check::TypeCheckError},
hir_def::{function::Parameters, traits::TraitFunction},
node_interner::{FuncId, NodeInterner, ReferenceId, TraitId},
hir_def::{
function::Parameters,
traits::{ResolvedTraitBound, TraitFunction},
},
node_interner::{DependencyId, FuncId, NodeInterner, ReferenceId, TraitId},
ResolvedGeneric, Type, TypeBindings,
};

Expand All @@ -34,10 +37,17 @@ impl<'context> Elaborator<'context> {
this.generics.push(associated_type.clone());
}

let resolved_trait_bounds = this.resolve_trait_bounds(unresolved_trait);
for bound in &resolved_trait_bounds {
this.interner
.add_trait_dependency(DependencyId::Trait(bound.trait_id), *trait_id);
}

let methods = this.resolve_trait_methods(*trait_id, unresolved_trait);

this.interner.update_trait(*trait_id, |trait_def| {
trait_def.set_methods(methods);
trait_def.set_trait_bounds(resolved_trait_bounds);
});
});

Expand All @@ -53,6 +63,14 @@ impl<'context> Elaborator<'context> {
self.current_trait = None;
}

fn resolve_trait_bounds(
&mut self,
unresolved_trait: &UnresolvedTrait,
) -> Vec<ResolvedTraitBound> {
let bounds = &unresolved_trait.trait_def.bounds;
bounds.iter().filter_map(|bound| self.resolve_trait_bound(bound)).collect()
}

fn resolve_trait_methods(
&mut self,
trait_id: TraitId,
Expand Down
Loading
Loading