Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of guards while flattening. #28392

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions compiler/passes/src/flattening/flatten_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::Flattener;
use crate::{Flattener, ReturnGuard};

use leo_ast::{Function, ProgramReconstructor, ProgramScope, Statement, StatementReconstructor};
use leo_ast::{
Expression,
Function,
ProgramReconstructor,
ProgramScope,
ReturnStatement,
Statement,
StatementReconstructor,
};

impl ProgramReconstructor for Flattener<'_> {
/// Flattens a program scope.
Expand Down Expand Up @@ -47,11 +55,19 @@ impl ProgramReconstructor for Flattener<'_> {
// Flatten the function body.
let mut block = self.reconstruct_block(function.block).0;

// Get all of the guards and return expression.
let returns = self.clear_early_returns();

// Fold the return statements into the block.
self.fold_returns(&mut block, returns);
let returns = std::mem::take(&mut self.returns);
let expression_returns: Vec<(Option<Expression>, ReturnStatement)> = returns
.into_iter()
.map(|(guard, statement)| match guard {
ReturnGuard::None => (None, statement),
ReturnGuard::Unconstructed(plain) | ReturnGuard::Constructed { plain, .. } => {
(Some(Expression::Identifier(plain)), statement)
}
})
.collect();

self.fold_returns(&mut block, expression_returns);

Function {
annotations: function.annotations,
Expand Down
204 changes: 126 additions & 78 deletions compiler/passes/src/flattening/flatten_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::Flattener;
use crate::{Flattener, Guard, ReturnGuard};

use leo_ast::{
AssertStatement,
Expand All @@ -28,6 +28,7 @@ use leo_ast::{
DefinitionStatement,
Expression,
ExpressionReconstructor,
Identifier,
IterationStatement,
Node,
ReturnStatement,
Expand Down Expand Up @@ -93,77 +94,98 @@ impl StatementReconstructor for Flattener<'_> {
},
};

// Add the appropriate guards.
match self.construct_guard() {
// If the condition stack is empty, we can return the flattened assert statement.
None => (Statement::Assert(assert), statements),
// Otherwise, we need to join the guard with the expression in the flattened assert statement.
// Note given the guard and the expression, we construct the logical formula `guard => expression`,
// which is equivalent to `!guard || expression`.
Some(guard) => (
Statement::Assert(AssertStatement {
span: input.span,
id: input.id,
variant: AssertVariant::Assert(Expression::Binary(BinaryExpression {
op: BinaryOperation::Or,
span: Default::default(),
id: {
// Create a new node ID for the binary expression.
let id = self.node_builder.next_id();
// Update the type table with the type of the binary expression.
self.type_table.insert(id, Type::Boolean);
id
},
// Take the logical negation of the guard.
left: Box::new(Expression::Unary(UnaryExpression {
op: UnaryOperation::Not,
receiver: Box::new(guard),
span: Default::default(),
id: {
// Create a new node ID for the unary expression.
let id = self.node_builder.next_id();
// Update the type table with the type of the unary expression.
self.type_table.insert(id, Type::Boolean);
id
},
})),
right: Box::new(match assert.variant {
// If the assert statement is an `assert`, use the expression as is.
AssertVariant::Assert(expression) => expression,
// If the assert statement is an `assert_eq`, construct a new equality expression.
AssertVariant::AssertEq(left, right) => Expression::Binary(BinaryExpression {
left: Box::new(left),
op: BinaryOperation::Eq,
right: Box::new(right),
span: Default::default(),
id: {
// Create a new node ID for the unary expression.
let id = self.node_builder.next_id();
// Update the type table with the type of the unary expression.
self.type_table.insert(id, Type::Boolean);
id
},
}),
// If the assert statement is an `assert_ne`, construct a new inequality expression.
AssertVariant::AssertNeq(left, right) => Expression::Binary(BinaryExpression {
left: Box::new(left),
op: BinaryOperation::Neq,
right: Box::new(right),
span: Default::default(),
id: {
// Create a new node ID for the unary expression.
let id = self.node_builder.next_id();
// Update the type table with the type of the unary expression.
self.type_table.insert(id, Type::Boolean);
id
},
}),
}),
})),
}),
statements,
),
let mut guards: Vec<Expression> = Vec::new();

if let Some((guard, guard_statements)) = self.construct_guard() {
statements.extend(guard_statements);

// The not_guard is true if we didn't follow the condition chain
// that led to this assertion.
let not_guard = Expression::Unary(UnaryExpression {
op: UnaryOperation::Not,
receiver: Box::new(Expression::Identifier(guard)),
span: Default::default(),
id: {
// Create a new node ID for the unary expression.
let id = self.node_builder.next_id();
// Update the type table with the type of the unary expression.
self.type_table.insert(id, Type::Boolean);
id
},
});
let (identifier, statement) = self.unique_simple_assign_statement(not_guard);
statements.push(statement);
guards.push(Expression::Identifier(identifier));
}

// We also need to guard against early returns.
if let Some((guard, guard_statements)) = self.construct_early_return_guard() {
guards.push(Expression::Identifier(guard));
statements.extend(guard_statements);
}

if guards.is_empty() {
return (Statement::Assert(assert), statements);
}

let is_eq = matches!(assert.variant, AssertVariant::AssertEq(..));

// We need to `or` the asserted expression with the guards,
// so extract an appropriate expression.
let mut expression = match assert.variant {
// If the assert statement is an `assert`, use the expression as is.
AssertVariant::Assert(expression) => expression,

// For `assert_eq` or `assert_neq`, construct a new expression.
AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
let binary = Expression::Binary(BinaryExpression {
left: Box::new(left),
op: if is_eq { BinaryOperation::Eq } else { BinaryOperation::Neq },
right: Box::new(right),
span: Default::default(),
id: {
// Create a new node ID.
let id = self.node_builder.next_id();
// Update the type table.
self.type_table.insert(id, Type::Boolean);
id
},
});
let (identifier, statement) = self.unique_simple_assign_statement(binary);
statements.push(statement);
Expression::Identifier(identifier)
}
};

// The assertion will be that the original assert statement is true or one of the guards is true
// (ie, we either didn't follow the condition chain that led to this assert, or else we took an
// early return).
for guard in guards.into_iter() {
let binary = Expression::Binary(BinaryExpression {
op: BinaryOperation::Or,
span: Default::default(),
id: {
// Create a new node ID.
let id = self.node_builder.next_id();
// Update the type table.
self.type_table.insert(id, Type::Boolean);
id
},
left: Box::new(expression),
right: Box::new(guard),
});
let (identifier, statement) = self.unique_simple_assign_statement(binary);
statements.push(statement);
expression = Expression::Identifier(identifier);
}

d0cd marked this conversation as resolved.
Show resolved Hide resolved
let assert_statement = Statement::Assert(AssertStatement {
span: input.span,
id: input.id,
variant: AssertVariant::Assert(expression),
});

(assert_statement, statements)
}

