Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coalesce all compatible no-op libfuncs into a single function. #919

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ case $(uname) in
MLIR_SYS_190_PREFIX="$(brew --prefix llvm@19)"
LLVM_SYS_191_PREFIX="$(brew --prefix llvm@19)"
TABLEGEN_190_PREFIX="$(brew --prefix llvm@19)"
CAIRO_NATIVE_RUNTIME_LIBRARY="$(pwd)/target/debug/libcairo_native_runtime.a"
CAIRO_NATIVE_RUNTIME_LIBRARY="$(pwd)/target/release/libcairo_native_runtime.a"

export LIBRARY_PATH
export MLIR_SYS_190_PREFIX
Expand All @@ -24,7 +24,7 @@ case $(uname) in
MLIR_SYS_190_PREFIX=/usr/lib/llvm-19
LLVM_SYS_191_PREFIX=/usr/lib/llvm-19
TABLEGEN_190_PREFIX=/usr/lib/llvm-19
CAIRO_NATIVE_RUNTIME_LIBRARY="$(pwd)/target/debug/libcairo_native_runtime.a"
CAIRO_NATIVE_RUNTIME_LIBRARY="$(pwd)/target/release/libcairo_native_runtime.a"

export MLIR_SYS_190_PREFIX
export LLVM_SYS_191_PREFIX
Expand Down
8 changes: 2 additions & 6 deletions src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,9 +614,8 @@ fn compile_func(
invocation
.branches
.iter()
.zip(helper.results())
.zip(helper.results()?)
.map(|(branch_info, result_values)| {
let result_values = result_values?;
assert_eq!(
branch_info.results.len(),
result_values.len(),
Expand All @@ -625,10 +624,7 @@ fn compile_func(

Ok(edit_state::put_results(
state.clone(),
branch_info
.results
.iter()
.zip(result_values.iter().copied()),
branch_info.results.iter().zip(result_values.into_iter()),
)?)
})
.collect::<Result<_, Error>>()?,
Expand Down
140 changes: 93 additions & 47 deletions src/libfuncs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
//! Contains libfunc generation stuff (aka. the actual instructions).

use crate::{
error::{panic::ToNativeAssertError, Error as CoreLibfuncBuilderError, Result as NativeResult},
error::{panic::ToNativeAssertError, Error as CoreLibfuncBuilderError, Result},
metadata::MetadataStorage,
types::TypeBuilder,
utils::BlockExt,
};
use bumpalo::Bump;
use cairo_lang_sierra::{
extensions::core::{CoreConcreteLibfunc, CoreLibfunc, CoreType},
extensions::{
core::{CoreConcreteLibfunc, CoreLibfunc, CoreType, CoreTypeConcrete},
lib_func::ParamSignature,
starknet::StarkNetTypeConcrete,
ConcreteLibfunc,
},
ids::FunctionId,
program_registry::ProgramRegistry,
};
Expand All @@ -21,13 +27,11 @@ use melior::{
use num_bigint::BigInt;
use std::{cell::Cell, error::Error, ops::Deref};

mod ap_tracking;
mod array;
mod bitwise;
mod r#bool;
mod bounded_int;
mod r#box;
mod branch_align;
mod bytes31;
mod cast;
mod circuit;
Expand All @@ -52,7 +56,6 @@ mod sint16;
mod sint32;
mod sint64;
mod sint8;
mod snapshot_take;
mod starknet;
mod r#struct;
mod uint128;
Expand All @@ -62,8 +65,6 @@ mod uint32;
mod uint512;
mod uint64;
mod uint8;
mod unconditional_jump;
mod unwrap_non_zero;

/// Generation of MLIR operations from their Sierra counterparts.
///
Expand All @@ -82,7 +83,7 @@ pub trait LibfuncBuilder {
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
metadata: &mut MetadataStorage,
) -> Result<(), Self::Error>;
) -> Result<()>;

/// Return the target function if the statement is a function call.
///
Expand All @@ -102,20 +103,28 @@ impl LibfuncBuilder for CoreConcreteLibfunc {
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
metadata: &mut MetadataStorage,
) -> Result<(), Self::Error> {
) -> Result<()> {
match self {
Self::ApTracking(selector) => self::ap_tracking::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::ApTracking(_) | Self::BranchAlign(_) | Self::UnconditionalJump(_) => {
build_noop::<0, true>(
context,
registry,
entry,
location,
helper,
metadata,
self.param_signatures(),
)
}
Self::Array(selector) => self::array::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::BranchAlign(info) => self::branch_align::build(
context, registry, entry, location, helper, metadata, info,
),
Self::Bool(selector) => self::r#bool::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::BoundedInt(info) => {
self::bounded_int::build(context, registry, entry, location, helper, metadata, info)
}
Self::Box(selector) => self::r#box::build(
context, registry, entry, location, helper, metadata, selector,
),
Expand All @@ -125,16 +134,25 @@ impl LibfuncBuilder for CoreConcreteLibfunc {
Self::Cast(selector) => self::cast::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Circuit(info) => {
self::circuit::build(context, registry, entry, location, helper, metadata, info)
}
Self::Const(selector) => self::r#const::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Coupon(info) => {
self::coupon::build(context, registry, entry, location, helper, metadata, info)
}
Self::CouponCall(info) => self::function_call::build(
context, registry, entry, location, helper, metadata, info,
),
Self::Debug(selector) => self::debug::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Drop(info) => {
self::drop::build(context, registry, entry, location, helper, metadata, info)
}
Self::Dup(info) => {
Self::Dup(info) | Self::SnapshotTake(info) => {
self::dup::build(context, registry, entry, location, helper, metadata, info)
}
Self::Ec(selector) => self::ec::build(
Expand Down Expand Up @@ -185,9 +203,6 @@ impl LibfuncBuilder for CoreConcreteLibfunc {
Self::Sint128(info) => {
self::sint128::build(context, registry, entry, location, helper, metadata, info)
}
Self::SnapshotTake(info) => self::snapshot_take::build(
context, registry, entry, location, helper, metadata, info,
),
Self::StarkNet(selector) => self::starknet::build(
context, registry, entry, location, helper, metadata, selector,
),
Expand Down Expand Up @@ -215,24 +230,15 @@ impl LibfuncBuilder for CoreConcreteLibfunc {
Self::Uint512(selector) => self::uint512::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::UnconditionalJump(info) => self::unconditional_jump::build(
context, registry, entry, location, helper, metadata, info,
),
Self::UnwrapNonZero(info) => self::unwrap_non_zero::build(
context, registry, entry, location, helper, metadata, info,
),
Self::Coupon(info) => {
self::coupon::build(context, registry, entry, location, helper, metadata, info)
}
Self::CouponCall(info) => self::function_call::build(
context, registry, entry, location, helper, metadata, info,
Self::UnwrapNonZero(info) => build_noop::<1, true>(
context,
registry,
entry,
location,
helper,
metadata,
&info.signature.param_signatures,
),
Self::Circuit(info) => {
self::circuit::build(context, registry, entry, location, helper, metadata, info)
}
Self::BoundedInt(info) => {
self::bounded_int::build(context, registry, entry, location, helper, metadata, info)
}
}
}

Expand Down Expand Up @@ -274,17 +280,21 @@ where
'this: 'ctx,
{
#[doc(hidden)]
pub(crate) fn results(self) -> impl Iterator<Item = NativeResult<Vec<Value<'ctx, 'this>>>> {
self.results.into_iter().enumerate().map(|(branch_idx, x)| {
x.into_iter()
.enumerate()
.map(|(arg_idx, x)| {
x.into_inner().to_native_assert_error(&format!(
"Argument #{arg_idx} of branch {branch_idx} doesn't have a value."
))
})
.collect()
})
pub(crate) fn results(self) -> Result<Vec<Vec<Value<'ctx, 'this>>>> {
self.results
.into_iter()
.enumerate()
.map(|(branch_idx, x)| {
x.into_iter()
.enumerate()
.map(|(arg_idx, x)| {
x.into_inner().to_native_assert_error(&format!(
"Argument #{arg_idx} of branch {branch_idx} doesn't have a value."
))
})
.collect()
})
.collect()
}

/// Return the initialization block.
Expand Down Expand Up @@ -439,3 +449,39 @@ fn increment_builtin_counter_by<'ctx: 'a, 'a>(
location,
))
}

fn build_noop<'ctx, 'this, const N: usize, const PROCESS_BUILTINS: bool>(
context: &'ctx Context,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
entry: &'this Block<'ctx>,
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
_metadata: &mut MetadataStorage,
param_signatures: &[ParamSignature],
) -> Result<()> {
let mut params = Vec::with_capacity(N);

#[allow(clippy::needless_range_loop)]
for i in 0..N {
let param_ty = registry.get_type(&param_signatures[i].ty)?;
let mut param_val = entry.argument(i)?.into();

if PROCESS_BUILTINS
JulianGCalderon marked this conversation as resolved.
Show resolved Hide resolved
&& param_ty.is_builtin()
&& !matches!(
param_ty,
CoreTypeConcrete::BuiltinCosts(_)
| CoreTypeConcrete::Coupon(_)
| CoreTypeConcrete::GasBuiltin(_)
| CoreTypeConcrete::StarkNet(StarkNetTypeConcrete::System(_))
)
{
param_val = increment_builtin_counter(context, entry, location, param_val)?;
}

params.push(param_val);
}

entry.append_operation(helper.br(0, &params, location));
Ok(())
}
87 changes: 0 additions & 87 deletions src/libfuncs/ap_tracking.rs

This file was deleted.

Loading
Loading