diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index aef0771c486..d4da297167b 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -4,7 +4,14 @@ use std::{ }; use crate::{ - ast::ItemVisibility, hir_def::traits::ResolvedTraitBound, StructField, StructType, TypeBindings, + ast::ItemVisibility, + hir_def::{ + expr::{HirBlockExpression, HirExpression}, + stmt::HirStatement, + traits::ResolvedTraitBound, + }, + node_interner::DefinitionId, + StructField, StructType, TypeBindings, }; use crate::{ ast::{ @@ -470,6 +477,16 @@ impl<'context> Elaborator<'context> { self.check_for_unused_variables_in_scope_tree(func_scope_tree); } + // Check that the body can return without calling the function. + if let FunctionKind::Normal | FunctionKind::Recursive = kind { + self.check_for_unbounded_recursion( + id, + self.interner.definition_name(func_meta.name.id).to_string(), + func_meta.name.location.span, + hir_func.as_expr(), + ); + } + let meta = self .interner .func_meta @@ -1692,4 +1709,78 @@ impl<'context> Elaborator<'context> { _ => true, }) } + + /// Check that a recursive function *can* return without endlessly calling itself. + fn check_for_unbounded_recursion( + &mut self, + func_id: FuncId, + func_name: String, + func_span: Span, + body_id: ExprId, + ) { + if !self.can_return_without_recursing(func_id, body_id) { + self.push_err(CompilationError::ResolverError(ResolverError::UnconditionalRecursion { + name: func_name, + span: func_span, + })); + } + } + + /// Check if an expression will end up calling a specific function. + fn can_return_without_recursing(&self, func_id: FuncId, expr_id: ExprId) -> bool { + let check = |e| self.can_return_without_recursing(func_id, e); + + let check_block = |block: HirBlockExpression| { + block.statements.iter().all(|stmt_id| match self.interner.statement(stmt_id) { + HirStatement::Let(s) => check(s.expression), + HirStatement::Assign(s) => check(s.expression), + HirStatement::Expression(e) => check(e), + HirStatement::Semi(e) => check(e), + // Rust doesn't seem to check the for loop body. + HirStatement::For(e) => check(e.start_range) && check(e.end_range), + HirStatement::Constrain(_) + | HirStatement::Comptime(_) + | HirStatement::Break + | HirStatement::Continue + | HirStatement::Error => true, + }) + }; + + match self.interner.expression(&expr_id) { + HirExpression::Ident(ident, _) => { + if ident.id == DefinitionId::dummy_id() { + return true; + } + let definition = self.interner.definition(ident.id); + if let DefinitionKind::Function(id) = definition.kind { + func_id != id + } else { + true + } + } + HirExpression::Block(b) => check_block(b), + HirExpression::Prefix(e) => check(e.rhs), + HirExpression::Infix(e) => check(e.lhs) && check(e.rhs), + HirExpression::Index(e) => check(e.collection) && check(e.index), + HirExpression::MemberAccess(e) => check(e.lhs), + HirExpression::Call(e) => check(e.func) && e.arguments.iter().cloned().all(check), + HirExpression::MethodCall(e) => { + check(e.object) && e.arguments.iter().cloned().all(check) + } + HirExpression::Cast(e) => check(e.lhs), + HirExpression::If(e) => { + check(e.condition) + && (check(e.consequence) || e.alternative.map(check).unwrap_or(true)) + } + HirExpression::Tuple(e) => e.iter().cloned().all(check), + HirExpression::Lambda(e) => check(e.body), + HirExpression::Unsafe(b) => check_block(b), + HirExpression::Literal(_) + | HirExpression::Constructor(_) + | HirExpression::Quote(_) + | HirExpression::Unquote(_) + | HirExpression::Comptime(_) + | HirExpression::Error => true, + } + } } diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index 4f9907d6a16..3c4022b58bb 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -29,6 +29,8 @@ pub enum ResolverError { UnusedVariable { ident: Ident }, #[error("Unused {}", item.item_type())] UnusedItem { ident: Ident, item: UnusedItem }, + #[error("Unconditional recursion")] + UnconditionalRecursion { name: String, span: Span }, #[error("Could not find variable in this scope")] VariableNotDeclared { name: String, span: Span }, #[error("path is not an identifier")] @@ -213,6 +215,15 @@ impl<'a> From<&'a ResolverError> for Diagnostic { diagnostic.unnecessary = true; diagnostic } + ResolverError::UnconditionalRecursion { name, span} => { + let mut diagnostic = Diagnostic::simple_warning( + format!("function `{name}` cannot return without recursing"), + "function cannot return without recursing".to_string(), + *span, + ); + diagnostic.unnecessary = true; + diagnostic + } ResolverError::VariableNotDeclared { name, span } => Diagnostic::simple_error( format!("cannot find `{name}` in this scope "), "not found in this scope".to_string(), diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index b2800717d90..0d00caf220c 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3389,3 +3389,126 @@ fn arithmetic_generics_rounding_fail() { let errors = get_program_errors(src); assert_eq!(errors.len(), 1); } + +#[test] +fn unconditional_recursion_fail() { + let srcs = vec![ + r#" + fn main() { + main() + } + "#, + r#" + fn main() -> pub bool { + if main() { true } else { false } + } + "#, + r#" + fn main() -> pub bool { + if true { main() } else { main() } + } + "#, + r#" + fn main() -> pub u64 { + main() + main() + } + "#, + r#" + fn main() -> pub u64 { + 1 + main() + } + "#, + r#" + fn main() -> pub bool { + let _ = main(); + true + } + "#, + r#" + fn main(a: u64, b: u64) -> pub u64 { + main(a + b, main(a, b)) + } + "#, + r#" + fn main() -> pub u64 { + foo(1, main()) + } + fn foo(a: u64, b: u64) -> u64 { + a + b + } + "#, + r#" + fn main() -> pub u64 { + let (a, b) = (main(), main()); + a + b + } + "#, + r#" + fn main() -> pub u64 { + let mut sum = 0; + for i in 0 .. main() { + sum += i; + } + sum + } + "#, + ]; + + for src in srcs { + let errors = get_program_errors(src); + assert!( + !errors.is_empty(), + "expected 'unconditional recursion' error, got nothing; src = {src}" + ); + + for (error, _) in errors { + let CompilationError::ResolverError(ResolverError::UnconditionalRecursion { .. }) = + error + else { + panic!("Expected an 'unconditional recursion' error, got {:?}; src = {src}", error); + }; + } + } +} + +#[test] +fn unconditional_recursion_pass() { + let srcs = vec![ + r#" + fn main() { + if false { main(); } + } + "#, + r#" + fn main(i: u64) -> pub u64 { + if i == 0 { 0 } else { i + main(i-1) } + } + "#, + // Only immediate self-recursion is detected. + r#" + fn main() { + foo(); + } + fn foo() { + bar(); + } + fn bar() { + foo(); + } + "#, + // For loop bodies are not checked. + r#" + fn main() -> pub u64 { + let mut sum = 0; + for _ in 0 .. 10 { + sum += main(); + } + sum + } + "#, + ]; + + for src in srcs { + assert_no_errors(src); + } +} diff --git a/noir_stdlib/src/meta/mod.nr b/noir_stdlib/src/meta/mod.nr index f756be364b1..d299fd7d9c0 100644 --- a/noir_stdlib/src/meta/mod.nr +++ b/noir_stdlib/src/meta/mod.nr @@ -256,6 +256,8 @@ mod tests { let _: MyOtherStruct = MyOtherStruct { my_other_field: 2 }; let _ = derive_do_nothing(crate::panic::panic(f"")); let _ = derive_do_nothing_alt(crate::panic::panic(f"")); - remove_unused_warnings(); + if false { + remove_unused_warnings(); + } } }