Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant hoisting and BLS const fix #1025

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 114 additions & 99 deletions crates/aiken-lang/src/gen_uplc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4988,134 +4988,149 @@ impl<'a> CodeGenerator<'a> {
}
Air::DefineFunc {
func_name,
params,
recursive,
recursive_nonstatic_params,
module_name,
variant_name,
variant,
} => {
let func_name = if module_name.is_empty() {
format!("{func_name}{variant_name}")
} else {
format!("{module_name}_{func_name}{variant_name}")
};
let mut func_body = arg_stack.pop().unwrap();

let mut term = arg_stack.pop().unwrap();

// Introduce a parameter for each parameter
// NOTE: we use recursive_nonstatic_params here because
// if this is recursive, those are the ones that need to be passed
// each time
for param in recursive_nonstatic_params.iter().rev() {
func_body = func_body.lambda(param.clone());
}
match variant {
air::FunctionVariants::Standard(params) => {
let mut func_body = arg_stack.pop().unwrap();

if recursive_nonstatic_params.is_empty() || params.is_empty() {
func_body = func_body.delay();
}
let term = arg_stack.pop().unwrap();

if !recursive {
term = term.lambda(func_name).apply(func_body.lambda(NO_INLINE));
if params.is_empty() {
func_body = func_body.delay();
}

Some(term)
} else {
func_body = func_body.lambda(func_name.clone());
let func_body = params
.into_iter()
.rfold(func_body, |term, arg| term.lambda(arg))
.lambda(NO_INLINE);

if recursive_nonstatic_params == params {
// If we don't have any recursive-static params, we can just emit the function as is
term = term
.lambda(func_name.clone())
.apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone())))
.lambda(func_name)
.apply(func_body.lambda(NO_INLINE));
} else {
// If we have parameters that remain static in each recursive call,
// we can construct an *outer* function to take those in
// and simplify the recursive part to only accept the non-static arguments
let mut recursive_func_body =
Term::var(&func_name).apply(Term::var(&func_name));
for param in recursive_nonstatic_params.iter() {
recursive_func_body = recursive_func_body.apply(Term::var(param));
}
Some(term.lambda(func_name).apply(func_body))
}
air::FunctionVariants::Recursive {
params,
recursive_nonstatic_params,
} => {
let mut func_body = arg_stack.pop().unwrap();

if recursive_nonstatic_params.is_empty() {
recursive_func_body = recursive_func_body.force();
}
let term = arg_stack.pop().unwrap();

// Then construct an outer function with *all* parameters, not just the nonstatic ones.
let mut outer_func_body =
recursive_func_body.lambda(&func_name).apply(func_body);
let no_statics = recursive_nonstatic_params == params;

// Now, add *all* parameters, so that other call sites don't know the difference
for param in params.iter().rev() {
outer_func_body = outer_func_body.lambda(param);
if recursive_nonstatic_params.is_empty() || params.is_empty() {
func_body = func_body.delay();
}

// And finally, fold that definition into the rest of our program
term = term
.lambda(&func_name)
.apply(outer_func_body.lambda(NO_INLINE));
}
let func_body = recursive_nonstatic_params
.iter()
.rfold(func_body, |term, arg| term.lambda(arg));

Some(term)
}
}
Air::DefineCyclicFuncs {
func_name,
module_name,
variant_name,
contained_functions,
} => {
let func_name = if module_name.is_empty() {
format!("{func_name}{variant_name}")
} else {
format!("{module_name}_{func_name}{variant_name}")
};
let mut cyclic_functions = vec![];
let func_body = func_body.lambda(func_name.clone());

for params in contained_functions {
let func_body = arg_stack.pop().unwrap();
if no_statics {
// If we don't have any recursive-static params, we can just emit the function as is
Some(
term.lambda(func_name.clone())
.apply(
Term::var(func_name.clone())
.apply(Term::var(func_name.clone())),
)
.lambda(func_name)
.apply(func_body.lambda(NO_INLINE)),
)
} else {
// If we have parameters that remain static in each recursive call,
// we can construct an *outer* function to take those in
// and simplify the recursive part to only accept the non-static arguments
let mut recursive_func_body =
Term::var(&func_name).apply(Term::var(&func_name));

if recursive_nonstatic_params.is_empty() {
recursive_func_body = recursive_func_body.force();
}

cyclic_functions.push((params, func_body));
}
let mut term = arg_stack.pop().unwrap();
// Introduce a parameter for each parameter
// NOTE: we use recursive_nonstatic_params here because
// if this is recursive, those are the ones that need to be passed
// each time
for param in recursive_nonstatic_params.into_iter() {
recursive_func_body = recursive_func_body.apply(Term::var(param));
}

let mut cyclic_body = Term::var("__chooser");
// Then construct an outer function with *all* parameters, not just the nonstatic ones.
let mut outer_func_body =
recursive_func_body.lambda(&func_name).apply(func_body);

for (params, func_body) in cyclic_functions.into_iter() {
let mut function = func_body;
if params.is_empty() {
function = function.delay();
} else {
for param in params.iter().rev() {
function = function.lambda(param);
// Now, add *all* parameters, so that other call sites don't know the difference
outer_func_body = params
.clone()
.into_iter()
.rfold(outer_func_body, |term, arg| term.lambda(arg));

// And finally, fold that definition into the rest of our program
Some(
term.lambda(&func_name)
.apply(outer_func_body.lambda(NO_INLINE)),
)
}
}
air::FunctionVariants::Constant => todo!(),
air::FunctionVariants::Cyclic(contained_functions) => {
let mut cyclic_functions = vec![];

// We basically Scott encode our function bodies and use the chooser function
// to determine which function body and params is run
// For example say there is a cycle of 3 function bodies
// Our choose function can look like this:
// \func1 -> \func2 -> \func3 -> func1
// In this case our chooser is a function that takes in 3 functions
// and returns the first one to run
cyclic_body = cyclic_body.apply(function)
}
for params in contained_functions {
let func_body = arg_stack.pop().unwrap();

term = term
.lambda(&func_name)
.apply(Term::var(&func_name).apply(Term::var(&func_name)))
.lambda(&func_name)
.apply(
cyclic_body
.lambda("__chooser")
.lambda(func_name)
.lambda(NO_INLINE),
);
cyclic_functions.push((params, func_body));
}
let mut term = arg_stack.pop().unwrap();

Some(term)
let mut cyclic_body = Term::var("__chooser");

for (params, func_body) in cyclic_functions.into_iter() {
let mut function = func_body;
if params.is_empty() {
function = function.delay();
} else {
for param in params.iter().rev() {
function = function.lambda(param);
}
}

// We basically Scott encode our function bodies and use the chooser function
// to determine which function body and params is run
// For example say there is a cycle of 3 function bodies
// Our choose function can look like this:
// \func1 -> \func2 -> \func3 -> func1
// In this case our chooser is a function that takes in 3 functions
// and returns the first one to run
cyclic_body = cyclic_body.apply(function)
}

term = term
.lambda(&func_name)
.apply(Term::var(&func_name).apply(Term::var(&func_name)))
.lambda(&func_name)
.apply(
cyclic_body
.lambda("__chooser")
.lambda(func_name)
.lambda(NO_INLINE),
);

Some(term)
}
}
}

Air::Let { name } => {
let arg = arg_stack.pop().unwrap();

Expand Down
22 changes: 12 additions & 10 deletions crates/aiken-lang/src/gen_uplc/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ impl From<bool> for ExpectLevel {
}
}

#[derive(Debug, Clone, PartialEq)]
pub enum FunctionVariants {
Standard(Vec<String>),
Recursive {
params: Vec<String>,
recursive_nonstatic_params: Vec<String>,
},
Cyclic(Vec<Vec<String>>),
Constant,
}

#[derive(Debug, Clone, PartialEq)]
pub enum Air {
// Primitives
Expand Down Expand Up @@ -65,19 +76,10 @@ pub enum Air {
tipo: Rc<Type>,
},
DefineFunc {
func_name: String,
module_name: String,
params: Vec<String>,
recursive: bool,
recursive_nonstatic_params: Vec<String>,
variant_name: String,
},
DefineCyclicFuncs {
func_name: String,
module_name: String,
variant_name: String,
// just the params
contained_functions: Vec<Vec<String>>,
variant: FunctionVariants,
},
Fn {
params: Vec<String>,
Expand Down
Loading
Loading