From 5850c74d01bda706ceca9fcc73b47d9b197eeb8e Mon Sep 17 00:00:00 2001 From: rvcas Date: Thu, 22 Aug 2024 17:03:47 -0400 Subject: [PATCH] feat: enforcement that spend first arg is option --- crates/aiken-lang/src/ast.rs | 4 ++++ crates/aiken-lang/src/tests/check.rs | 30 ++++++++++++++++++---------- crates/aiken-lang/src/tipo/infer.rs | 10 ++++++++++ 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/crates/aiken-lang/src/ast.rs b/crates/aiken-lang/src/ast.rs index 95310c6c7..df6087ad5 100644 --- a/crates/aiken-lang/src/ast.rs +++ b/crates/aiken-lang/src/ast.rs @@ -278,6 +278,10 @@ impl TypedFunction { }) } + pub fn is_spend(&self) -> bool { + self.name == HANDLER_SPEND + } + pub fn has_valid_purpose_name(&self) -> bool { self.name == HANDLER_SPEND || self.name == HANDLER_PUBLISH diff --git a/crates/aiken-lang/src/tests/check.rs b/crates/aiken-lang/src/tests/check.rs index 82c2b10b6..97b81ba75 100644 --- a/crates/aiken-lang/src/tests/check.rs +++ b/crates/aiken-lang/src/tests/check.rs @@ -358,7 +358,7 @@ fn expect_multi_patterns() { fn validator_correct_form() { let source_code = r#" validator foo { - spend(d, r, oref, c) { + spend(d: Option, r, oref, c) { True } } @@ -389,7 +389,7 @@ fn validator_in_lib_warning() { fn multi_validator() { let source_code = r#" validator foo(foo: ByteArray, bar: Int) { - spend(_d, _r, _oref, _c) { + spend(_d: Option, _r, _oref, _c) { foo == #"aabb" } @@ -408,7 +408,7 @@ fn multi_validator() { fn multi_validator_warning() { let source_code = r#" validator foo(foo: ByteArray, bar: Int) { - spend(_d, _r, _oref, _c) { + spend(_d: Option, _r, _oref, _c) { foo == #"aabb" } @@ -458,7 +458,7 @@ fn exhaustiveness_simple() { fn validator_args_no_annotation() { let source_code = r#" validator hello(d) { - spend(a, b, oref, c) { + spend(a: Option, b, oref, c) { True } } @@ -475,9 +475,13 @@ fn validator_args_no_annotation() { assert!(param.tipo.is_data()); }); - validator.handlers[0].arguments.iter().for_each(|arg| { - assert!(arg.tipo.is_data()); - }) + validator.handlers[0] + .arguments + .iter() + .skip(1) + .for_each(|arg| { + assert!(arg.tipo.is_data()); + }) }) } @@ -2451,8 +2455,10 @@ fn validator_private_type_leak() { } validator bar { - spend(datum: Datum, redeemer: Redeemer, _oref, _ctx) { - datum.foo == redeemer.bar + spend(datum: Option, redeemer: Redeemer, _oref, _ctx) { + expect Some(d) = datum + + d.foo == redeemer.bar } } "#; @@ -2475,8 +2481,10 @@ fn validator_public() { } validator bar { - spend(datum: Datum, redeemer: Redeemer, _oref, _ctx) { - datum.foo == redeemer.bar + spend(datum: Option, redeemer: Redeemer, _oref, _ctx) { + expect Some(d) = datum + + d.foo == redeemer.bar } } "#; diff --git a/crates/aiken-lang/src/tipo/infer.rs b/crates/aiken-lang/src/tipo/infer.rs index aa0b70c24..b0b316163 100644 --- a/crates/aiken-lang/src/tipo/infer.rs +++ b/crates/aiken-lang/src/tipo/infer.rs @@ -225,6 +225,16 @@ fn infer_definition( }); } + if typed_fun.is_spend() && !typed_fun.arguments[0].tipo.is_option() { + return Err(Error::CouldNotUnify { + location: typed_fun.arguments[0].location, + expected: Type::option(typed_fun.arguments[0].tipo.clone()), + given: typed_fun.arguments[0].tipo.clone(), + situation: None, + rigid_type_names: Default::default(), + }); + } + for arg in typed_fun.arguments.iter_mut() { if arg.tipo.is_unbound() { arg.tipo = Type::data();