Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: represent assertions more similarly to function calls #6103

Merged
merged 12 commits into from
Sep 20, 2024
44 changes: 25 additions & 19 deletions aztec_macros/src/transforms/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,14 +316,17 @@ fn create_static_check(fname: &str, is_private: bool) -> Statement {
.iter()
.fold(variable("context"), |acc, member| member_access(acc, member))
};
make_statement(StatementKind::Constrain(ConstrainStatement(
make_eq(is_static_call_expr, expression(ExpressionKind::Literal(Literal::Bool(true)))),
Some(expression(ExpressionKind::Literal(Literal::Str(format!(
"Function {} can only be called statically",
fname
))))),
ConstrainKind::Assert,
)))
make_statement(StatementKind::Constrain(ConstrainStatement {
kind: ConstrainKind::Assert,
arguments: vec![
make_eq(is_static_call_expr, expression(ExpressionKind::Literal(Literal::Bool(true)))),
expression(ExpressionKind::Literal(Literal::Str(format!(
"Function {} can only be called statically",
fname
)))),
],
span: Default::default(),
}))
}

/// Creates a check for internal functions ensuring that the caller is self.
Expand All @@ -332,17 +335,20 @@ fn create_static_check(fname: &str, is_private: bool) -> Statement {
/// assert(context.msg_sender() == context.this_address(), "Function can only be called internally");
/// ```
fn create_internal_check(fname: &str) -> Statement {
make_statement(StatementKind::Constrain(ConstrainStatement(
make_eq(
method_call(variable("context"), "msg_sender", vec![]),
method_call(variable("context"), "this_address", vec![]),
),
Some(expression(ExpressionKind::Literal(Literal::Str(format!(
"Function {} can only be called internally",
fname
))))),
ConstrainKind::Assert,
)))
make_statement(StatementKind::Constrain(ConstrainStatement {
kind: ConstrainKind::Assert,
arguments: vec![
make_eq(
method_call(variable("context"), "msg_sender", vec![]),
method_call(variable("context"), "this_address", vec![]),
),
expression(ExpressionKind::Literal(Literal::Str(format!(
"Function {} can only be called internally",
fname
)))),
],
span: Default::default(),
}))
}

/// Creates a call to assert_initialization_matches_address_preimage to be inserted
Expand Down
5 changes: 1 addition & 4 deletions aztec_macros/src/utils/parse_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,7 @@ fn empty_statement(statement: &mut Statement) {
}

fn empty_constrain_statement(constrain_statement: &mut ConstrainStatement) {
empty_expression(&mut constrain_statement.0);
if let Some(expression) = &mut constrain_statement.1 {
empty_expression(expression);
}
empty_expressions(&mut constrain_statement.arguments);
}

fn empty_expressions(expressions: &mut [Expression]) {
Expand Down
38 changes: 31 additions & 7 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,27 @@ pub enum LValue {
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ConstrainStatement(pub Expression, pub Option<Expression>, pub ConstrainKind);
pub struct ConstrainStatement {
pub kind: ConstrainKind,
pub arguments: Vec<Expression>,
pub span: Span,
}

impl Display for ConstrainStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.kind {
ConstrainKind::Assert | ConstrainKind::AssertEq => write!(
f,
"{}({})",
self.kind,
vecmap(&self.arguments, |arg| arg.to_string()).join(", ")
),
ConstrainKind::Constrain => {
write!(f, "constrain {}", &self.arguments[0])
}
}
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ConstrainKind {
Expand All @@ -571,6 +591,16 @@ pub enum ConstrainKind {
Constrain,
}

impl Display for ConstrainKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConstrainKind::Assert => write!(f, "assert"),
ConstrainKind::AssertEq => write!(f, "assert_eq"),
ConstrainKind::Constrain => write!(f, "constrain"),
}
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Pattern {
Identifier(Ident),
Expand Down Expand Up @@ -885,12 +915,6 @@ impl Display for LetStatement {
}
}

impl Display for ConstrainStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "constrain {}", self.0)
}
}

impl Display for AssignStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} = {}", self.lvalue, self.expression)
Expand Down
6 changes: 1 addition & 5 deletions compiler/noirc_frontend/src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1117,11 +1117,7 @@
}

