diff --git a/aztec_macros/src/transforms/functions.rs b/aztec_macros/src/transforms/functions.rs index 39c0ca344e6..6c8af308a52 100644 --- a/aztec_macros/src/transforms/functions.rs +++ b/aztec_macros/src/transforms/functions.rs @@ -316,14 +316,17 @@ fn create_static_check(fname: &str, is_private: bool) -> Statement { .iter() .fold(variable("context"), |acc, member| member_access(acc, member)) }; - make_statement(StatementKind::Constrain(ConstrainStatement( - make_eq(is_static_call_expr, expression(ExpressionKind::Literal(Literal::Bool(true)))), - Some(expression(ExpressionKind::Literal(Literal::Str(format!( - "Function {} can only be called statically", - fname - ))))), - ConstrainKind::Assert, - ))) + make_statement(StatementKind::Constrain(ConstrainStatement { + kind: ConstrainKind::Assert, + arguments: vec![ + make_eq(is_static_call_expr, expression(ExpressionKind::Literal(Literal::Bool(true)))), + expression(ExpressionKind::Literal(Literal::Str(format!( + "Function {} can only be called statically", + fname + )))), + ], + span: Default::default(), + })) } /// Creates a check for internal functions ensuring that the caller is self. @@ -332,17 +335,20 @@ fn create_static_check(fname: &str, is_private: bool) -> Statement { /// assert(context.msg_sender() == context.this_address(), "Function can only be called internally"); /// ``` fn create_internal_check(fname: &str) -> Statement { - make_statement(StatementKind::Constrain(ConstrainStatement( - make_eq( - method_call(variable("context"), "msg_sender", vec![]), - method_call(variable("context"), "this_address", vec![]), - ), - Some(expression(ExpressionKind::Literal(Literal::Str(format!( - "Function {} can only be called internally", - fname - ))))), - ConstrainKind::Assert, - ))) + make_statement(StatementKind::Constrain(ConstrainStatement { + kind: ConstrainKind::Assert, + arguments: vec![ + make_eq( + method_call(variable("context"), "msg_sender", vec![]), + method_call(variable("context"), "this_address", vec![]), + ), + expression(ExpressionKind::Literal(Literal::Str(format!( + "Function {} can only be called internally", + fname + )))), + ], + span: Default::default(), + })) } /// Creates a call to assert_initialization_matches_address_preimage to be inserted diff --git a/aztec_macros/src/utils/parse_utils.rs b/aztec_macros/src/utils/parse_utils.rs index efa31860b6e..61f54377284 100644 --- a/aztec_macros/src/utils/parse_utils.rs +++ b/aztec_macros/src/utils/parse_utils.rs @@ -243,10 +243,7 @@ fn empty_statement(statement: &mut Statement) { } fn empty_constrain_statement(constrain_statement: &mut ConstrainStatement) { - empty_expression(&mut constrain_statement.0); - if let Some(expression) = &mut constrain_statement.1 { - empty_expression(expression); - } + empty_expressions(&mut constrain_statement.arguments); } fn empty_expressions(expressions: &mut [Expression]) { diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index d67501e932b..49568d42038 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -562,7 +562,27 @@ pub enum LValue { } #[derive(Debug, PartialEq, Eq, Clone)] -pub struct ConstrainStatement(pub Expression, pub Option, pub ConstrainKind); +pub struct ConstrainStatement { + pub kind: ConstrainKind, + pub arguments: Vec, + pub span: Span, +} + +impl Display for ConstrainStatement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.kind { + ConstrainKind::Assert | ConstrainKind::AssertEq => write!( + f, + "{}({})", + self.kind, + vecmap(&self.arguments, |arg| arg.to_string()).join(", ") + ), + ConstrainKind::Constrain => { + write!(f, "constrain {}", &self.arguments[0]) + } + } + } +} #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum ConstrainKind { @@ -571,6 +591,25 @@ pub enum ConstrainKind { Constrain, } +impl ConstrainKind { + pub fn required_arguments_count(&self) -> usize { + match self { + ConstrainKind::Assert | ConstrainKind::Constrain => 1, + ConstrainKind::AssertEq => 2, + } + } +} + +impl Display for ConstrainKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ConstrainKind::Assert => write!(f, "assert"), + ConstrainKind::AssertEq => write!(f, "assert_eq"), + ConstrainKind::Constrain => write!(f, "constrain"), + } + } +} + #[derive(Debug, PartialEq, Eq, Clone)] pub enum Pattern { Identifier(Ident), @@ -885,12 +924,6 @@ impl Display for LetStatement { } } -impl Display for ConstrainStatement { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "constrain {}", self.0) - } -} - impl Display for AssignStatement { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{} = {}", self.lvalue, self.expression) diff --git a/compiler/noirc_frontend/src/ast/visitor.rs b/compiler/noirc_frontend/src/ast/visitor.rs index 9a2fb79ca88..fb116d34e8a 100644 --- a/compiler/noirc_frontend/src/ast/visitor.rs +++ b/compiler/noirc_frontend/src/ast/visitor.rs @@ -1117,11 +1117,7 @@ impl ConstrainStatement { } pub fn accept_children(&self, visitor: &mut impl Visitor) { - self.0.accept(visitor); - - if let Some(exp) = &self.1 { - exp.accept(visitor); - } + visit_expressions(&self.arguments, visitor); } } diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index 543cf20b647..2d46c4c6341 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -1,7 +1,10 @@ -use noirc_errors::{Location, Span}; +use noirc_errors::{Location, Span, Spanned}; use crate::{ - ast::{AssignStatement, ConstrainStatement, LValue}, + ast::{ + AssignStatement, BinaryOpKind, ConstrainKind, ConstrainStatement, Expression, + ExpressionKind, InfixExpression, LValue, + }, hir::{ resolution::errors::ResolverError, type_check::{Source, TypeCheckError}, @@ -110,12 +113,51 @@ impl<'context> Elaborator<'context> { (HirStatement::Let(let_), Type::Unit) } - pub(super) fn elaborate_constrain(&mut self, stmt: ConstrainStatement) -> (HirStatement, Type) { - let expr_span = stmt.0.span; - let (expr_id, expr_type) = self.elaborate_expression(stmt.0); + pub(super) fn elaborate_constrain( + &mut self, + mut stmt: ConstrainStatement, + ) -> (HirStatement, Type) { + let span = stmt.span; + let min_args_count = stmt.kind.required_arguments_count(); + let max_args_count = min_args_count + 1; + let actual_args_count = stmt.arguments.len(); + + let (message, expr) = if !(min_args_count..=max_args_count).contains(&actual_args_count) { + self.push_err(TypeCheckError::AssertionParameterCountMismatch { + kind: stmt.kind, + found: actual_args_count, + span, + }); + + // Given that we already produced an error, let's make this an `assert(true)` so + // we don't get further errors. + let message = None; + let kind = ExpressionKind::Literal(crate::ast::Literal::Bool(true)); + let expr = Expression { kind, span }; + (message, expr) + } else { + let message = + (actual_args_count != min_args_count).then(|| stmt.arguments.pop().unwrap()); + let expr = match stmt.kind { + ConstrainKind::Assert | ConstrainKind::Constrain => stmt.arguments.pop().unwrap(), + ConstrainKind::AssertEq => { + let rhs = stmt.arguments.pop().unwrap(); + let lhs = stmt.arguments.pop().unwrap(); + let span = Span::from(lhs.span.start()..rhs.span.end()); + let operator = Spanned::from(span, BinaryOpKind::Equal); + let kind = + ExpressionKind::Infix(Box::new(InfixExpression { lhs, operator, rhs })); + Expression { kind, span } + } + }; + (message, expr) + }; + + let expr_span = expr.span; + let (expr_id, expr_type) = self.elaborate_expression(expr); // Must type check the assertion message expression so that we instantiate bindings - let msg = stmt.1.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0); + let msg = message.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0); self.unify(&expr_type, &Type::Bool, || TypeCheckError::TypeMismatch { expr_typ: expr_type.to_string(), diff --git a/compiler/noirc_frontend/src/hir/comptime/display.rs b/compiler/noirc_frontend/src/hir/comptime/display.rs index 869e5517d6c..143e0450bac 100644 --- a/compiler/noirc_frontend/src/hir/comptime/display.rs +++ b/compiler/noirc_frontend/src/hir/comptime/display.rs @@ -683,11 +683,12 @@ fn remove_interned_in_statement_kind( r#type: remove_interned_in_unresolved_type(interner, let_statement.r#type), ..let_statement }), - StatementKind::Constrain(constrain) => StatementKind::Constrain(ConstrainStatement( - remove_interned_in_expression(interner, constrain.0), - constrain.1.map(|expr| remove_interned_in_expression(interner, expr)), - constrain.2, - )), + StatementKind::Constrain(constrain) => StatementKind::Constrain(ConstrainStatement { + arguments: vecmap(constrain.arguments, |expr| { + remove_interned_in_expression(interner, expr) + }), + ..constrain + }), StatementKind::Expression(expr) => { StatementKind::Expression(remove_interned_in_expression(interner, expr)) } diff --git a/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs b/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs index 4a159c682b7..972826f5b7c 100644 --- a/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs +++ b/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs @@ -33,10 +33,17 @@ impl HirStatement { } HirStatement::Constrain(constrain) => { let expr = constrain.0.to_display_ast(interner); - let message = constrain.2.map(|message| message.to_display_ast(interner)); + let mut arguments = vec![expr]; + if let Some(message) = constrain.2 { + arguments.push(message.to_display_ast(interner)); + } // TODO: Find difference in usage between Assert & AssertEq - StatementKind::Constrain(ConstrainStatement(expr, message, ConstrainKind::Assert)) + StatementKind::Constrain(ConstrainStatement { + kind: ConstrainKind::Assert, + arguments, + span, + }) } HirStatement::Assign(assign) => StatementKind::Assign(AssignStatement { lvalue: assign.lvalue.to_display_ast(interner), diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index e9c8fbe99a5..4678d29a452 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -1240,9 +1240,17 @@ fn expr_as_assert( location: Location, ) -> IResult { expr_as(interner, arguments, return_type.clone(), location, |expr| { - if let ExprValue::Statement(StatementKind::Constrain(constrain)) = expr { - if constrain.2 == ConstrainKind::Assert { - let predicate = Value::expression(constrain.0.kind); + if let ExprValue::Statement(StatementKind::Constrain(mut constrain)) = expr { + if constrain.kind == ConstrainKind::Assert + && !constrain.arguments.is_empty() + && constrain.arguments.len() <= 2 + { + let (message, predicate) = if constrain.arguments.len() == 1 { + (None, constrain.arguments.pop().unwrap()) + } else { + (Some(constrain.arguments.pop().unwrap()), constrain.arguments.pop().unwrap()) + }; + let predicate = Value::expression(predicate.kind); let option_type = extract_option_generic_type(return_type); let Type::Tuple(mut tuple_types) = option_type else { @@ -1251,7 +1259,7 @@ fn expr_as_assert( assert_eq!(tuple_types.len(), 2); let option_type = tuple_types.pop().unwrap(); - let message = constrain.1.map(|message| Value::expression(message.kind)); + let message = message.map(|msg| Value::expression(msg.kind)); let message = option(option_type, message).ok()?; Some(Value::Tuple(vec![predicate, message])) @@ -1272,14 +1280,23 @@ fn expr_as_assert_eq( location: Location, ) -> IResult { expr_as(interner, arguments, return_type.clone(), location, |expr| { - if let ExprValue::Statement(StatementKind::Constrain(constrain)) = expr { - if constrain.2 == ConstrainKind::AssertEq { - let ExpressionKind::Infix(infix) = constrain.0.kind else { - panic!("Expected AssertEq constrain statement to have an infix expression"); + if let ExprValue::Statement(StatementKind::Constrain(mut constrain)) = expr { + if constrain.kind == ConstrainKind::AssertEq + && constrain.arguments.len() >= 2 + && constrain.arguments.len() <= 3 + { + let (message, rhs, lhs) = if constrain.arguments.len() == 2 { + (None, constrain.arguments.pop().unwrap(), constrain.arguments.pop().unwrap()) + } else { + ( + Some(constrain.arguments.pop().unwrap()), + constrain.arguments.pop().unwrap(), + constrain.arguments.pop().unwrap(), + ) }; - let lhs = Value::expression(infix.lhs.kind); - let rhs = Value::expression(infix.rhs.kind); + let lhs = Value::expression(lhs.kind); + let rhs = Value::expression(rhs.kind); let option_type = extract_option_generic_type(return_type); let Type::Tuple(mut tuple_types) = option_type else { @@ -1288,7 +1305,7 @@ fn expr_as_assert_eq( assert_eq!(tuple_types.len(), 3); let option_type = tuple_types.pop().unwrap(); - let message = constrain.1.map(|message| Value::expression(message.kind)); + let message = message.map(|message| Value::expression(message.kind)); let message = option(option_type, message).ok()?; Some(Value::Tuple(vec![lhs, rhs, message])) diff --git a/compiler/noirc_frontend/src/hir/type_check/errors.rs b/compiler/noirc_frontend/src/hir/type_check/errors.rs index 155564536d6..00e73e682e8 100644 --- a/compiler/noirc_frontend/src/hir/type_check/errors.rs +++ b/compiler/noirc_frontend/src/hir/type_check/errors.rs @@ -5,6 +5,7 @@ use noirc_errors::CustomDiagnostic as Diagnostic; use noirc_errors::Span; use thiserror::Error; +use crate::ast::ConstrainKind; use crate::ast::{BinaryOpKind, FunctionReturnType, IntegerBitSize, Signedness}; use crate::hir::resolution::errors::ResolverError; use crate::hir_def::expr::HirBinaryOp; @@ -59,6 +60,8 @@ pub enum TypeCheckError { AccessUnknownMember { lhs_type: Type, field_name: String, span: Span }, #[error("Function expects {expected} parameters but {found} were given")] ParameterCountMismatch { expected: usize, found: usize, span: Span }, + #[error("{} expects {} or {} parameters but {found} were given", kind, kind.required_arguments_count(), kind.required_arguments_count() + 1)] + AssertionParameterCountMismatch { kind: ConstrainKind, found: usize, span: Span }, #[error("{item} expects {expected} generics but {found} were given")] GenericCountMismatch { item: String, expected: usize, found: usize, span: Span }, #[error("{item} has incompatible `unconstrained`")] @@ -260,6 +263,13 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic { let msg = format!("Function expects {expected} parameter{empty_or_s} but {found} {was_or_were} given"); Diagnostic::simple_error(msg, String::new(), *span) } + TypeCheckError::AssertionParameterCountMismatch { kind, found, span } => { + let was_or_were = if *found == 1 { "was" } else { "were" }; + let min = kind.required_arguments_count(); + let max = min + 1; + let msg = format!("{kind} expects {min} or {max} parameters but {found} {was_or_were} given"); + Diagnostic::simple_error(msg, String::new(), *span) + } TypeCheckError::GenericCountMismatch { item, expected, found, span } => { let empty_or_s = if *expected == 1 { "" } else { "s" }; let was_or_were = if *found == 1 { "was" } else { "were" }; diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 337563213e5..b007653062b 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -465,7 +465,6 @@ where choice(( assertion::constrain(expr_parser.clone()), assertion::assertion(expr_parser.clone()), - assertion::assertion_eq(expr_parser.clone()), declaration(expr_parser.clone()), assignment(expr_parser.clone()), if_statement(expr_no_constructors.clone(), statement.clone()), @@ -1629,17 +1628,13 @@ mod test { Case { source: "let", expect: "let $error = Error", errors: 3 }, Case { source: "foo = one two three", expect: "foo = one", errors: 1 }, Case { source: "constrain", expect: "constrain Error", errors: 2 }, - Case { source: "assert", expect: "constrain Error", errors: 1 }, + Case { source: "assert", expect: "assert()", errors: 1 }, Case { source: "constrain x ==", expect: "constrain (x == Error)", errors: 2 }, - Case { source: "assert(x ==)", expect: "constrain (x == Error)", errors: 1 }, - Case { source: "assert(x == x, x)", expect: "constrain (x == x)", errors: 0 }, - Case { source: "assert_eq(x,)", expect: "constrain (Error == Error)", errors: 1 }, - Case { - source: "assert_eq(x, x, x, x)", - expect: "constrain (Error == Error)", - errors: 1, - }, - Case { source: "assert_eq(x, x, x)", expect: "constrain (x == x)", errors: 0 }, + Case { source: "assert(x ==)", expect: "assert((x == Error))", errors: 1 }, + Case { source: "assert(x == x, x)", expect: "assert((x == x), x)", errors: 0 }, + Case { source: "assert_eq(x,)", expect: "assert_eq(x)", errors: 0 }, + Case { source: "assert_eq(x, x, x, x)", expect: "assert_eq(x, x, x, x)", errors: 0 }, + Case { source: "assert_eq(x, x, x)", expect: "assert_eq(x, x, x)", errors: 0 }, ]; check_cases_with_errors(&cases[..], fresh_statement()); diff --git a/compiler/noirc_frontend/src/parser/parser/assertion.rs b/compiler/noirc_frontend/src/parser/parser/assertion.rs index ed08a4c9922..9eb429ef295 100644 --- a/compiler/noirc_frontend/src/parser/parser/assertion.rs +++ b/compiler/noirc_frontend/src/parser/parser/assertion.rs @@ -1,14 +1,11 @@ -use crate::ast::{Expression, ExpressionKind, StatementKind}; -use crate::parser::{ - ignore_then_commit, labels::ParsingRuleLabel, parenthesized, ExprParser, NoirParser, - ParserError, ParserErrorReason, -}; +use crate::ast::StatementKind; +use crate::parser::{ignore_then_commit, then_commit, ParserError, ParserErrorReason}; +use crate::parser::{labels::ParsingRuleLabel, parenthesized, ExprParser, NoirParser}; -use crate::ast::{BinaryOpKind, ConstrainKind, ConstrainStatement, InfixExpression, Recoverable}; +use crate::ast::{ConstrainKind, ConstrainStatement}; use crate::token::{Keyword, Token}; use chumsky::prelude::*; -use noirc_errors::Spanned; use super::keyword; @@ -20,7 +17,13 @@ where keyword(Keyword::Constrain).labelled(ParsingRuleLabel::Statement), expr_parser, ) - .map(|expr| StatementKind::Constrain(ConstrainStatement(expr, None, ConstrainKind::Constrain))) + .map_with_span(|expr, span| { + StatementKind::Constrain(ConstrainStatement { + kind: ConstrainKind::Constrain, + arguments: vec![expr], + span, + }) + }) .validate(|expr, span, emit| { emit(ParserError::with_reason(ParserErrorReason::ConstrainDeprecated, span)); expr @@ -31,42 +34,17 @@ pub(super) fn assertion<'a, P>(expr_parser: P) -> impl NoirParser where P: ExprParser + 'a, { - let argument_parser = - expr_parser.separated_by(just(Token::Comma)).allow_trailing().at_least(1).at_most(2); - - ignore_then_commit(keyword(Keyword::Assert), parenthesized(argument_parser)) - .labelled(ParsingRuleLabel::Statement) - .validate(|expressions, span, _| { - let condition = expressions.first().unwrap_or(&Expression::error(span)).clone(); - let message = expressions.get(1).cloned(); - StatementKind::Constrain(ConstrainStatement(condition, message, ConstrainKind::Assert)) - }) -} + let keyword = choice(( + keyword(Keyword::Assert).map(|_| ConstrainKind::Assert), + keyword(Keyword::AssertEq).map(|_| ConstrainKind::AssertEq), + )); -pub(super) fn assertion_eq<'a, P>(expr_parser: P) -> impl NoirParser + 'a -where - P: ExprParser + 'a, -{ - let argument_parser = - expr_parser.separated_by(just(Token::Comma)).allow_trailing().at_least(2).at_most(3); + let argument_parser = expr_parser.separated_by(just(Token::Comma)).allow_trailing(); - ignore_then_commit(keyword(Keyword::AssertEq), parenthesized(argument_parser)) + then_commit(keyword, parenthesized(argument_parser)) .labelled(ParsingRuleLabel::Statement) - .validate(|exprs: Vec, span, _| { - let predicate = Expression::new( - ExpressionKind::Infix(Box::new(InfixExpression { - lhs: exprs.first().unwrap_or(&Expression::error(span)).clone(), - rhs: exprs.get(1).unwrap_or(&Expression::error(span)).clone(), - operator: Spanned::from(span, BinaryOpKind::Equal), - })), - span, - ); - let message = exprs.get(2).cloned(); - StatementKind::Constrain(ConstrainStatement( - predicate, - message, - ConstrainKind::AssertEq, - )) + .map_with_span(|(kind, arguments), span| { + StatementKind::Constrain(ConstrainStatement { arguments, kind, span }) }) } @@ -74,7 +52,7 @@ where mod test { use super::*; use crate::{ - ast::Literal, + ast::{BinaryOpKind, ExpressionKind, Literal}, parser::parser::{ expression, test_helpers::{parse_all, parse_all_failing, parse_with}, @@ -174,11 +152,11 @@ mod test { match parse_with(assertion(expression()), "assert(x == y, \"assertion message\")").unwrap() { - StatementKind::Constrain(ConstrainStatement(_, message, _)) => { - let message = message.unwrap(); - match message.kind { + StatementKind::Constrain(ConstrainStatement { arguments, .. }) => { + let message = arguments.last().unwrap(); + match &message.kind { ExpressionKind::Literal(Literal::Str(message_string)) => { - assert_eq!(message_string, "assertion message".to_owned()); + assert_eq!(message_string, "assertion message"); } _ => unreachable!(), } @@ -191,7 +169,7 @@ mod test { #[test] fn parse_assert_eq() { parse_all( - assertion_eq(expression()), + assertion(expression()), vec![ "assert_eq(x, y)", "assert_eq(((x + y) == k) + z, y)", @@ -201,14 +179,13 @@ mod test { "assert_eq(x + x ^ x, y | m)", ], ); - match parse_with(assertion_eq(expression()), "assert_eq(x, y, \"assertion message\")") - .unwrap() + match parse_with(assertion(expression()), "assert_eq(x, y, \"assertion message\")").unwrap() { - StatementKind::Constrain(ConstrainStatement(_, message, _)) => { - let message = message.unwrap(); - match message.kind { + StatementKind::Constrain(ConstrainStatement { arguments, .. }) => { + let message = arguments.last().unwrap(); + match &message.kind { ExpressionKind::Literal(Literal::Str(message_string)) => { - assert_eq!(message_string, "assertion message".to_owned()); + assert_eq!(message_string, "assertion message"); } _ => unreachable!(), } diff --git a/tooling/lsp/src/requests/signature_help.rs b/tooling/lsp/src/requests/signature_help.rs index 8f1bbb1570f..b075fea1d1e 100644 --- a/tooling/lsp/src/requests/signature_help.rs +++ b/tooling/lsp/src/requests/signature_help.rs @@ -8,8 +8,8 @@ use lsp_types::{ use noirc_errors::{Location, Span}; use noirc_frontend::{ ast::{ - CallExpression, ConstrainKind, ConstrainStatement, Expression, ExpressionKind, - FunctionReturnType, MethodCallExpression, Statement, Visitor, + CallExpression, ConstrainKind, ConstrainStatement, Expression, FunctionReturnType, + MethodCallExpression, Statement, Visitor, }, hir_def::{function::FuncMeta, stmt::HirPattern}, macros_api::NodeInterner, @@ -375,39 +375,24 @@ impl<'a> Visitor for SignatureFinder<'a> { return false; } - let arguments_span = if let Some(expr) = &constrain_statement.1 { - Span::from(constrain_statement.0.span.start()..expr.span.end()) - } else { - constrain_statement.0.span - }; + let kind_len = constrain_statement.kind.to_string().len() as u32; + let span = constrain_statement.span; + let arguments_span = Span::from(span.start() + kind_len + 1..span.end() - 1); if !self.includes_span(arguments_span) { return false; } - match constrain_statement.2 { - ConstrainKind::Assert => { - let mut arguments = vec![constrain_statement.0.clone()]; - if let Some(expr) = &constrain_statement.1 { - arguments.push(expr.clone()); - } + let active_parameter = self.compute_active_parameter(&constrain_statement.arguments); - let active_parameter = self.compute_active_parameter(&arguments); + match constrain_statement.kind { + ConstrainKind::Assert => { let signature_information = self.assert_signature_information(active_parameter); self.set_signature_help(signature_information); } ConstrainKind::AssertEq => { - if let ExpressionKind::Infix(infix) = &constrain_statement.0.kind { - let mut arguments = vec![infix.lhs.clone(), infix.rhs.clone()]; - if let Some(expr) = &constrain_statement.1 { - arguments.push(expr.clone()); - } - - let active_parameter = self.compute_active_parameter(&arguments); - let signature_information = - self.assert_eq_signature_information(active_parameter); - self.set_signature_help(signature_information); - } + let signature_information = self.assert_eq_signature_information(active_parameter); + self.set_signature_help(signature_information); } ConstrainKind::Constrain => (), } diff --git a/tooling/nargo_fmt/src/visitor/stmt.rs b/tooling/nargo_fmt/src/visitor/stmt.rs index b5ac14a33b3..8908aabd87c 100644 --- a/tooling/nargo_fmt/src/visitor/stmt.rs +++ b/tooling/nargo_fmt/src/visitor/stmt.rs @@ -2,9 +2,7 @@ use std::iter::zip; use noirc_frontend::macros_api::Span; -use noirc_frontend::ast::{ - ConstrainKind, ConstrainStatement, ExpressionKind, ForRange, Statement, StatementKind, -}; +use noirc_frontend::ast::{ConstrainKind, ConstrainStatement, ForRange, Statement, StatementKind}; use crate::{rewrite, visitor::expr::wrap_exprs}; @@ -38,37 +36,21 @@ impl super::FmtVisitor<'_> { self.push_rewrite(format!("{let_str} {expr_str};"), span); } - StatementKind::Constrain(ConstrainStatement(expr, message, kind)) => { + StatementKind::Constrain(ConstrainStatement { kind, arguments, span: _ }) => { let mut nested_shape = self.shape(); let shape = nested_shape; nested_shape.indent.block_indent(self.config); - let message = message.map_or(String::new(), |message| { - let message = rewrite::sub_expr(self, nested_shape, message); - format!(", {message}") - }); - - let (callee, args) = match kind { - ConstrainKind::Assert | ConstrainKind::Constrain => { - let assertion = rewrite::sub_expr(self, nested_shape, expr); - let args = format!("{assertion}{message}"); - - ("assert", args) - } - ConstrainKind::AssertEq => { - if let ExpressionKind::Infix(infix) = expr.kind { - let lhs = rewrite::sub_expr(self, nested_shape, infix.lhs); - let rhs = rewrite::sub_expr(self, nested_shape, infix.rhs); - - let args = format!("{lhs}, {rhs}{message}"); - - ("assert_eq", args) - } else { - unreachable!() - } - } + let callee = match kind { + ConstrainKind::Assert | ConstrainKind::Constrain => "assert", + ConstrainKind::AssertEq => "assert_eq", }; + let args = arguments + .into_iter() + .map(|arg| rewrite::sub_expr(self, nested_shape, arg)) + .collect::>() + .join(", "); let args = wrap_exprs( "(",