/// Flattens an assign statement, if necessary.
Expand Down Expand Up @@ -250,8 +272,21 @@ impl StatementReconstructor for Flattener<'_> {
);
}

// Assign the condition to a variable, as it may be used multiple times.
let place = Identifier {
name: self.assigner.unique_symbol("condition", "$"),
span: Default::default(),
id: {
let id = self.node_builder.next_id();
self.type_table.insert(id, Type::Boolean);
id
},
};

statements.push(self.simple_assign_statement(place, conditional.condition.clone()));

// Add condition to the condition stack.
self.condition_stack.push(conditional.condition.clone());
self.condition_stack.push(Guard::Unconstructed(place));

// Reconstruct the then-block and accumulate it constituent statements.
statements.extend(self.reconstruct_block(conditional.then).0.statements);
Expand All @@ -261,13 +296,24 @@ impl StatementReconstructor for Flattener<'_> {

// Consume the otherwise-block and flatten its constituent statements into the current block.
if let Some(statement) = conditional.otherwise {
// Add the negated condition to the condition stack.
self.condition_stack.push(Expression::Unary(UnaryExpression {
// Apply Not to the condition, assign it, and put it on the condition stack.
let not_condition = Expression::Unary(UnaryExpression {
op: UnaryOperation::Not,
receiver: Box::new(conditional.condition.clone()),
span: conditional.condition.span(),
id: conditional.condition.id(),
}));
});
let not_place = Identifier {
name: self.assigner.unique_symbol("condition", "$"),
span: Default::default(),
id: {
let id = self.node_builder.next_id();
self.type_table.insert(id, Type::Boolean);
id
},
};
statements.push(self.simple_assign_statement(not_place, not_condition));
self.condition_stack.push(Guard::Unconstructed(not_place));

// Reconstruct the otherwise-block and accumulate it constituent statements.
match *statement {
Expand Down Expand Up @@ -302,15 +348,17 @@ impl StatementReconstructor for Flattener<'_> {
return (Statement::Return(input), Default::default());
}
// Construct the associated guard.
let guard = self.construct_guard();
let (guard_identifier, statements) = self.construct_guard().unzip();

let return_guard = guard_identifier.map_or(ReturnGuard::None, ReturnGuard::Unconstructed);

match input.expression {
Expression::Unit(_) | Expression::Identifier(_) | Expression::Access(_) => {
self.returns.push((guard, input))
self.returns.push((return_guard, input))
}
_ => unreachable!("SSA guarantees that the expression is always an identifier or unit expression."),
};

(Statement::dummy(Default::default(), self.node_builder.next_id()), Default::default())
(Statement::dummy(Default::default(), self.node_builder.next_id()), statements.unwrap_or_default())
}
}
Loading
Loading