pub fn accept_children(&self, visitor: &mut impl Visitor) {
self.0.accept(visitor);

if let Some(exp) = &self.1 {
exp.accept(visitor);
}
visit_expressions(&self.arguments, visitor);
}
}

Expand Down Expand Up @@ -1295,8 +1291,8 @@
UnresolvedTypeData::Unspecified => visitor.visit_unspecified_type(self.span),
UnresolvedTypeData::Quoted(typ) => visitor.visit_quoted_type(typ, self.span),
UnresolvedTypeData::FieldElement => visitor.visit_field_element_type(self.span),
UnresolvedTypeData::Integer(signdness, size) => {

Check warning on line 1294 in compiler/noirc_frontend/src/ast/visitor.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (signdness)
visitor.visit_integer_type(*signdness, *size, self.span);

Check warning on line 1295 in compiler/noirc_frontend/src/ast/visitor.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (signdness)
}
UnresolvedTypeData::Bool => visitor.visit_bool_type(self.span),
UnresolvedTypeData::Unit => visitor.visit_unit_type(self.span),
Expand Down
63 changes: 57 additions & 6 deletions compiler/noirc_frontend/src/elaborator/statements.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use noirc_errors::{Location, Span};
use noirc_errors::{Location, Span, Spanned};

use crate::{
ast::{AssignStatement, ConstrainStatement, LValue},
ast::{
AssignStatement, BinaryOpKind, ConstrainKind, ConstrainStatement, Expression,
ExpressionKind, InfixExpression, LValue,
},
hir::{
resolution::errors::ResolverError,
type_check::{Source, TypeCheckError},
Expand Down Expand Up @@ -110,12 +113,60 @@ impl<'context> Elaborator<'context> {
(HirStatement::Let(let_), Type::Unit)
}

pub(super) fn elaborate_constrain(&mut self, stmt: ConstrainStatement) -> (HirStatement, Type) {
let expr_span = stmt.0.span;
let (expr_id, expr_type) = self.elaborate_expression(stmt.0);
pub(super) fn elaborate_constrain(
&mut self,
mut stmt: ConstrainStatement,
) -> (HirStatement, Type) {
let span = stmt.span;
let min_args_count = match stmt.kind {
ConstrainKind::Assert | ConstrainKind::Constrain => 1,
ConstrainKind::AssertEq => 2,
};
let max_args_count = min_args_count + 1;
let actual_args_count = stmt.arguments.len();

let (message, expr) = if actual_args_count < min_args_count
|| actual_args_count > max_args_count
{
self.push_err(TypeCheckError::AssertionParameterCountMismatch {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
expected1: min_args_count,
expected2: max_args_count,
found: actual_args_count,
span,
});

jfecher marked this conversation as resolved.
Show resolved Hide resolved
// Given that we already produced an error, let's make this an `assert(true)` so
// we don't get further errors.
let message = None;
let kind = ExpressionKind::Literal(crate::ast::Literal::Bool(true));
let expr = Expression { kind, span };
(message, expr)
} else {
let message = if actual_args_count == min_args_count {
None
} else {
Some(stmt.arguments.pop().unwrap())
};
let expr = match stmt.kind {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
ConstrainKind::Assert | ConstrainKind::Constrain => stmt.arguments.pop().unwrap(),
ConstrainKind::AssertEq => {
let rhs = stmt.arguments.pop().unwrap();
let lhs = stmt.arguments.pop().unwrap();
let span = Span::from(lhs.span.start()..rhs.span.end());
let operator = Spanned::from(span, BinaryOpKind::Equal);
let kind =
ExpressionKind::Infix(Box::new(InfixExpression { lhs, operator, rhs }));
Expression { kind, span }
}
};
(message, expr)
};

let expr_span = expr.span;
let (expr_id, expr_type) = self.elaborate_expression(expr);

// Must type check the assertion message expression so that we instantiate bindings
let msg = stmt.1.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0);
let msg = message.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0);

self.unify(&expr_type, &Type::Bool, || TypeCheckError::TypeMismatch {
expr_typ: expr_type.to_string(),
Expand Down
11 changes: 9 additions & 2 deletions compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@ impl HirStatement {
}
HirStatement::Constrain(constrain) => {
let expr = constrain.0.to_display_ast(interner);
let message = constrain.2.map(|message| message.to_display_ast(interner));
let mut arguments = vec![expr];
if let Some(message) = constrain.2 {
arguments.push(message.to_display_ast(interner));
}

// TODO: Find difference in usage between Assert & AssertEq
StatementKind::Constrain(ConstrainStatement(expr, message, ConstrainKind::Assert))
StatementKind::Constrain(ConstrainStatement {
kind: ConstrainKind::Assert,
arguments,
span,
})
}
HirStatement::Assign(assign) => StatementKind::Assign(AssignStatement {
lvalue: assign.lvalue.to_display_ast(interner),
Expand Down
39 changes: 28 additions & 11 deletions compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1232,9 +1232,17 @@ fn expr_as_assert(
location: Location,
) -> IResult<Value> {
expr_as(interner, arguments, return_type.clone(), location, |expr| {
if let ExprValue::Statement(StatementKind::Constrain(constrain)) = expr {
if constrain.2 == ConstrainKind::Assert {
let predicate = Value::expression(constrain.0.kind);
if let ExprValue::Statement(StatementKind::Constrain(mut constrain)) = expr {
if constrain.kind == ConstrainKind::Assert
&& !constrain.arguments.is_empty()
&& constrain.arguments.len() <= 2
{
let (message, predicate) = if constrain.arguments.len() == 1 {
(None, constrain.arguments.pop().unwrap())
} else {
(Some(constrain.arguments.pop().unwrap()), constrain.arguments.pop().unwrap())
};
let predicate = Value::expression(predicate.kind);

let option_type = extract_option_generic_type(return_type);
let Type::Tuple(mut tuple_types) = option_type else {
Expand All @@ -1243,7 +1251,7 @@ fn expr_as_assert(
assert_eq!(tuple_types.len(), 2);

let option_type = tuple_types.pop().unwrap();
let message = constrain.1.map(|message| Value::expression(message.kind));
let message = message.map(|msg| Value::expression(msg.kind));
let message = option(option_type, message).ok()?;

Some(Value::Tuple(vec![predicate, message]))
Expand All @@ -1264,14 +1272,23 @@ fn expr_as_assert_eq(
location: Location,
) -> IResult<Value> {
expr_as(interner, arguments, return_type.clone(), location, |expr| {
if let ExprValue::Statement(StatementKind::Constrain(constrain)) = expr {
if constrain.2 == ConstrainKind::AssertEq {
let ExpressionKind::Infix(infix) = constrain.0.kind else {
panic!("Expected AssertEq constrain statement to have an infix expression");
if let ExprValue::Statement(StatementKind::Constrain(mut constrain)) = expr {
if constrain.kind == ConstrainKind::AssertEq
&& constrain.arguments.len() >= 2
&& constrain.arguments.len() <= 3
{
let (message, rhs, lhs) = if constrain.arguments.len() == 2 {
(None, constrain.arguments.pop().unwrap(), constrain.arguments.pop().unwrap())
} else {
(
Some(constrain.arguments.pop().unwrap()),
constrain.arguments.pop().unwrap(),
constrain.arguments.pop().unwrap(),
)
};

let lhs = Value::expression(infix.lhs.kind);
let rhs = Value::expression(infix.rhs.kind);
let lhs = Value::expression(lhs.kind);
let rhs = Value::expression(rhs.kind);

let option_type = extract_option_generic_type(return_type);
let Type::Tuple(mut tuple_types) = option_type else {
Expand All @@ -1280,7 +1297,7 @@ fn expr_as_assert_eq(
assert_eq!(tuple_types.len(), 3);

let option_type = tuple_types.pop().unwrap();
let message = constrain.1.map(|message| Value::expression(message.kind));
let message = message.map(|message| Value::expression(message.kind));
let message = option(option_type, message).ok()?;

Some(Value::Tuple(vec![lhs, rhs, message]))
Expand Down
11 changes: 6 additions & 5 deletions compiler/noirc_frontend/src/hir/comptime/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@
Value::Expr(ExprValue::Statement(statement))
}

pub(crate) fn lvalue(lvaue: LValue) -> Self {

Check warning on line 104 in compiler/noirc_frontend/src/hir/comptime/value.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (lvaue)
Value::Expr(ExprValue::LValue(lvaue))

Check warning on line 105 in compiler/noirc_frontend/src/hir/comptime/value.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (lvaue)
}

pub(crate) fn pattern(pattern: Pattern) -> Self {
Expand Down Expand Up @@ -952,11 +952,12 @@
r#type: remove_interned_in_unresolved_type(interner, let_statement.r#type),
..let_statement
}),
StatementKind::Constrain(constrain) => StatementKind::Constrain(ConstrainStatement(
remove_interned_in_expression(interner, constrain.0),
constrain.1.map(|expr| remove_interned_in_expression(interner, expr)),
constrain.2,
)),
StatementKind::Constrain(constrain) => StatementKind::Constrain(ConstrainStatement {
arguments: vecmap(constrain.arguments, |expr| {
remove_interned_in_expression(interner, expr)
}),
..constrain
}),
StatementKind::Expression(expr) => {
StatementKind::Expression(remove_interned_in_expression(interner, expr))
}
Expand Down
7 changes: 7 additions & 0 deletions compiler/noirc_frontend/src/hir/type_check/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ pub enum TypeCheckError {
AccessUnknownMember { lhs_type: Type, field_name: String, span: Span },
#[error("Function expects {expected} parameters but {found} were given")]
ParameterCountMismatch { expected: usize, found: usize, span: Span },
#[error("Function expects {expected1} or {expected2} parameters but {found} were given")]
AssertionParameterCountMismatch { expected1: usize, expected2: usize, found: usize, span: Span },
#[error("{item} expects {expected} generics but {found} were given")]
GenericCountMismatch { item: String, expected: usize, found: usize, span: Span },
#[error("{item} has incompatible `unconstrained`")]
Expand Down Expand Up @@ -260,6 +262,11 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic {
let msg = format!("Function expects {expected} parameter{empty_or_s} but {found} {was_or_were} given");
Diagnostic::simple_error(msg, String::new(), *span)
}
TypeCheckError::AssertionParameterCountMismatch { expected1, expected2, found, span } => {
let was_or_were = if *found == 1 { "was" } else { "were" };
let msg = format!("Function expects {expected1} or {expected2} parameters but {found} {was_or_were} given");
jfecher marked this conversation as resolved.
Show resolved Hide resolved
Diagnostic::simple_error(msg, String::new(), *span)
}
TypeCheckError::GenericCountMismatch { item, expected, found, span } => {
let empty_or_s = if *expected == 1 { "" } else { "s" };
let was_or_were = if *found == 1 { "was" } else { "were" };
Expand Down
17 changes: 6 additions & 11 deletions compiler/noirc_frontend/src/parser/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,6 @@ where
choice((
assertion::constrain(expr_parser.clone()),
assertion::assertion(expr_parser.clone()),
assertion::assertion_eq(expr_parser.clone()),
declaration(expr_parser.clone()),
assignment(expr_parser.clone()),
if_statement(expr_no_constructors.clone(), statement.clone()),
Expand Down Expand Up @@ -1629,17 +1628,13 @@ mod test {
Case { source: "let", expect: "let $error = Error", errors: 3 },
Case { source: "foo = one two three", expect: "foo = one", errors: 1 },
Case { source: "constrain", expect: "constrain Error", errors: 2 },
Case { source: "assert", expect: "constrain Error", errors: 1 },
Case { source: "assert", expect: "assert()", errors: 1 },
Case { source: "constrain x ==", expect: "constrain (x == Error)", errors: 2 },
Case { source: "assert(x ==)", expect: "constrain (x == Error)", errors: 1 },
Case { source: "assert(x == x, x)", expect: "constrain (x == x)", errors: 0 },
Case { source: "assert_eq(x,)", expect: "constrain (Error == Error)", errors: 1 },
Case {
source: "assert_eq(x, x, x, x)",
expect: "constrain (Error == Error)",
errors: 1,
},
Case { source: "assert_eq(x, x, x)", expect: "constrain (x == x)", errors: 0 },
Case { source: "assert(x ==)", expect: "assert((x == Error))", errors: 1 },
Case { source: "assert(x == x, x)", expect: "assert((x == x), x)", errors: 0 },
Case { source: "assert_eq(x,)", expect: "assert_eq(x)", errors: 0 },
Case { source: "assert_eq(x, x, x, x)", expect: "assert_eq(x, x, x, x)", errors: 0 },
Case { source: "assert_eq(x, x, x)", expect: "assert_eq(x, x, x)", errors: 0 },
];

check_cases_with_errors(&cases[..], fresh_statement());
Expand Down
Loading
Loading