diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index ba31d1bd7..2ffce69e7 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -14,9 +14,9 @@ use self::{ }; use crate::{ ast::{ - AssignmentKind, BinOp, Bls12_381Point, Curve, DataTypeKey, FunctionAccessKey, Pattern, - Span, TraceLevel, Tracing, TypedArg, TypedClause, TypedDataType, TypedFunction, - TypedPattern, TypedValidator, UnOp, + AssignmentKind, BinOp, Bls12_381Point, Curve, DataTypeKey, FunctionAccessKey, + OnTestFailure, Pattern, Span, TraceLevel, Tracing, TypedArg, TypedClause, TypedDataType, + TypedFunction, TypedPattern, TypedValidator, UnOp, }, builtins::PRELUDE, expr::TypedExpr, @@ -64,7 +64,6 @@ const DELAY_ERROR: fn() -> AirTree = #[derive(Clone)] pub struct CodeGenerator<'a> { - #[allow(dead_code)] plutus_version: PlutusVersion, /// immutable index maps functions: IndexMap<&'a FunctionAccessKey, &'a TypedFunction>, @@ -80,6 +79,8 @@ pub struct CodeGenerator<'a> { code_gen_functions: IndexMap, cyclic_functions: IndexMap<(FunctionAccessKey, Variant), (CycleFunctionNames, usize, FunctionAccessKey)>, + monomorphized_consts: IndexMap<(FunctionAccessKey, String), AirTree>, + uplc_resolved_consts: IndexMap<(FunctionAccessKey, String), Term>, /// mutable and reset as well interner: AirInterner, id_gen: IdGenerator, @@ -111,6 +112,8 @@ impl<'a> CodeGenerator<'a> { special_functions: CodeGenSpecialFuncs::new(), code_gen_functions: IndexMap::new(), cyclic_functions: IndexMap::new(), + monomorphized_consts: IndexMap::new(), + uplc_resolved_consts: IndexMap::new(), interner: AirInterner::new(), id_gen: IdGenerator::new(), } @@ -120,6 +123,8 @@ impl<'a> CodeGenerator<'a> { self.code_gen_functions = IndexMap::new(); self.defined_functions = IndexMap::new(); self.cyclic_functions = IndexMap::new(); + self.monomorphized_consts = IndexMap::new(); + self.uplc_resolved_consts = IndexMap::new(); self.interner = AirInterner::new(); self.id_gen = IdGenerator::new(); if reset_special_functions { @@ -151,6 +156,43 @@ impl<'a> CodeGenerator<'a> { let full_vec = full_tree.to_vec(); + self.uplc_resolved_consts = self + .monomorphized_consts + .clone() + .into_iter() + .map(|item| { + let (key, value) = item; + + let value = self.hoist_functions_to_validator(value); + + let const_term = self + .uplc_code_gen(value.to_vec()) + .constr_fields_exposer() + .constr_index_exposer(); + + let mut program = + self.new_program(self.special_functions.apply_used_functions(const_term)); + + let mut interner = CodeGenInterner::new(); + + interner.program(&mut program); + + let eval_program: Program = + program.remove_no_inlines().try_into().unwrap(); + + let const_body = eval_program + .eval(ExBudget::max()) + .result() + .unwrap_or_else(|e| panic!("Failed to evaluate constant: {e:#?}")) + .try_into() + .unwrap(); + + (key.clone(), const_body) + }) + .collect(); + + println!("TOOOO {:#?}", self.uplc_resolved_consts); + let term = self.uplc_code_gen(full_vec); let term = cast_validator_args(term, &validator.params, &self.interner); @@ -186,6 +228,43 @@ impl<'a> CodeGenerator<'a> { // optimizations on air tree let full_vec = full_tree.to_vec(); + self.uplc_resolved_consts = self + .monomorphized_consts + .clone() + .into_iter() + .map(|item| { + let (key, value) = item; + + let value = self.hoist_functions_to_validator(value); + + let const_term = self + .uplc_code_gen(value.to_vec()) + .constr_fields_exposer() + .constr_index_exposer(); + + let mut program = + self.new_program(self.special_functions.apply_used_functions(const_term)); + + let mut interner = CodeGenInterner::new(); + + interner.program(&mut program); + + let eval_program: Program = + program.remove_no_inlines().try_into().unwrap(); + + let const_body = eval_program + .eval(ExBudget::max()) + .result() + .unwrap_or_else(|e| panic!("Failed to evaluate constant: {e:#?}")) + .try_into() + .unwrap(); + + (key.clone(), const_body) + }) + .collect(); + + println!("TOOOO {:#?}", self.uplc_resolved_consts); + let mut term = self.uplc_code_gen(full_vec); term = if args.is_empty() { @@ -215,6 +294,8 @@ impl<'a> CodeGenerator<'a> { fn finalize(&mut self, mut term: Term) -> Program { term = self.special_functions.apply_used_functions(term); + println!("TERM IS {}", term.to_pretty()); + let program = aiken_optimize_and_intern(self.new_program(term)); // This is very important to call here. @@ -3541,20 +3622,28 @@ impl<'a> CodeGenerator<'a> { .unwrap_or_else(|| panic!("Missing Function Variant Definition")); match function { - HoistableFunction::Function { body, deps, params } => { + HoistableFunction::Function { + body, + deps, + params, + is_constant, + } => { let mut hoist_body = body.clone(); let mut hoist_deps = deps.clone(); let params = params.clone(); let tree_path = tree_path.clone(); - - self.define_dependent_functions( - &mut hoist_body, - &mut functions_to_hoist, - &mut used_functions, - &defined_functions, - &mut hoist_deps, - tree_path, - ); + let is_constant = *is_constant; + + if !is_constant { + self.define_dependent_functions( + &mut hoist_body, + &mut functions_to_hoist, + &mut used_functions, + &defined_functions, + &mut hoist_deps, + tree_path, + ); + } let function_variants = functions_to_hoist .get_mut(&key) @@ -3564,17 +3653,17 @@ impl<'a> CodeGenerator<'a> { .get_mut(&variant_name) .expect("Missing Function Variant Definition"); - if params.is_empty() { - validator_hoistable.push((key, variant_name)); + if is_constant { + assert!(hoist_deps.is_empty()); + } else { + *function = HoistableFunction::Function { + body: hoist_body, + deps: hoist_deps, + params, + is_constant, + }; } - - *function = HoistableFunction::Function { - body: hoist_body, - deps: hoist_deps, - params, - }; } - HoistableFunction::Link(_) => todo!("Deal with Link later"), _ => unreachable!(), } } @@ -3678,7 +3767,9 @@ impl<'a> CodeGenerator<'a> { .expect("Missing Function Variant Definition"); match func { - HoistableFunction::Function { params, body, deps } => { + HoistableFunction::Function { + params, body, deps, .. + } => { cycle_of_functions.push((params.clone(), body.clone())); cycle_deps.push(deps.clone()); } @@ -3878,6 +3969,7 @@ impl<'a> CodeGenerator<'a> { body, deps: func_deps, params, + is_constant, } => { let mut body = body.clone(); @@ -3901,27 +3993,47 @@ impl<'a> CodeGenerator<'a> { let node_to_edit = air_tree.find_air_tree_node(tree_path); - let defined_function = AirTree::define_func( - &key.function_name, - &key.module_name, - variant, - func_params.clone(), - is_recursive, - recursive_nonstatics, - body, - node_to_edit.clone(), - ); + let defined_function = if *is_constant { + self.monomorphized_consts.insert( + ( + FunctionAccessKey { + function_name: key.function_name.clone(), + module_name: key.module_name.clone(), + }, + variant.clone(), + ), + body, + ); - let defined_dependencies = self.hoist_dependent_functions( - deps, - (key, variant), - hoisted_functions, - functions_to_hoist, - defined_function, - ); + AirTree::define_const( + &key.function_name, + &key.module_name, + variant, + node_to_edit.clone(), + ) + } else { + let func = AirTree::define_func( + &key.function_name, + &key.module_name, + variant, + func_params.clone(), + is_recursive, + recursive_nonstatics, + body, + node_to_edit.clone(), + ); + + self.hoist_dependent_functions( + deps, + (key, variant), + hoisted_functions, + functions_to_hoist, + func, + ) + }; // now hoist full function onto validator tree - *node_to_edit = defined_dependencies; + *node_to_edit = defined_function; hoisted_functions.push((key.clone(), variant.clone())); } @@ -4067,6 +4179,7 @@ impl<'a> CodeGenerator<'a> { body: mut dep_air_tree, deps: dependency_deps, params: dependent_params, + is_constant, } => { let is_dependent_recursive = dependency_deps .iter() @@ -4085,16 +4198,36 @@ impl<'a> CodeGenerator<'a> { hoisted_functions.push((dep_key.clone(), dep_variant.clone())); - AirTree::define_func( - &dep_key.function_name, - &dep_key.module_name, - &dep_variant, - dependent_params, - is_dependent_recursive, - recursive_nonstatics, - dep_air_tree, - then, - ) + if is_constant { + self.monomorphized_consts.insert( + ( + FunctionAccessKey { + function_name: dep_key.function_name.clone(), + module_name: dep_key.module_name.clone(), + }, + dep_variant.clone(), + ), + dep_air_tree, + ); + + AirTree::define_const( + &dep_key.function_name, + &dep_key.module_name, + &dep_variant, + then, + ) + } else { + AirTree::define_func( + &dep_key.function_name, + &dep_key.module_name, + &dep_variant, + dependent_params, + is_dependent_recursive, + recursive_nonstatics, + dep_air_tree, + then, + ) + } } HoistableFunction::CyclicFunction { functions, .. } => { let mut functions = functions.clone(); @@ -4185,24 +4318,52 @@ impl<'a> CodeGenerator<'a> { .. } = air_tree { - let ValueConstructorVariant::ModuleFn { - name: func_name, - module, - builtin: None, - .. - } = &constructor.variant - else { - return; + let (func_name, module, is_constant) = match &constructor.variant { + ValueConstructorVariant::ModuleConstant { module, name, .. } => { + (name, module, true) + } + ValueConstructorVariant::ModuleFn { + name, + module, + builtin: None, + .. + } => (name, module, false), + // In other cases simply return early + _ => return, }; - let function_var_tipo = &constructor.tipo; + let mut function_var_tipo = constructor.tipo.clone(); let generic_function_key = FunctionAccessKey { module_name: module.clone(), function_name: func_name.clone(), }; - let function_def = self.functions.get(&generic_function_key); + let const_func = + self.constants + .get(&generic_function_key) + .map(|item| TypedFunction { + arguments: vec![], + body: (*item).clone(), + doc: None, + end_position: 0, + location: Span::empty(), + name: func_name.clone(), + public: true, + return_annotation: None, + return_type: item.tipo(), + on_test_failure: OnTestFailure::FailImmediately, + }); + + if const_func.is_some() { + function_var_tipo = Type::function(vec![], constructor.tipo.clone()); + } + + let function_def = self + .functions + .get(&generic_function_key) + .copied() + .or(const_func.as_ref()); let Some(function_def) = function_def else { let code_gen_func = self @@ -4250,6 +4411,7 @@ impl<'a> CodeGenerator<'a> { body, deps: vec![], params: params.clone(), + is_constant: false, }, ), ); @@ -4312,9 +4474,8 @@ impl<'a> CodeGenerator<'a> { if let Some((path, _)) = func_variants.get_mut(&variant) { *path = path.common_ancestor(tree_path); } else { - let args = function_def.arguments.clone(); - - let params = args + let params = function_def + .arguments .iter() .map(|arg| { arg.arg_name @@ -4337,7 +4498,7 @@ impl<'a> CodeGenerator<'a> { monomorphize(air_tree, &mono_types); }); - args.iter().for_each(|arg| { + function_def.arguments.iter().for_each(|arg| { arg.arg_name.get_variable_name().iter().for_each(|arg| { self.interner.pop_text(arg.to_string()); }) @@ -4351,14 +4512,14 @@ impl<'a> CodeGenerator<'a> { body: function_air_tree_body, deps: vec![], params, + is_constant, }, ), ); } } else { - let args = function_def.arguments.clone(); - - let params = args + let params = function_def + .arguments .iter() .map(|arg| { arg.arg_name @@ -4381,7 +4542,7 @@ impl<'a> CodeGenerator<'a> { let mut function_variant_path = IndexMap::new(); - args.iter().for_each(|arg| { + function_def.arguments.iter().for_each(|arg| { arg.arg_name .get_variable_name() .iter() @@ -4396,6 +4557,7 @@ impl<'a> CodeGenerator<'a> { body: function_air_tree_body, deps: vec![], params, + is_constant, }, ), ); @@ -4443,48 +4605,31 @@ impl<'a> CodeGenerator<'a> { .into(), )), ValueConstructorVariant::ModuleConstant { module, name, .. } => { - let access_key = FunctionAccessKey { - module_name: module.clone(), - function_name: name.clone(), + let uplc_name = if !module.is_empty() { + format!("{module}_{name}{variant_name}") + } else { + format!("{name}{variant_name}") }; - let definition = self - .constants - .get(&access_key) - .unwrap_or_else(|| panic!("unknown constant {module}.{name}")); - - let mut value = - AirTree::no_op(self.build(definition, &access_key.module_name, &[])); - - value.traverse_tree_with(&mut |air_tree, _| { - erase_opaque_type_operations(air_tree, &self.data_types); - }); - - value = self.hoist_functions_to_validator(value); - - let term = self - .uplc_code_gen(value.to_vec()) - .constr_fields_exposer() - .constr_index_exposer(); - - let mut program = - self.new_program(self.special_functions.apply_used_functions(term)); - - let mut interner = CodeGenInterner::new(); - - interner.program(&mut program); + let existing_term = self.uplc_resolved_consts.get(&( + FunctionAccessKey { + module_name: module.clone(), + function_name: name.clone(), + }, + variant_name.clone(), + )); - let eval_program: Program = - program.remove_no_inlines().try_into().unwrap(); + match existing_term { + Some(constant @ Term::Constant(_)) => Some(constant.clone()), - Some( - eval_program - .eval(ExBudget::max()) - .result() - .unwrap_or_else(|e| panic!("Failed to evaluate constant: {e:#?}")) - .try_into() - .unwrap(), - ) + _ => Some(Term::Var( + Name { + text: uplc_name, + unique: 0.into(), + } + .into(), + )), + } } ValueConstructorVariant::ModuleFn { name: func_name, @@ -4988,134 +5133,211 @@ impl<'a> CodeGenerator<'a> { } Air::DefineFunc { func_name, - params, - recursive, - recursive_nonstatic_params, module_name, variant_name, + variant, } => { - let func_name = if module_name.is_empty() { + let func_uplc_name = if module_name.is_empty() { format!("{func_name}{variant_name}") } else { format!("{module_name}_{func_name}{variant_name}") }; - let mut func_body = arg_stack.pop().unwrap(); - let mut term = arg_stack.pop().unwrap(); + match variant { + air::FunctionVariants::Standard(params) => { + let mut func_body = arg_stack.pop().unwrap(); - // Introduce a parameter for each parameter - // NOTE: we use recursive_nonstatic_params here because - // if this is recursive, those are the ones that need to be passed - // each time - for param in recursive_nonstatic_params.iter().rev() { - func_body = func_body.lambda(param.clone()); - } + let term = arg_stack.pop().unwrap(); - if recursive_nonstatic_params.is_empty() || params.is_empty() { - func_body = func_body.delay(); - } + if params.is_empty() { + func_body = func_body.delay(); + } - if !recursive { - term = term.lambda(func_name).apply(func_body.lambda(NO_INLINE)); + let func_body = params + .into_iter() + .rfold(func_body, |term, arg| term.lambda(arg)) + .lambda(NO_INLINE); - Some(term) - } else { - func_body = func_body.lambda(func_name.clone()); + Some(term.lambda(func_uplc_name).apply(func_body)) + } + air::FunctionVariants::Recursive { + params, + recursive_nonstatic_params, + } => { + let mut func_body = arg_stack.pop().unwrap(); - if recursive_nonstatic_params == params { - // If we don't have any recursive-static params, we can just emit the function as is - term = term - .lambda(func_name.clone()) - .apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone()))) - .lambda(func_name) - .apply(func_body.lambda(NO_INLINE)); - } else { - // If we have parameters that remain static in each recursive call, - // we can construct an *outer* function to take those in - // and simplify the recursive part to only accept the non-static arguments - let mut recursive_func_body = - Term::var(&func_name).apply(Term::var(&func_name)); - for param in recursive_nonstatic_params.iter() { - recursive_func_body = recursive_func_body.apply(Term::var(param)); - } + let term = arg_stack.pop().unwrap(); - if recursive_nonstatic_params.is_empty() { - recursive_func_body = recursive_func_body.force(); + let no_statics = recursive_nonstatic_params == params; + + if recursive_nonstatic_params.is_empty() || params.is_empty() { + func_body = func_body.delay(); } - // Then construct an outer function with *all* parameters, not just the nonstatic ones. - let mut outer_func_body = - recursive_func_body.lambda(&func_name).apply(func_body); + let func_body = recursive_nonstatic_params + .iter() + .rfold(func_body, |term, arg| term.lambda(arg)); - // Now, add *all* parameters, so that other call sites don't know the difference - for param in params.iter().rev() { - outer_func_body = outer_func_body.lambda(param); - } + let func_body = func_body.lambda(func_uplc_name.clone()); - // And finally, fold that definition into the rest of our program - term = term - .lambda(&func_name) - .apply(outer_func_body.lambda(NO_INLINE)); + if no_statics { + // If we don't have any recursive-static params, we can just emit the function as is + Some( + term.lambda(func_uplc_name.clone()) + .apply( + Term::var(func_uplc_name.clone()) + .apply(Term::var(func_uplc_name.clone())), + ) + .lambda(func_uplc_name) + .apply(func_body.lambda(NO_INLINE)), + ) + } else { + // If we have parameters that remain static in each recursive call, + // we can construct an *outer* function to take those in + // and simplify the recursive part to only accept the non-static arguments + let mut recursive_func_body = + Term::var(&func_uplc_name).apply(Term::var(&func_uplc_name)); + + if recursive_nonstatic_params.is_empty() { + recursive_func_body = recursive_func_body.force(); + } + + // Introduce a parameter for each parameter + // NOTE: we use recursive_nonstatic_params here because + // if this is recursive, those are the ones that need to be passed + // each time + for param in recursive_nonstatic_params.into_iter() { + recursive_func_body = recursive_func_body.apply(Term::var(param)); + } + + // Then construct an outer function with *all* parameters, not just the nonstatic ones. + let mut outer_func_body = + recursive_func_body.lambda(&func_uplc_name).apply(func_body); + + // Now, add *all* parameters, so that other call sites don't know the difference + outer_func_body = params + .clone() + .into_iter() + .rfold(outer_func_body, |term, arg| term.lambda(arg)); + + // And finally, fold that definition into the rest of our program + Some( + term.lambda(&func_uplc_name) + .apply(outer_func_body.lambda(NO_INLINE)), + ) + } } + air::FunctionVariants::Constant => { + match arg_stack.pop().unwrap() { + Term::Constant(v) if matches!(v.as_ref(), UplcConstant::Unit) => (), + _ => panic!("Constants should have a void body."), + }; - Some(term) - } - } - Air::DefineCyclicFuncs { - func_name, - module_name, - variant_name, - contained_functions, - } => { - let func_name = if module_name.is_empty() { - format!("{func_name}{variant_name}") - } else { - format!("{module_name}_{func_name}{variant_name}") - }; - let mut cyclic_functions = vec![]; + let term = arg_stack.pop().unwrap(); - for params in contained_functions { - let func_body = arg_stack.pop().unwrap(); + let access_key = FunctionAccessKey { + module_name: module_name.clone(), + function_name: func_name.clone(), + }; - cyclic_functions.push((params, func_body)); - } - let mut term = arg_stack.pop().unwrap(); + let existing_term = self + .uplc_resolved_consts + .get(&(access_key.clone(), variant_name.clone())); - let mut cyclic_body = Term::var("__chooser"); + match existing_term { + Some(Term::Constant(_)) => Some(term), - for (params, func_body) in cyclic_functions.into_iter() { - let mut function = func_body; - if params.is_empty() { - function = function.delay(); - } else { - for param in params.iter().rev() { - function = function.lambda(param); + _ => { + let mut value = self + .monomorphized_consts + .get(&(access_key.clone(), variant_name.clone())) + .cloned() + .unwrap(); + + value = self.hoist_functions_to_validator(value); + + let const_term = self + .uplc_code_gen(value.to_vec()) + .constr_fields_exposer() + .constr_index_exposer(); + + let mut program = self.new_program( + self.special_functions.apply_used_functions(const_term), + ); + + let mut interner = CodeGenInterner::new(); + + interner.program(&mut program); + + let eval_program: Program = + program.remove_no_inlines().try_into().unwrap(); + + let const_body: Term = eval_program + .eval(ExBudget::max()) + .result() + .unwrap_or_else(|e| { + panic!("Failed to evaluate constant: {e:#?}") + }) + .try_into() + .unwrap(); + + self.uplc_resolved_consts.insert( + (access_key.clone(), variant_name.clone()), + const_body.clone(), + ); + + Some(term.lambda(func_uplc_name).apply(const_body)) + } } } + air::FunctionVariants::Cyclic(contained_functions) => { + let mut cyclic_functions = vec![]; - // We basically Scott encode our function bodies and use the chooser function - // to determine which function body and params is run - // For example say there is a cycle of 3 function bodies - // Our choose function can look like this: - // \func1 -> \func2 -> \func3 -> func1 - // In this case our chooser is a function that takes in 3 functions - // and returns the first one to run - cyclic_body = cyclic_body.apply(function) - } + for params in contained_functions { + let func_body = arg_stack.pop().unwrap(); - term = term - .lambda(&func_name) - .apply(Term::var(&func_name).apply(Term::var(&func_name))) - .lambda(&func_name) - .apply( - cyclic_body - .lambda("__chooser") - .lambda(func_name) - .lambda(NO_INLINE), - ); + cyclic_functions.push((params, func_body)); + } + let mut term = arg_stack.pop().unwrap(); - Some(term) + let mut cyclic_body = Term::var("__chooser"); + + for (params, func_body) in cyclic_functions.into_iter() { + let mut function = func_body; + if params.is_empty() { + function = function.delay(); + } else { + for param in params.iter().rev() { + function = function.lambda(param); + } + } + + // We basically Scott encode our function bodies and use the chooser function + // to determine which function body and params is run + // For example say there is a cycle of 3 function bodies + // Our choose function can look like this: + // \func1 -> \func2 -> \func3 -> func1 + // In this case our chooser is a function that takes in 3 functions + // and returns the first one to run + cyclic_body = cyclic_body.apply(function) + } + + term = term + .lambda(&func_uplc_name) + .apply(Term::var(&func_uplc_name).apply(Term::var(&func_uplc_name))) + .lambda(&func_uplc_name) + .apply( + cyclic_body + .lambda("__chooser") + .lambda(func_uplc_name) + .lambda(NO_INLINE), + ); + + Some(term) + } + } } + Air::Let { name } => { let arg = arg_stack.pop().unwrap(); diff --git a/crates/aiken-lang/src/gen_uplc/air.rs b/crates/aiken-lang/src/gen_uplc/air.rs index 981ad76ea..6845d8d0d 100644 --- a/crates/aiken-lang/src/gen_uplc/air.rs +++ b/crates/aiken-lang/src/gen_uplc/air.rs @@ -23,6 +23,17 @@ impl From for ExpectLevel { } } +#[derive(Debug, Clone, PartialEq)] +pub enum FunctionVariants { + Standard(Vec), + Recursive { + params: Vec, + recursive_nonstatic_params: Vec, + }, + Cyclic(Vec>), + Constant, +} + #[derive(Debug, Clone, PartialEq)] pub enum Air { // Primitives @@ -65,19 +76,10 @@ pub enum Air { tipo: Rc, }, DefineFunc { - func_name: String, - module_name: String, - params: Vec, - recursive: bool, - recursive_nonstatic_params: Vec, - variant_name: String, - }, - DefineCyclicFuncs { func_name: String, module_name: String, variant_name: String, - // just the params - contained_functions: Vec>, + variant: FunctionVariants, }, Fn { params: Vec, diff --git a/crates/aiken-lang/src/gen_uplc/builder.rs b/crates/aiken-lang/src/gen_uplc/builder.rs index 4d9b9d76c..3148e526c 100644 --- a/crates/aiken-lang/src/gen_uplc/builder.rs +++ b/crates/aiken-lang/src/gen_uplc/builder.rs @@ -56,6 +56,7 @@ pub enum HoistableFunction { body: AirTree, deps: Vec<(FunctionAccessKey, Variant)>, params: Params, + is_constant: bool, }, CyclicFunction { functions: Vec<(Params, AirTree)>, diff --git a/crates/aiken-lang/src/gen_uplc/tree.rs b/crates/aiken-lang/src/gen_uplc/tree.rs index 6e0183e38..5752eccb2 100644 --- a/crates/aiken-lang/src/gen_uplc/tree.rs +++ b/crates/aiken-lang/src/gen_uplc/tree.rs @@ -1,4 +1,4 @@ -use super::air::{Air, ExpectLevel}; +use super::air::{Air, ExpectLevel, FunctionVariants}; use crate::{ ast::{BinOp, Curve, Span, UnOp}, tipo::{Type, ValueConstructor, ValueConstructorVariant}, @@ -21,6 +21,7 @@ pub enum Fields { SixthField, SeventhField, EighthField, + NinthField, ArgsField(usize), } @@ -136,10 +137,12 @@ pub enum AirTree { DefineFunc { func_name: String, module_name: String, + variant_name: String, + //params and other parts of a function params: Vec, recursive: bool, recursive_nonstatic_params: Vec, - variant_name: String, + constant: bool, func_body: Box, then: Box, }, @@ -531,12 +534,33 @@ impl AirTree { params, recursive, recursive_nonstatic_params, + constant: false, variant_name: variant_name.to_string(), func_body: func_body.into(), then: then.into(), } } + #[allow(clippy::too_many_arguments)] + pub fn define_const( + func_name: impl ToString, + module_name: impl ToString, + variant_name: impl ToString, + then: AirTree, + ) -> AirTree { + AirTree::DefineFunc { + func_name: func_name.to_string(), + module_name: module_name.to_string(), + variant_name: variant_name.to_string(), + params: vec![], + recursive: false, + recursive_nonstatic_params: vec![], + constant: true, + func_body: AirTree::void().into(), + then: then.into(), + } + } + pub fn define_cyclic_func( func_name: impl ToString, module_name: impl ToString, @@ -1158,17 +1182,32 @@ impl AirTree { params, recursive, recursive_nonstatic_params, + constant, variant_name, func_body, then, } => { + let variant = if *constant { + assert!(!recursive); + assert!(params.is_empty()); + assert!(recursive_nonstatic_params.is_empty()); + + FunctionVariants::Constant + } else if *recursive { + FunctionVariants::Recursive { + params: params.clone(), + recursive_nonstatic_params: recursive_nonstatic_params.clone(), + } + } else { + assert_eq!(params, recursive_nonstatic_params); + FunctionVariants::Standard(params.clone()) + }; + air_vec.push(Air::DefineFunc { func_name: func_name.clone(), module_name: module_name.clone(), - params: params.clone(), - recursive: *recursive, - recursive_nonstatic_params: recursive_nonstatic_params.clone(), variant_name: variant_name.clone(), + variant, }); func_body.create_air_vec(air_vec); then.create_air_vec(air_vec); @@ -1180,14 +1219,18 @@ impl AirTree { contained_functions, then, } => { - air_vec.push(Air::DefineCyclicFuncs { - func_name: func_name.clone(), - module_name: module_name.clone(), - variant_name: variant_name.clone(), - contained_functions: contained_functions + let variant = FunctionVariants::Cyclic( + contained_functions .iter() .map(|(params, _)| params.clone()) .collect_vec(), + ); + + air_vec.push(Air::DefineFunc { + func_name: func_name.clone(), + module_name: module_name.clone(), + variant_name: variant_name.clone(), + variant, }); for (_, func_body) in contained_functions { @@ -1834,8 +1877,8 @@ impl AirTree { ) { tree_path.push(current_depth, field_index); - // Assignments'/Statements' values get traversed here - // Then the body under these assignments/statements get traversed later on + // TODO: Merge together the 2 match statements + match self { AirTree::Let { name: _, @@ -2104,7 +2147,6 @@ impl AirTree { | AirTree::MultiValidator { .. } => {} } - // Expressions or an assignment that hoist over a expression are traversed here match self { AirTree::NoOp { then } => { then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with); @@ -2389,16 +2431,17 @@ impl AirTree { recursive: _, recursive_nonstatic_params: _, variant_name: _, + constant: _, func_body, then, } => { func_body.do_traverse_tree_with( tree_path, current_depth + 1, - Fields::SeventhField, + Fields::EighthField, with, ); - then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::EighthField, with) + then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::NinthField, with) } AirTree::DefineCyclicFuncs { func_name: _, diff --git a/crates/aiken-project/src/tests/gen_uplc.rs b/crates/aiken-project/src/tests/gen_uplc.rs index 2ea886aeb..0660036fa 100644 --- a/crates/aiken-project/src/tests/gen_uplc.rs +++ b/crates/aiken-project/src/tests/gen_uplc.rs @@ -3032,7 +3032,9 @@ fn acceptance_test_29_union_pair() { inner: Pairs, } - const empty_list: AssocList = AssocList { inner: [] } + const empty_list: AssocList = {fn(k: key, v: value){ + [(k,v)] + }(1, 2)} pub fn from_list(xs: Pairs) -> AssocList { AssocList { inner: do_from_list(xs) } @@ -6051,6 +6053,95 @@ fn bls12_381_elements_from_data_conversion() { ) } +#[test] +fn bls12_381_elements_constant_hoisting() { + let src = r#" + pub const generator_g1: G1Element = + #"97f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb" + + pub const other_generator_g1: G1Element = + #"b28cb29bc282be68df977b35eb9d8e98b3a0a3fc7c372990bddc50419ca86693e491755338fed4fb42231a7c081252ce" + + type Mew{ + One(G1Element) + Two(G1Element) + } + + type Choo{ + Foo(Mew) + Bar + } + + test thing() { + let x = Foo(One(other_generator_g1)) + + when x is { + Foo(y) -> { + when y is { + One(other_g1) -> { + let g1 = generator_g1 + + g1 != other_g1 + } + Two(other_g) -> { + let ga = generator_g1 + + ga == other_g && True + } + } + } + Bar -> False + } + } + "#; + + let bytes = vec![ + 0xb2, 0x8c, 0xb2, 0x9b, 0xc2, 0x82, 0xbe, 0x68, 0xdf, 0x97, 0x7b, 0x35, 0xeb, 0x9d, 0x8e, + 0x98, 0xb3, 0xa0, 0xa3, 0xfc, 0x7c, 0x37, 0x29, 0x90, 0xbd, 0xdc, 0x50, 0x41, 0x9c, 0xa8, + 0x66, 0x93, 0xe4, 0x91, 0x75, 0x53, 0x38, 0xfe, 0xd4, 0xfb, 0x42, 0x23, 0x1a, 0x7c, 0x08, + 0x12, 0x52, 0xce, + ]; + + let g1 = Term::Constant( + Constant::Bls12_381G1Element(blst::blst_p1::uncompress(&bytes).unwrap().into()).into(), + ); + + let constant = Term::Constant( + Constant::Data(Data::constr( + 0, + vec![ + Data::bytestring(bytes), + Data::bytestring(vec![ + 0xb9, 0x21, 0x5e, 0x5b, 0xc4, 0x81, 0xba, 0x65, 0x52, 0x38, 0x4c, 0x89, 0xc2, + 0x3d, 0x45, 0xbd, 0x65, 0x0b, 0x69, 0x46, 0x28, 0x68, 0x24, 0x8b, 0xfb, 0xb8, + 0x3a, 0xee, 0x70, 0x60, 0x57, 0x94, 0x04, 0xdb, 0xa4, 0x1c, 0x78, 0x1d, 0xec, + 0x7c, 0x2b, 0xec, 0x5f, 0xcc, 0xec, 0x06, 0x84, 0x2e, 0x0e, 0x66, 0xad, 0x6d, + 0x86, 0xc7, 0xc7, 0x6c, 0x46, 0x8a, 0x32, 0xc9, 0xc0, 0x08, 0x0e, 0xea, 0x02, + 0x19, 0xd0, 0x95, 0x3b, 0x44, 0xb1, 0xc4, 0xf5, 0x60, 0x5a, 0xfb, 0x1e, 0x5a, + 0x31, 0x93, 0x26, 0x4f, 0xf7, 0x30, 0x22, 0x2e, 0x94, 0xf5, 0x52, 0x07, 0x62, + 0x82, 0x35, 0xf3, 0xb4, 0x23, + ]), + ], + )) + .into(), + ); + + assert_uplc( + src, + Term::bls12_381_g1_equal() + .apply(Term::bls12_381_g1_uncompress().apply( + Term::un_b_data().apply( + Term::head_list().apply( + Term::snd_pair().apply(Term::unconstr_data().apply(constant.clone())), + ), + ), + )) + .apply(g1), + false, + true, + ) +} + #[test] fn qualified_prelude_functions() { let src = r#" diff --git a/crates/uplc/src/optimize.rs b/crates/uplc/src/optimize.rs index ef60c20ad..9bfa122a8 100644 --- a/crates/uplc/src/optimize.rs +++ b/crates/uplc/src/optimize.rs @@ -6,7 +6,7 @@ pub mod shrinker; pub fn aiken_optimize_and_intern(program: Program) -> Program { program .inline_constr_ops() - .bls381_compressor() + .bls381_compressor_hoister() .builtin_force_reducer() .lambda_reducer() .inline_reducer() diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index 9a96283ea..670dc5208 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -1139,6 +1139,26 @@ impl Program { } pub fn bls381_compressor(self) -> Self { + let program = self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| { + if let Term::Constant(con) = term { + match con.as_ref() { + Constant::Bls12_381G1Element(blst_p1) => { + *term = Term::bls12_381_g1_uncompress() + .apply(Term::byte_string(blst_p1.compress())) + } + Constant::Bls12_381G2Element(blst_p2) => { + *term = Term::bls12_381_g2_uncompress() + .apply(Term::byte_string(blst_p2.compress())) + } + _ => (), + } + } + }); + + program + } + + pub fn bls381_compressor_hoister(self) -> Self { let mut blst_p1_list = vec![]; let mut blst_p2_list = vec![];