diff --git a/Cargo.lock b/Cargo.lock index 3957fdc250..1beba1dda2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1460,6 +1460,7 @@ dependencies = [ "ctor", "indoc", "itertools 0.12.1", + "num-bigint", "num-traits 0.2.19", "regex", "serde", @@ -4662,9 +4663,13 @@ dependencies = [ "base16ct", "blockifier", "cairo-lang-casm", + "cairo-lang-diagnostics", + "cairo-lang-filesystem", + "cairo-lang-parser", "cairo-lang-runner", "cairo-lang-sierra", "cairo-lang-sierra-to-casm", + "cairo-lang-syntax", "cairo-lang-utils", "cairo-vm", "camino", @@ -4677,6 +4682,7 @@ dependencies = [ "fs_extra", "indoc", "itertools 0.12.1", + "num-bigint", "num-traits 0.2.19", "primitive-types", "project-root", diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs index f05bdbc8cd..81e4764ab5 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs @@ -37,8 +37,9 @@ use cairo_vm::vm::{ use cairo_vm::Felt252; use conversions::byte_array::ByteArray; use conversions::felt252::TryInferFormat; -use conversions::serde::deserialize::{BufferReader, CairoDeserialize}; +use conversions::serde::deserialize::BufferReader; use conversions::serde::serialize::CairoSerialize; +use conversions::u256::CairoU256; use runtime::{ CheatcodeHandlingResult, EnhancedHintError, ExtendedRuntime, ExtensionLogic, SyscallHandlingResult, @@ -491,30 +492,6 @@ impl<'a> ExtensionLogic for ForgeExtension<'a> { } } -#[derive(CairoDeserialize, CairoSerialize)] -struct CairoU256 { - low: u128, - high: u128, -} - -impl CairoU256 { - fn from_bytes(bytes: &[u8]) -> Self { - Self { - low: u128::from_be_bytes(bytes[16..32].try_into().unwrap()), - high: u128::from_be_bytes(bytes[0..16].try_into().unwrap()), - } - } - - fn to_be_bytes(&self) -> [u8; 32] { - let mut result = [0; 32]; - - result[16..].copy_from_slice(&self.low.to_be_bytes()); - result[..16].copy_from_slice(&self.high.to_be_bytes()); - - result - } -} - #[derive(CairoSerialize)] enum SignError { InvalidSecretKey, diff --git a/crates/conversions/Cargo.toml b/crates/conversions/Cargo.toml index ebd5845265..07b68f326d 100644 --- a/crates/conversions/Cargo.toml +++ b/crates/conversions/Cargo.toml @@ -24,6 +24,7 @@ serde.workspace = true num-traits.workspace = true itertools.workspace = true cairo-serde-macros = { path = "cairo-serde-macros" } +num-bigint.workspace = true [dev-dependencies] ctor.workspace = true diff --git a/crates/conversions/src/lib.rs b/crates/conversions/src/lib.rs index 5209b14002..7cef947986 100644 --- a/crates/conversions/src/lib.rs +++ b/crates/conversions/src/lib.rs @@ -10,6 +10,8 @@ pub mod nonce; pub mod primitive; pub mod serde; pub mod string; +pub mod u256; +pub mod u512; pub trait FromConv: Sized { fn from_(value: T) -> Self; diff --git a/crates/conversions/src/serde/serialize/serialize_impl.rs b/crates/conversions/src/serde/serialize/serialize_impl.rs index b944492e4e..7685e5e3b4 100644 --- a/crates/conversions/src/serde/serialize/serialize_impl.rs +++ b/crates/conversions/src/serde/serialize/serialize_impl.rs @@ -229,6 +229,12 @@ impl_serialize_for_num_type!(u64); impl_serialize_for_num_type!(u128); impl_serialize_for_num_type!(usize); +impl_serialize_for_num_type!(i8); +impl_serialize_for_num_type!(i16); +impl_serialize_for_num_type!(i32); +impl_serialize_for_num_type!(i64); +impl_serialize_for_num_type!(i128); + impl_serialize_for_tuple!(); impl_serialize_for_tuple!(A); impl_serialize_for_tuple!(A, B); diff --git a/crates/conversions/src/u256.rs b/crates/conversions/src/u256.rs new file mode 100644 index 0000000000..11254bd2dd --- /dev/null +++ b/crates/conversions/src/u256.rs @@ -0,0 +1,56 @@ +use crate as conversions; // Must be imported because of derive macros +use cairo_serde_macros::{CairoDeserialize, CairoSerialize}; +use num_bigint::{BigUint, ParseBigIntError}; +use std::str::FromStr; + +#[derive(CairoDeserialize, CairoSerialize, Debug)] +pub struct CairoU256 { + low: u128, + high: u128, +} + +impl CairoU256 { + #[must_use] + pub fn from_bytes(bytes: &[u8]) -> Self { + Self { + low: u128::from_be_bytes(bytes[16..32].try_into().unwrap()), + high: u128::from_be_bytes(bytes[0..16].try_into().unwrap()), + } + } + + #[must_use] + pub fn to_be_bytes(&self) -> [u8; 32] { + let mut result = [0; 32]; + + result[16..].copy_from_slice(&self.low.to_be_bytes()); + result[..16].copy_from_slice(&self.high.to_be_bytes()); + + result + } +} + +#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)] +pub enum ParseCairoU256Error { + #[error(transparent)] + InvalidString(#[from] ParseBigIntError), + #[error("Number is too large to fit in 32 bytes")] + Overflow, +} + +impl FromStr for CairoU256 { + type Err = ParseCairoU256Error; + + fn from_str(input: &str) -> Result { + let bytes = input.parse::()?.to_bytes_be(); + + if bytes.len() > 32 { + return Err(ParseCairoU256Error::Overflow); + } + + let mut result = [0u8; 32]; + let start = 32 - bytes.len(); + result[start..].copy_from_slice(&bytes); + + Ok(CairoU256::from_bytes(&result)) + } +} diff --git a/crates/conversions/src/u512.rs b/crates/conversions/src/u512.rs new file mode 100644 index 0000000000..baf97a332e --- /dev/null +++ b/crates/conversions/src/u512.rs @@ -0,0 +1,62 @@ +use crate as conversions; // Must be imported because of derive macros +use cairo_serde_macros::{CairoDeserialize, CairoSerialize}; +use num_bigint::{BigUint, ParseBigIntError}; +use std::str::FromStr; + +#[derive(CairoDeserialize, CairoSerialize, Debug)] +pub struct CairoU512 { + limb_0: u128, + limb_1: u128, + limb_2: u128, + limb_3: u128, +} + +impl CairoU512 { + #[must_use] + pub fn from_bytes(bytes: &[u8]) -> Self { + Self { + limb_0: u128::from_be_bytes(bytes[48..64].try_into().unwrap()), + limb_1: u128::from_be_bytes(bytes[32..48].try_into().unwrap()), + limb_2: u128::from_be_bytes(bytes[16..32].try_into().unwrap()), + limb_3: u128::from_be_bytes(bytes[00..16].try_into().unwrap()), + } + } + + #[must_use] + pub fn to_be_bytes(&self) -> [u8; 64] { + let mut result = [0; 64]; + + result[48..64].copy_from_slice(&self.limb_0.to_be_bytes()); + result[32..48].copy_from_slice(&self.limb_1.to_be_bytes()); + result[16..32].copy_from_slice(&self.limb_2.to_be_bytes()); + result[00..16].copy_from_slice(&self.limb_3.to_be_bytes()); + + result + } +} + +#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)] +pub enum ParseCairoU512Error { + #[error(transparent)] + InvalidString(#[from] ParseBigIntError), + #[error("Number is too large to fit in 64 bytes")] + Overflow, +} + +impl FromStr for CairoU512 { + type Err = ParseCairoU512Error; + + fn from_str(input: &str) -> Result { + let bytes = input.parse::()?.to_bytes_be(); + + if bytes.len() > 64 { + return Err(ParseCairoU512Error::Overflow); + } + + let mut result = [0u8; 64]; + let start = 64 - bytes.len(); + result[start..].copy_from_slice(&bytes); + + Ok(CairoU512::from_bytes(&result)) + } +} diff --git a/crates/sncast/Cargo.toml b/crates/sncast/Cargo.toml index 550c8cf826..60067a007e 100644 --- a/crates/sncast/Cargo.toml +++ b/crates/sncast/Cargo.toml @@ -35,8 +35,13 @@ cairo-lang-casm.workspace = true cairo-lang-sierra-to-casm.workspace = true cairo-lang-utils.workspace = true cairo-lang-sierra.workspace = true +cairo-lang-parser.workspace = true +cairo-lang-syntax.workspace = true +cairo-lang-diagnostics.workspace = true +cairo-lang-filesystem.workspace = true itertools.workspace = true num-traits.workspace = true +num-bigint.workspace = true starknet-types-core.workspace = true cairo-vm.workspace = true blockifier.workspace = true diff --git a/crates/sncast/src/helpers/data_transformer/calldata_representation.rs b/crates/sncast/src/helpers/data_transformer/calldata_representation.rs new file mode 100644 index 0000000000..7d7b5805fd --- /dev/null +++ b/crates/sncast/src/helpers/data_transformer/calldata_representation.rs @@ -0,0 +1,218 @@ +use anyhow::{bail, Context}; +use conversions::{ + byte_array::ByteArray, + serde::serialize::{BufferWriter, CairoSerialize}, + u256::CairoU256, + u512::CairoU512, +}; +use starknet::core::types::Felt; +use std::{any::type_name, str::FromStr}; + +#[derive(Debug)] +pub(super) struct CalldataStructField(AllowedCalldataArguments); + +impl CalldataStructField { + pub fn new(value: AllowedCalldataArguments) -> Self { + Self(value) + } +} + +#[derive(Debug)] +pub(super) struct CalldataStruct(Vec); + +impl CalldataStruct { + pub fn new(arguments: Vec) -> Self { + Self(arguments) + } +} + +#[derive(Debug)] +pub(super) struct CalldataArrayMacro(Vec); + +impl CalldataArrayMacro { + pub fn new(arguments: Vec) -> Self { + Self(arguments) + } +} + +#[derive(Debug)] +pub(super) struct CalldataEnum { + position: usize, + argument: Option>, +} + +impl CalldataEnum { + pub fn new(position: usize, argument: Option>) -> Self { + Self { position, argument } + } +} + +#[derive(Debug)] +pub(super) enum CalldataSingleArgument { + Bool(bool), + U8(u8), + U16(u16), + U32(u32), + U64(u64), + U128(u128), + U256(CairoU256), + U512(CairoU512), + I8(i8), + I16(i16), + I32(i32), + I64(i64), + I128(i128), + Felt(Felt), + ByteArray(ByteArray), +} + +fn neat_parsing_error_message(value: &str, parsing_type: &str, reason: Option<&str>) -> String { + if let Some(message) = reason { + format!(r#"Failed to parse value "{value}" into type "{parsing_type}": {message}"#) + } else { + format!(r#"Failed to parse value "{value}" into type "{parsing_type}""#) + } +} + +fn parse_with_type(value: &str) -> anyhow::Result +where + ::Err: std::error::Error + Send + Sync + 'static, +{ + value + .parse::() + .context(neat_parsing_error_message(value, type_name::(), None)) +} + +impl CalldataSingleArgument { + pub(super) fn try_new(type_with_path: &str, value: &str) -> anyhow::Result { + // TODO add all corelib types + let type_str = type_with_path + .split("::") + .last() + .context("Couldn't parse parameter type from ABI")?; + + match type_str { + "u8" => Ok(Self::U8(parse_with_type(value)?)), + "u16" => Ok(Self::U16(parse_with_type(value)?)), + "u32" => Ok(Self::U32(parse_with_type(value)?)), + "u64" => Ok(Self::U64(parse_with_type(value)?)), + "u128" => Ok(Self::U128(parse_with_type(value)?)), + "u256" => Ok(Self::U256(parse_with_type(value)?)), + "u512" => Ok(Self::U512(parse_with_type(value)?)), + "i8" => Ok(Self::I8(parse_with_type(value)?)), + "i16" => Ok(Self::I16(parse_with_type(value)?)), + "i32" => Ok(Self::I32(parse_with_type(value)?)), + "i64" => Ok(Self::I64(parse_with_type(value)?)), + "i128" => Ok(Self::I128(parse_with_type(value)?)), + // TODO check if bytes31 is actually a felt + // (e.g. alexandria_data_structures::bit_array::BitArray uses that) + // https://github.com/starkware-libs/cairo/blob/bf48e658b9946c2d5446eeb0c4f84868e0b193b5/corelib/src/bytes_31.cairo#L14 + // There is `bytes31_try_from_felt252`, which means it isn't always a valid felt? + "felt252" | "felt" | "ContractAddress" | "ClassHash" | "bytes31" => { + let felt = Felt::from_dec_str(value) + .with_context(|| neat_parsing_error_message(value, type_with_path, None))?; + Ok(Self::Felt(felt)) + } + "bool" => Ok(Self::Bool(parse_with_type(value)?)), + "ByteArray" => Ok(Self::ByteArray(ByteArray::from(value))), + _ => { + bail!(neat_parsing_error_message( + value, + type_with_path, + Some(&format!("unsupported type {type_with_path}")) + )) + } + } + } +} + +#[derive(Debug)] +pub(super) struct CalldataTuple(Vec); + +impl CalldataTuple { + pub fn new(arguments: Vec) -> Self { + Self(arguments) + } +} + +#[derive(Debug)] +pub(super) enum AllowedCalldataArguments { + Struct(CalldataStruct), + ArrayMacro(CalldataArrayMacro), + Enum(CalldataEnum), + // TODO rename to BasicType or smth + SingleArgument(CalldataSingleArgument), + Tuple(CalldataTuple), +} + +impl CairoSerialize for CalldataSingleArgument { + // https://docs.starknet.io/architecture-and-concepts/smart-contracts/serialization-of-cairo-types/ + fn serialize(&self, output: &mut BufferWriter) { + match self { + CalldataSingleArgument::Bool(value) => value.serialize(output), + CalldataSingleArgument::U8(value) => value.serialize(output), + CalldataSingleArgument::U16(value) => value.serialize(output), + CalldataSingleArgument::U32(value) => value.serialize(output), + CalldataSingleArgument::U64(value) => value.serialize(output), + CalldataSingleArgument::U128(value) => value.serialize(output), + CalldataSingleArgument::U256(value) => value.serialize(output), + CalldataSingleArgument::U512(value) => value.serialize(output), + CalldataSingleArgument::I8(value) => value.serialize(output), + CalldataSingleArgument::I16(value) => value.serialize(output), + CalldataSingleArgument::I32(value) => value.serialize(output), + CalldataSingleArgument::I64(value) => value.serialize(output), + CalldataSingleArgument::I128(value) => value.serialize(output), + CalldataSingleArgument::Felt(value) => value.serialize(output), + CalldataSingleArgument::ByteArray(value) => value.serialize(output), + }; + } +} + +impl CairoSerialize for CalldataStructField { + // Every argument serialized in order of occurrence + fn serialize(&self, output: &mut BufferWriter) { + self.0.serialize(output); + } +} + +impl CairoSerialize for CalldataStruct { + // https://docs.starknet.io/architecture-and-concepts/smart-contracts/serialization-of-cairo-types/#serialization_of_structs + fn serialize(&self, output: &mut BufferWriter) { + self.0.iter().for_each(|field| field.serialize(output)); + } +} + +impl CairoSerialize for CalldataTuple { + fn serialize(&self, output: &mut BufferWriter) { + self.0.iter().for_each(|field| field.serialize(output)); + } +} + +impl CairoSerialize for CalldataArrayMacro { + // https://docs.starknet.io/architecture-and-concepts/smart-contracts/serialization-of-cairo-types/#serialization_of_arrays + fn serialize(&self, output: &mut BufferWriter) { + self.0.len().serialize(output); + self.0.iter().for_each(|field| field.serialize(output)); + } +} + +impl CairoSerialize for CalldataEnum { + // https://docs.starknet.io/architecture-and-concepts/smart-contracts/serialization-of-cairo-types/#serialization_of_enums + fn serialize(&self, output: &mut BufferWriter) { + self.position.serialize(output); + if self.argument.is_some() { + self.argument.as_ref().unwrap().serialize(output); + } + } +} +impl CairoSerialize for AllowedCalldataArguments { + fn serialize(&self, output: &mut BufferWriter) { + match self { + AllowedCalldataArguments::Struct(value) => value.serialize(output), + AllowedCalldataArguments::ArrayMacro(value) => value.serialize(output), + AllowedCalldataArguments::Enum(value) => value.serialize(output), + AllowedCalldataArguments::SingleArgument(value) => value.serialize(output), + AllowedCalldataArguments::Tuple(value) => value.serialize(output), + } + } +} diff --git a/crates/sncast/src/helpers/data_transformer/mod.rs b/crates/sncast/src/helpers/data_transformer/mod.rs new file mode 100644 index 0000000000..00a1c860eb --- /dev/null +++ b/crates/sncast/src/helpers/data_transformer/mod.rs @@ -0,0 +1,3 @@ +pub mod calldata_representation; +pub mod sierra_abi; +pub mod transformer; diff --git a/crates/sncast/src/helpers/data_transformer/sierra_abi.rs b/crates/sncast/src/helpers/data_transformer/sierra_abi.rs new file mode 100644 index 0000000000..75e454daca --- /dev/null +++ b/crates/sncast/src/helpers/data_transformer/sierra_abi.rs @@ -0,0 +1,625 @@ +use crate::helpers::data_transformer::calldata_representation::{ + AllowedCalldataArguments, CalldataArrayMacro, CalldataEnum, CalldataSingleArgument, + CalldataStruct, CalldataStructField, CalldataTuple, +}; +use anyhow::{bail, ensure, Context, Result}; +use cairo_lang_parser::utils::SimpleParserDatabase; +use cairo_lang_syntax::node::ast::PathSegment::Simple; +use cairo_lang_syntax::node::ast::{ + ArgClause, ArgList, Expr, ExprFunctionCall, ExprInlineMacro, ExprListParenthesized, ExprPath, + ExprStructCtorCall, ExprUnary, OptionStructArgExpr, PathSegment, StructArg, TerminalFalse, + TerminalLiteralNumber, TerminalShortString, TerminalString, TerminalTrue, UnaryOperator, + WrappedArgList, +}; +use cairo_lang_syntax::node::{Terminal, Token}; +use itertools::Itertools; +use regex::Regex; +use starknet::core::types::contract::{AbiEntry, AbiEnum, AbiNamedMember, AbiStruct}; +use std::collections::HashSet; +use std::ops::Neg; + +pub(super) fn build_representation( + expression: Expr, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, +) -> Result { + match expression { + Expr::StructCtorCall(item) => item.transform(expected_type, abi, db), + Expr::Literal(item) => item.transform(expected_type, abi, db), + Expr::Unary(item) => item.transform(expected_type, abi, db), + Expr::ShortString(item) => item.transform(expected_type, abi, db), + Expr::String(item) => item.transform(expected_type, abi, db), + Expr::False(item) => item.transform(expected_type, abi, db), + Expr::True(item) => item.transform(expected_type, abi, db), + Expr::Path(item) => item.transform(expected_type, abi, db), + Expr::FunctionCall(item) => item.transform(expected_type, abi, db), + Expr::InlineMacro(item) => item.transform(expected_type, abi, db), + Expr::Tuple(item) => item.transform(expected_type, abi, db), + _ => { + bail!(r#"Invalid argument type: unsupported expression for type "{expected_type}""#) + } + } +} + +trait SupportedCalldataKind { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result; +} + +impl SupportedCalldataKind for ExprStructCtorCall { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let struct_path: Vec = split(&self.path(db), db)?; + let struct_path_joined = struct_path.clone().join("::"); + + validate_path_argument(expected_type, &struct_path, &struct_path_joined)?; + + let structs_from_abi = find_all_structs(abi); + let struct_abi_definition = find_valid_enum_or_struct(structs_from_abi, &struct_path)?; + + let struct_args = self.arguments(db).arguments(db).elements(db); + + let struct_args_with_values = get_struct_arguments_with_values(&struct_args, db) + .context("Found invalid expression in struct argument")?; + + if struct_args_with_values.len() != struct_abi_definition.members.len() { + bail!( + r#"Invalid number of struct arguments in struct "{}", expected {} arguments, found {}"#, + struct_path_joined, + struct_abi_definition.members.len(), + struct_args.len() + ) + } + + // validate if all arguments' names have corresponding names in abi + if struct_args_with_values + .iter() + .map(|(arg_name, _)| arg_name.clone()) + .collect::>() + != struct_abi_definition + .members + .iter() + .map(|x| x.name.clone()) + .collect::>() + { + // TODO add message which arguments are invalid + bail!( + r#"Arguments in constructor invocation for struct {} do not match struct arguments in ABI"#, + expected_type + ) + } + + let fields = struct_args_with_values + .into_iter() + .map(|(arg_name, expr)| { + let abi_entry = struct_abi_definition + .members + .iter() + .find(|&abi_member| abi_member.name == arg_name) + .expect("Arg name should be in ABI - it is checked before with HashSets"); + Ok(CalldataStructField::new(build_representation( + expr, + &abi_entry.r#type, + abi, + db, + )?)) + }) + .collect::>>()?; + + Ok(AllowedCalldataArguments::Struct(CalldataStruct::new( + fields, + ))) + } +} + +impl SupportedCalldataKind for TerminalLiteralNumber { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let (value, suffix) = self + .numeric_value_and_suffix(db) + .with_context(|| format!("Couldn't parse value: {}", self.text(db)))?; + + let proper_param_type = match suffix { + None => expected_type, + Some(ref suffix) => suffix.as_str(), + }; + + Ok(AllowedCalldataArguments::SingleArgument( + CalldataSingleArgument::try_new(proper_param_type, &value.to_string())?, + )) + } +} + +impl SupportedCalldataKind for ExprUnary { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let (value, suffix) = match self.expr(db) { + Expr::Literal(literal_number) => literal_number + .numeric_value_and_suffix(db) + .with_context(|| format!("Couldn't parse value: {}", literal_number.text(db))), + _ => bail!("Invalid expression with unary operator, only numbers allowed"), + }?; + + let proper_param_type = match suffix { + None => expected_type, + Some(ref suffix) => suffix.as_str(), + }; + + match self.op(db) { + UnaryOperator::Not(_) => bail!( + "Invalid unary operator in expression !{} , only - allowed, got !", + value + ), + UnaryOperator::Desnap(_) => bail!( + "Invalid unary operator in expression *{} , only - allowed, got *", + value + ), + UnaryOperator::BitNot(_) => bail!( + "Invalid unary operator in expression ~{} , only - allowed, got ~", + value + ), + UnaryOperator::At(_) => bail!( + "Invalid unary operator in expression @{} , only - allowed, got @", + value + ), + UnaryOperator::Minus(_) => {} + } + + Ok(AllowedCalldataArguments::SingleArgument( + CalldataSingleArgument::try_new(proper_param_type, &value.neg().to_string())?, + )) + } +} + +impl SupportedCalldataKind for TerminalShortString { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let value = self + .string_value(db) + .context("Invalid shortstring passed as an argument")?; + + Ok(AllowedCalldataArguments::SingleArgument( + CalldataSingleArgument::try_new(expected_type, &value)?, + )) + } +} + +impl SupportedCalldataKind for TerminalString { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let value = self + .string_value(db) + .context("Invalid string passed as an argument")?; + + Ok(AllowedCalldataArguments::SingleArgument( + CalldataSingleArgument::try_new(expected_type, &value)?, + )) + } +} + +impl SupportedCalldataKind for TerminalFalse { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + // Could use terminal_false.boolean_value(db) and simplify try_new() + let value = self.text(db).to_string(); + + Ok(AllowedCalldataArguments::SingleArgument( + CalldataSingleArgument::try_new(expected_type, &value)?, + )) + } +} + +impl SupportedCalldataKind for TerminalTrue { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let value = self.text(db).to_string(); + + Ok(AllowedCalldataArguments::SingleArgument( + CalldataSingleArgument::try_new(expected_type, &value)?, + )) + } +} + +impl SupportedCalldataKind for ExprPath { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + // Enums with no value - Enum::Variant + let enum_path_with_variant = split(self, db)?; + let (enum_variant_name, enum_path) = enum_path_with_variant.split_last().unwrap(); + let enum_path_joined = enum_path.join("::"); + + validate_path_argument(expected_type, enum_path, &enum_path_joined)?; + + let (enum_position, enum_variant) = + find_enum_variant_position(enum_variant_name, enum_path, abi)?; + + if enum_variant.r#type != "()" { + bail!( + r#"Couldn't find variant "{}" in enum "{}""#, + enum_variant_name, + enum_path_joined + ) + } + + Ok(AllowedCalldataArguments::Enum(CalldataEnum::new( + enum_position, + None, + ))) + } +} + +impl SupportedCalldataKind for ExprFunctionCall { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + // Enums with value - Enum::Variant(10) + let enum_path_with_variant = split(&self.path(db), db)?; + let (enum_variant_name, enum_path) = enum_path_with_variant.split_last().unwrap(); + let enum_path_joined = enum_path.join("::"); + + validate_path_argument(expected_type, enum_path, &enum_path_joined)?; + + let (enum_position, enum_variant) = + find_enum_variant_position(enum_variant_name, enum_path, abi)?; + + // When creating an enum with variant, there can be only one argument. Parsing the + // argument inside ArgList (enum_expr_path_with_value.arguments(db).arguments(db)), + // then popping from the vector and unwrapping safely. + let expr = parse_argument_list(&self.arguments(db).arguments(db), db)? + .pop() + .unwrap(); + let parsed_expr = build_representation(expr, &enum_variant.r#type, abi, db)?; + + Ok(AllowedCalldataArguments::Enum(CalldataEnum::new( + enum_position, + Some(Box::new(parsed_expr)), + ))) + } +} + +impl SupportedCalldataKind for ExprInlineMacro { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + // array![] calls + let parsed_exprs = parse_inline_macro(self, db)?; + + let array_element_type_pattern = Regex::new("core::array::Array::<(.*)>").unwrap(); + let abi_argument_type = array_element_type_pattern + .captures(expected_type) + .with_context(|| { + format!(r#"Invalid argument type, expected "{expected_type}", got array"#,) + })? + .get(1) + // TODO better message + .with_context(|| { + format!( + "Couldn't parse array element type from the ABI array parameter: {expected_type}" + ) + })? + .as_str(); + + let arguments = parsed_exprs + .into_iter() + .map(|arg| build_representation(arg, abi_argument_type, abi, db)) + .collect::>>()?; + + Ok(AllowedCalldataArguments::ArrayMacro( + CalldataArrayMacro::new(arguments), + )) + } +} + +impl SupportedCalldataKind for ExprListParenthesized { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + // Regex capturing types between the parentheses, e.g.: for "(core::felt252, core::u8)" + // will capture "core::felt252, core::u8" + let tuple_types_pattern = Regex::new(r"\(([^)]+)\)").unwrap(); + let tuple_types: Vec<&str> = tuple_types_pattern + .captures(expected_type) + .with_context(|| { + format!(r#"Invalid argument type, expected "{expected_type}", got tuple"#,) + })? + .get(1) + .map(|x| x.as_str().split(", ").collect()) + .unwrap(); + + let parsed_exprs = self + .expressions(db) + .elements(db) + .into_iter() + .zip(tuple_types) + .map(|(expr, single_param)| build_representation(expr, single_param, abi, db)) + .collect::>>()?; + + Ok(AllowedCalldataArguments::Tuple(CalldataTuple::new( + parsed_exprs, + ))) + } +} + +fn split(path: &ExprPath, db: &SimpleParserDatabase) -> Result> { + path.elements(db) + .iter() + .map(|p| match p { + Simple(segment) => Ok(segment.ident(db).token(db).text(db).to_string()), + PathSegment::WithGenericArgs(_) => { + bail!("Cannot use generic args when specifying struct/enum path") + } + }) + .collect::>>() +} + +fn get_struct_arguments_with_values( + arguments: &[StructArg], + db: &SimpleParserDatabase, +) -> Result> { + arguments + .iter() + .map(|elem| { + match elem { + // Holds info about parameter and argument in struct creation, e.g.: + // in case of "Struct { a: 1, b: 2 }", two separate StructArgSingle hold info + // about "a: 1" and "b: 2" respectively. + StructArg::StructArgSingle(whole_arg) => { + match whole_arg.arg_expr(db) { + // TODO add comment + // dunno what that is + // probably case when there is Struct {a, b} and there are variables a and b + OptionStructArgExpr::Empty(_) => { + bail!( + "Single arg, used {ident}, should be {ident}: value", + ident = whole_arg.identifier(db).text(db) + ) + } + // Holds info about the argument, e.g.: in case of "a: 1" holds info + // about ": 1" + OptionStructArgExpr::StructArgExpr(arg_value_with_colon) => Ok(( + whole_arg.identifier(db).text(db).to_string(), + arg_value_with_colon.expr(db), + )), + } + } + StructArg::StructArgTail(_) => { + bail!("Struct unpack-init with \"..\" operator is not allowed") + } + } + }) + .collect() +} + +fn find_enum_variant_position<'a>( + variant: &String, + path: &[String], + abi: &'a [AbiEntry], +) -> Result<(usize, &'a AbiNamedMember)> { + let enums_from_abi = abi + .iter() + .filter_map(|abi_entry| { + if let AbiEntry::Enum(abi_enum) = abi_entry { + Some(abi_enum) + } else { + None + } + }) + .collect::>(); + + let enum_abi_definition = find_valid_enum_or_struct(enums_from_abi, path)?; + + let position_and_enum_variant = enum_abi_definition + .variants + .iter() + .find_position(|item| item.name == *variant) + .with_context(|| { + format!( + r#"Couldn't find variant "{}" in enum "{}""#, + variant, + path.join("::") + ) + })?; + + Ok(position_and_enum_variant) +} + +fn parse_argument_list(arguments: &ArgList, db: &SimpleParserDatabase) -> Result> { + let arguments = arguments.elements(db); + if arguments + .iter() + .map(|arg| arg.modifiers(db).elements(db)) + .any(|mod_list| !mod_list.is_empty()) + { + bail!("\"ref\" and \"mut\" modifiers are not allowed") + } + + arguments + .iter() + .map(|arg| match arg.arg_clause(db) { + ArgClause::Unnamed(expr) => Ok(expr.value(db)), + ArgClause::Named(_) => { + bail!("Named arguments are not allowed") + } + ArgClause::FieldInitShorthand(_) => { + bail!("Field init shorthands are not allowed") + } + }) + .collect::>>() +} + +fn parse_inline_macro( + invocation: &ExprInlineMacro, + db: &SimpleParserDatabase, +) -> Result> { + match invocation + .path(db) + .elements(db) + .iter() + .last() + .expect("Macro must have a name") + { + Simple(simple) => { + let macro_name = simple.ident(db).text(db); + if macro_name != "array" { + bail!( + r#"Invalid macro name, expected "array![]", got "{}""#, + macro_name + ) + } + } + PathSegment::WithGenericArgs(_) => { + bail!("Invalid path specified: generic args in array![] macro not supported") + } + }; + + let macro_arg_list = match invocation.arguments(db) { + WrappedArgList::BracketedArgList(args) => { + // TODO arglist parsing here + args.arguments(db) + } + WrappedArgList::ParenthesizedArgList(_) | WrappedArgList::BracedArgList(_) => + bail!("`array` macro supports only square brackets: array![]"), + WrappedArgList::Missing(_) => unreachable!("If any type of parentheses is missing, then diagnostics have been reported and whole flow should have already been terminated.") + }; + parse_argument_list(¯o_arg_list, db) +} + +fn find_all_structs(abi: &[AbiEntry]) -> Vec<&AbiStruct> { + abi.iter() + .filter_map(|entry| match entry { + AbiEntry::Struct(r#struct) => Some(r#struct), + _ => None, + }) + .collect() +} + +fn validate_path_argument( + param_type: &str, + path_argument: &[String], + path_argument_joined: &String, +) -> Result<()> { + if *path_argument.last().unwrap() != param_type.split("::").last().unwrap() + && path_argument_joined != param_type + { + bail!( + r#"Invalid argument type, expected "{}", got "{}""#, + param_type, + path_argument_joined + ) + } + Ok(()) +} + +trait EnumOrStruct { + const VARIANT: &'static str; + const VARIANT_CAPITALIZED: &'static str; + fn name(&self) -> String; +} + +impl EnumOrStruct for AbiStruct { + const VARIANT: &'static str = "struct"; + const VARIANT_CAPITALIZED: &'static str = "Struct"; + + fn name(&self) -> String { + self.name.clone() + } +} + +impl EnumOrStruct for AbiEnum { + const VARIANT: &'static str = "enum"; + const VARIANT_CAPITALIZED: &'static str = "Enum"; + + fn name(&self) -> String { + self.name.clone() + } +} + +// 'item' here means enum or struct +fn find_valid_enum_or_struct<'item, T: EnumOrStruct>( + items_from_abi: Vec<&'item T>, + path: &[String], +) -> Result<&'item T> { + // Argument is a module path to an item (module_name::StructName {}) + if path.len() > 1 { + let full_path_item = items_from_abi + .into_iter() + .find(|x| x.name() == path.join("::")); + + ensure!( + full_path_item.is_some(), + r#"{} "{}" not found in ABI"#, + T::VARIANT_CAPITALIZED, + path.join("::") + ); + + return Ok(full_path_item.unwrap()); + } + + // Argument is just the name of the item (Struct {}) + let mut matching_items_from_abi: Vec<&T> = items_from_abi + .into_iter() + .filter(|x| x.name().split("::").last() == path.last().map(String::as_str)) + .collect(); + + ensure!( + !matching_items_from_abi.is_empty(), + r#"{} "{}" not found in ABI"#, + T::VARIANT_CAPITALIZED, + path.join("::") + ); + + ensure!( + matching_items_from_abi.len() == 1, + r#"Found more than one {} "{}" in ABI, please specify a full path to the item"#, + T::VARIANT, + path.join("::") + ); + + Ok(matching_items_from_abi.pop().unwrap()) +} diff --git a/crates/sncast/src/helpers/data_transformer/transformer.rs b/crates/sncast/src/helpers/data_transformer/transformer.rs new file mode 100644 index 0000000000..1aa657e4ee --- /dev/null +++ b/crates/sncast/src/helpers/data_transformer/transformer.rs @@ -0,0 +1,161 @@ +use crate::helpers::data_transformer::sierra_abi::build_representation; +use anyhow::{bail, ensure, Context, Result}; +use cairo_lang_diagnostics::DiagnosticsBuilder; +use cairo_lang_filesystem::ids::{FileKind, FileLongId, VirtualFile}; +use cairo_lang_parser::parser::Parser; +use cairo_lang_parser::utils::SimpleParserDatabase; +use cairo_lang_syntax::node::ast::Expr; +use cairo_lang_utils::Intern; +use conversions::serde::serialize::SerializeToFeltVec; +use itertools::Itertools; +use starknet::core::types::contract::{AbiEntry, AbiFunction, StateMutability}; +use starknet::core::types::{ContractClass, Felt}; +use starknet::core::utils::get_selector_from_name; +use std::collections::HashMap; + +pub fn transform( + calldata: &[String], + class_definition: ContractClass, + function_selector: &Felt, +) -> Result> { + let sierra_class = match class_definition { + ContractClass::Sierra(class) => class, + ContractClass::Legacy(_) => { + bail!("Transformation of Cairo-like expressions is not available for Cairo0 contracts") + } + }; + + let abi: Vec = serde_json::from_str(sierra_class.abi.as_str()) + .context("Couldn't deserialize ABI received from chain")?; + + let selector_function_map = map_selectors_to_functions(&abi); + + let function = selector_function_map + .get(function_selector) + .with_context(|| { + format!( + r#"Function with selector "{function_selector}" not found in ABI of the contract"# + ) + })?; + + let db = SimpleParserDatabase::default(); + + let result_for_cairo_expression_input = + process_as_cairo_expressions(calldata, function, &abi, &db) + .context("Error while processing Cairo-like calldata"); + + if result_for_cairo_expression_input.is_ok() { + return result_for_cairo_expression_input; + } + + let result_for_serialized_input = process_as_serialized(calldata, &abi, &db) + .context("Error while processing serialized calldata"); + + match result_for_serialized_input { + Err(_) => result_for_cairo_expression_input, + ok => ok, + } +} + +fn process_as_cairo_expressions( + calldata: &[String], + function: &AbiFunction, + abi: &[AbiEntry], + db: &SimpleParserDatabase, +) -> Result> { + let n_inputs = function.inputs.len(); + let n_arguments = calldata.len(); + + ensure!( + n_inputs == n_arguments, + "Invalid number of arguments: passed {}, expected {}", + n_arguments, + n_inputs, + ); + + function + .inputs + .iter() + .zip(calldata) + .map(|(parameter, value)| { + let expr = parse(value, db)?; + let representation = build_representation(expr, ¶meter.r#type, abi, db)?; + Ok(representation.serialize_to_vec()) + }) + .flatten_ok() + .collect::>() +} + +fn process_as_serialized( + calldata: &[String], + abi: &[AbiEntry], + db: &SimpleParserDatabase, +) -> Result> { + calldata + .iter() + .map(|expression| { + let expr = parse(expression, db)?; + let representation = build_representation(expr, "felt252", abi, db)?; + Ok(representation.serialize_to_vec()) + }) + .flatten_ok() + .collect::>() +} + +fn map_selectors_to_functions(abi: &[AbiEntry]) -> HashMap { + let mut map = HashMap::new(); + + for abi_entry in abi { + match abi_entry { + AbiEntry::Function(func) => { + map.insert( + get_selector_from_name(func.name.as_str()).unwrap(), + func.clone(), + ); + } + AbiEntry::Constructor(constructor) => { + // Transparency of constructors and other functions + map.insert( + get_selector_from_name(constructor.name.as_str()).unwrap(), + AbiFunction { + name: constructor.name.clone(), + inputs: constructor.inputs.clone(), + outputs: vec![], + state_mutability: StateMutability::View, + }, + ); + } + AbiEntry::Interface(interface) => { + map.extend(map_selectors_to_functions(&interface.items)); + } + _ => {} + } + } + + map +} + +fn parse(source: &str, db: &SimpleParserDatabase) -> Result { + let file = FileLongId::Virtual(VirtualFile { + parent: None, + name: "parser_input".into(), + content: source.to_string().into(), + code_mappings: [].into(), + kind: FileKind::Expr, + }) + .intern(db); + + let mut diagnostics = DiagnosticsBuilder::default(); + let expression = Parser::parse_file_expr(db, &mut diagnostics, file, source); + let diagnostics = diagnostics.build(); + + if diagnostics.check_error_free().is_err() { + bail!( + "Invalid Cairo expression found in input calldata \"{}\":\n{}", + source, + diagnostics.format(db) + ) + } + + Ok(expression) +} diff --git a/crates/sncast/src/helpers/fee.rs b/crates/sncast/src/helpers/fee.rs index cbfae691c0..ddf78fa1c8 100644 --- a/crates/sncast/src/helpers/fee.rs +++ b/crates/sncast/src/helpers/fee.rs @@ -169,6 +169,22 @@ pub enum FeeSettings { }, } +impl From for FeeSettings { + fn from(value: ScriptFeeSettings) -> Self { + match value { + ScriptFeeSettings::Eth { max_fee } => FeeSettings::Eth { max_fee }, + ScriptFeeSettings::Strk { + max_gas, + max_gas_unit_price, + .. + } => FeeSettings::Strk { + max_gas, + max_gas_unit_price, + }, + } + } +} + pub trait PayableTransaction { fn error_message(&self, token: &str, version: &str) -> String; fn validate(&self) -> Result<()>; diff --git a/crates/sncast/src/helpers/mod.rs b/crates/sncast/src/helpers/mod.rs index b9ac734f9e..9453a22e45 100644 --- a/crates/sncast/src/helpers/mod.rs +++ b/crates/sncast/src/helpers/mod.rs @@ -2,6 +2,7 @@ pub mod block_explorer; pub mod braavos; pub mod configuration; pub mod constants; +pub mod data_transformer; pub mod error; pub mod fee; pub mod rpc; diff --git a/crates/sncast/src/lib.rs b/crates/sncast/src/lib.rs index 1ef493b6ea..ded59fa225 100644 --- a/crates/sncast/src/lib.rs +++ b/crates/sncast/src/lib.rs @@ -266,6 +266,19 @@ pub async fn get_account<'a>( Ok(account) } +pub async fn get_contract_class( + class_hash: Felt, + provider: &JsonRpcClient, +) -> Result { + provider + .get_class(BlockId::Tag(BlockTag::Latest), class_hash) + .await + .map_err(handle_rpc_error) + .context(format!( + "Couldn't retrieve contract class with hash: {class_hash:#x}" + )) +} + async fn build_account( account_data: AccountData, chain_id: Felt, diff --git a/crates/sncast/src/main.rs b/crates/sncast/src/main.rs index 93e5ae202d..ce218485ef 100644 --- a/crates/sncast/src/main.rs +++ b/crates/sncast/src/main.rs @@ -6,6 +6,7 @@ use crate::starknet_commands::{ }; use anyhow::{Context, Result}; use configuration::load_global_config; +use sncast::helpers::data_transformer::transformer::transform; use sncast::response::explorer_link::print_block_explorer_link_if_allowed; use sncast::response::print::{print_command_result, OutputFormat}; @@ -20,9 +21,10 @@ use sncast::helpers::scarb_utils::{ }; use sncast::response::errors::handle_starknet_command_error; use sncast::{ - chain_id_to_network_name, get_account, get_block_id, get_chain_id, get_default_state_file_name, - NumbersFormat, ValidatedWaitParams, WaitForTx, + chain_id_to_network_name, get_account, get_block_id, get_chain_id, get_class_hash_by_address, + get_contract_class, get_default_state_file_name, NumbersFormat, ValidatedWaitParams, WaitForTx, }; +use starknet::accounts::ConnectedAccount; use starknet::core::utils::get_selector_from_name; use starknet::providers::Provider; use starknet_commands::account::list::print_account_list; @@ -224,6 +226,7 @@ async fn run_async_command( let provider = deploy.rpc.get_provider(&config).await?; deploy.validate()?; + let account = get_account( &config.account, &config.accounts_file, @@ -232,9 +235,36 @@ async fn run_async_command( ) .await?; - let result = starknet_commands::deploy::deploy(deploy, &account, wait_config) - .await - .map_err(handle_starknet_command_error); + let fee_settings = deploy + .fee_args + .clone() + .fee_token(deploy.token_from_version()) + .try_into_fee_settings(&provider, account.block_id()) + .await?; + + let constructor_calldata = deploy.constructor_calldata; + + let selector = get_selector_from_name("constructor").unwrap(); + + let contract_class = get_contract_class(deploy.class_hash, &provider).await?; + + let serialized_calldata = match constructor_calldata { + Some(ref data) => transform(data, contract_class, &selector)?, + None => vec![], + }; + + let result = starknet_commands::deploy::deploy( + deploy.class_hash, + &serialized_calldata, + deploy.salt, + deploy.unique, + fee_settings, + deploy.nonce, + &account, + wait_config, + ) + .await + .map_err(handle_starknet_command_error); print_command_result("deploy", &result, numbers_format, output_format)?; print_block_explorer_link_if_allowed( @@ -252,11 +282,26 @@ async fn run_async_command( let block_id = get_block_id(&call.block_id)?; + let class_hash = get_class_hash_by_address(&provider, call.contract_address) + .await? + .with_context(|| { + format!( + "Couldn't retrieve class hash of a contract with address {:#x}", + call.contract_address + ) + })?; + + let contract_class = get_contract_class(class_hash, &provider).await?; + + let entry_point_selector = get_selector_from_name(&call.function) + .context("Failed to convert entry point selector to FieldElement")?; + + let calldata = transform(&call.calldata, contract_class, &entry_point_selector)?; + let result = starknet_commands::call::call( call.contract_address, - get_selector_from_name(&call.function) - .context("Failed to convert entry point selector to FieldElement")?, - call.calldata, + entry_point_selector, + calldata, &provider, block_id.as_ref(), ) @@ -268,10 +313,22 @@ async fn run_async_command( } Commands::Invoke(invoke) => { - let provider = invoke.rpc.get_provider(&config).await?; - invoke.validate()?; + let fee_token = invoke.token_from_version(); + + let Invoke { + contract_address, + function, + calldata, + fee_args, + rpc, + nonce, + .. + } = invoke; + + let provider = rpc.get_provider(&config).await?; + let account = get_account( &config.account, &config.accounts_file, @@ -279,10 +336,27 @@ async fn run_async_command( config.keystore, ) .await?; + + let fee_args = fee_args.fee_token(fee_token); + + let selector = get_selector_from_name(&function) + .context("Failed to convert entry point selector to FieldElement")?; + + let class_hash = get_class_hash_by_address(&provider, contract_address) + .await + .with_context(|| format!("Couldn't retrieve class hash of a contract with address {contract_address:#x}"))? + .with_context(|| format!("Couldn't retrieve class hash of a contract with address {contract_address:#x}"))?; + + let contract_class = get_contract_class(class_hash, &provider).await?; + + let calldata = transform(&calldata, contract_class, &selector)?; + let result = starknet_commands::invoke::invoke( - invoke.clone(), - get_selector_from_name(&invoke.function) - .context("Failed to convert entry point selector to FieldElement")?, + contract_address, + calldata, + nonce, + fee_args, + selector, &account, wait_config, ) diff --git a/crates/sncast/src/starknet_commands/call.rs b/crates/sncast/src/starknet_commands/call.rs index 54ae9f3e35..4e048e421f 100644 --- a/crates/sncast/src/starknet_commands/call.rs +++ b/crates/sncast/src/starknet_commands/call.rs @@ -18,9 +18,9 @@ pub struct Call { #[clap(short, long)] pub function: String, - /// Arguments of the called function (list of hex) + /// Arguments of the called function, either entirely serialized or entirely written as Cairo-like expression strings #[clap(short, long, value_delimiter = ' ', num_args = 1..)] - pub calldata: Vec, + pub calldata: Vec, /// Block identifier on which call should be performed. /// Possible values: pending, latest, block hash (0x prefixed string) @@ -45,6 +45,7 @@ pub async fn call( entry_point_selector, calldata, }; + let res = provider.call(function_call, block_id).await; match res { diff --git a/crates/sncast/src/starknet_commands/deploy.rs b/crates/sncast/src/starknet_commands/deploy.rs index b2a5ab96ce..12151f84c0 100644 --- a/crates/sncast/src/starknet_commands/deploy.rs +++ b/crates/sncast/src/starknet_commands/deploy.rs @@ -23,9 +23,9 @@ pub struct Deploy { #[clap(short = 'g', long)] pub class_hash: Felt, - /// Calldata for the contract constructor + /// Calldata for the contract constructor, either entirely serialized or entirely written as Cairo-like expression strings #[clap(short, long, value_delimiter = ' ', num_args = 1..)] - pub constructor_calldata: Vec, + pub constructor_calldata: Option>, /// Salt for the address #[clap(short, long)] @@ -61,29 +61,27 @@ impl_payable_transaction!(Deploy, token_not_supported_for_deployment, DeployVersion::V3 => FeeToken::Strk ); +#[allow(clippy::ptr_arg, clippy::too_many_arguments)] pub async fn deploy( - deploy: Deploy, + class_hash: Felt, + calldata: &Vec, + salt: Option, + unique: bool, + fee_settings: FeeSettings, + nonce: Option, account: &SingleOwnerAccount<&JsonRpcClient, LocalWallet>, wait_config: WaitForTx, ) -> Result { - let fee_settings = deploy - .fee_args - .clone() - .fee_token(deploy.token_from_version()) - .try_into_fee_settings(account.provider(), account.block_id()) - .await?; - - let salt = extract_or_generate_salt(deploy.salt); - let factory = ContractFactory::new(deploy.class_hash, account); + let salt = extract_or_generate_salt(salt); + let factory = ContractFactory::new(class_hash, account); let result = match fee_settings { FeeSettings::Eth { max_fee } => { - let execution = - factory.deploy_v1(deploy.constructor_calldata.clone(), salt, deploy.unique); + let execution = factory.deploy_v1(calldata.clone(), salt, unique); let execution = match max_fee { None => execution, Some(max_fee) => execution.max_fee(max_fee), }; - let execution = match deploy.nonce { + let execution = match nonce { None => execution, Some(nonce) => execution.nonce(nonce), }; @@ -93,8 +91,7 @@ pub async fn deploy( max_gas, max_gas_unit_price, } => { - let execution = - factory.deploy_v3(deploy.constructor_calldata.clone(), salt, deploy.unique); + let execution = factory.deploy_v3(calldata.clone(), salt, unique); let execution = match max_gas { None => execution, @@ -104,7 +101,7 @@ pub async fn deploy( None => execution, Some(max_gas_unit_price) => execution.gas_price(max_gas_unit_price), }; - let execution = match deploy.nonce { + let execution = match nonce { None => execution, Some(nonce) => execution.nonce(nonce), }; @@ -119,9 +116,9 @@ pub async fn deploy( DeployResponse { contract_address: get_udc_deployed_address( salt, - deploy.class_hash, - &udc_uniqueness(deploy.unique, account.address()), - &deploy.constructor_calldata, + class_hash, + &udc_uniqueness(unique, account.address()), + calldata, ), transaction_hash: result.transaction_hash, }, diff --git a/crates/sncast/src/starknet_commands/invoke.rs b/crates/sncast/src/starknet_commands/invoke.rs index 99558c8e47..d8ec012434 100644 --- a/crates/sncast/src/starknet_commands/invoke.rs +++ b/crates/sncast/src/starknet_commands/invoke.rs @@ -1,5 +1,6 @@ use anyhow::{anyhow, Result}; use clap::{Args, ValueEnum}; + use sncast::helpers::error::token_not_supported_for_invoke; use sncast::helpers::fee::{FeeArgs, FeeSettings, FeeToken, PayableTransaction}; use sncast::helpers::rpc::RpcArgs; @@ -24,9 +25,9 @@ pub struct Invoke { #[clap(short, long)] pub function: String, - /// Calldata for the invoked function + /// Calldata for the invoked function, either entirely serialized or entirely written as Cairo-like expression strings #[clap(short, long, value_delimiter = ' ', num_args = 1..)] - pub calldata: Vec, + pub calldata: Vec, #[clap(flatten)] pub fee_args: FeeArgs, @@ -55,23 +56,21 @@ impl_payable_transaction!(Invoke, token_not_supported_for_invoke, ); pub async fn invoke( - invoke: Invoke, + contract_address: Felt, + calldata: Vec, + nonce: Option, + fee_args: FeeArgs, function_selector: Felt, account: &SingleOwnerAccount<&JsonRpcClient, LocalWallet>, wait_config: WaitForTx, ) -> Result { - let fee_args = invoke - .fee_args - .clone() - .fee_token(invoke.token_from_version()); - let call = Call { - to: invoke.contract_address, + to: contract_address, selector: function_selector, - calldata: invoke.calldata.clone(), + calldata, }; - execute_calls(account, vec![call], fee_args, invoke.nonce, wait_config).await + execute_calls(account, vec![call], fee_args, nonce, wait_config).await } pub async fn execute_calls( diff --git a/crates/sncast/src/starknet_commands/script/run.rs b/crates/sncast/src/starknet_commands/script/run.rs index ce7220afe3..6a42273766 100644 --- a/crates/sncast/src/starknet_commands/script/run.rs +++ b/crates/sncast/src/starknet_commands/script/run.rs @@ -1,6 +1,4 @@ use crate::starknet_commands::declare::Declare; -use crate::starknet_commands::deploy::Deploy; -use crate::starknet_commands::invoke::Invoke; use crate::starknet_commands::{call, declare, deploy, invoke, tx_status}; use crate::{get_account, WaitForTx}; use anyhow::{anyhow, Context, Result}; @@ -45,6 +43,7 @@ use sncast::state::hashing::{ }; use sncast::state::state_file::StateManager; use starknet::accounts::{Account, SingleOwnerAccount}; +use starknet::core::types::Felt; use starknet::core::types::{BlockId, BlockTag::Pending}; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; @@ -104,15 +103,16 @@ impl<'a> ExtensionLogic for CastScriptExtension<'a> { "call" => { let contract_address = input_reader.read()?; let function_selector = input_reader.read()?; - let calldata_felts = input_reader.read()?; + let calldata = input_reader.read::>()?; let call_result = self.tokio_runtime.block_on(call::call( contract_address, function_selector, - calldata_felts, + calldata, self.provider, &BlockId::Tag(Pending), )); + Ok(CheatcodeHandlingResult::from_serializable(call_result)) } "declare" => { @@ -152,29 +152,19 @@ impl<'a> ExtensionLogic for CastScriptExtension<'a> { selector, &declare_result, )?; + Ok(CheatcodeHandlingResult::from_serializable(declare_result)) } "deploy" => { let class_hash = input_reader.read()?; - let constructor_calldata = input_reader.read()?; + let constructor_calldata = input_reader.read::>()?; let salt = input_reader.read()?; let unique = input_reader.read()?; - let fee_args = input_reader.read::()?.into(); + let fee_args: ScriptFeeSettings = input_reader.read()?; let nonce = input_reader.read()?; - let deploy = Deploy { - class_hash, - constructor_calldata, - salt, - unique, - fee_args, - nonce, - version: None, - rpc: RpcArgs::default(), - }; - let deploy_tx_id = - generate_deploy_tx_id(class_hash, &deploy.constructor_calldata, salt, unique); + generate_deploy_tx_id(class_hash, &constructor_calldata, salt, unique); if let Some(success_output) = self.state.get_output_if_success(deploy_tx_id.as_str()) @@ -183,7 +173,12 @@ impl<'a> ExtensionLogic for CastScriptExtension<'a> { } let deploy_result = self.tokio_runtime.block_on(deploy::deploy( - deploy, + class_hash, + &constructor_calldata, + salt, + unique, + fee_args.into(), + nonce, self.account()?, WaitForTx { wait: true, @@ -202,20 +197,12 @@ impl<'a> ExtensionLogic for CastScriptExtension<'a> { "invoke" => { let contract_address = input_reader.read()?; let function_selector = input_reader.read()?; - let calldata: Vec<_> = input_reader.read()?; + + let calldata = input_reader.read::>()?; + let fee_args = input_reader.read::()?.into(); let nonce = input_reader.read()?; - let invoke = Invoke { - contract_address, - function: String::new(), - calldata: calldata.clone(), - fee_args, - nonce, - version: None, - rpc: RpcArgs::default(), - }; - let invoke_tx_id = generate_invoke_tx_id(contract_address, function_selector, &calldata); @@ -226,7 +213,10 @@ impl<'a> ExtensionLogic for CastScriptExtension<'a> { } let invoke_result = self.tokio_runtime.block_on(invoke::invoke( - invoke, + contract_address, + calldata, + nonce, + fee_args, function_selector, self.account()?, WaitForTx { diff --git a/crates/sncast/tests/e2e/call.rs b/crates/sncast/tests/e2e/call.rs index 72f401820f..d651e18ecb 100644 --- a/crates/sncast/tests/e2e/call.rs +++ b/crates/sncast/tests/e2e/call.rs @@ -78,14 +78,11 @@ async fn test_contract_does_not_exist() { ]; let snapbox = runner(&args); - let output = snapbox.assert().success(); + let output = snapbox.assert().failure(); assert_stderr_contains( output, - indoc! {r" - command: call - error: There is no contract at the specified address - "}, + r"Error: Couldn't retrieve class hash of a contract with address 0x1", ); } @@ -104,14 +101,11 @@ fn test_wrong_function_name() { ]; let snapbox = runner(&args); - let output = snapbox.assert().success(); + let output = snapbox.assert().failure(); assert_stderr_contains( output, - indoc! {r" - command: call - error: An error occurred [..]Entry point[..]not found in contract[..] - "}, + r#"Error: Function with selector "[..]" not found in ABI of the contract"#, ); } diff --git a/crates/sncast/tests/e2e/declare.rs b/crates/sncast/tests/e2e/declare.rs index cc382b80d5..f0ecc8e24c 100644 --- a/crates/sncast/tests/e2e/declare.rs +++ b/crates/sncast/tests/e2e/declare.rs @@ -7,7 +7,7 @@ use crate::helpers::fixtures::{ use crate::helpers::runner::runner; use configuration::CONFIG_FILENAME; use indoc::indoc; -use shared::test_utils::output_assert::{assert_stderr_contains, assert_stdout_contains}; +use shared::test_utils::output_assert::{assert_stderr_contains, assert_stdout_contains, AsOutput}; use sncast::helpers::constants::{ARGENT_CLASS_HASH, BRAAVOS_CLASS_HASH, OZ_CLASS_HASH}; use sncast::AccountType; use starknet::core::types::Felt; @@ -255,7 +255,11 @@ async fn test_happy_case_specify_package() { let snapbox = runner(&args).current_dir(tempdir.path()); - let output = snapbox.assert().success().get_output().stdout.clone(); + let output = snapbox.assert().success(); + + println!("{}\n{}", output.as_stdout(), output.as_stderr()); + + let output = output.get_output().stdout.clone(); let hash = get_transaction_hash(&output); let receipt = get_transaction_receipt(hash).await; diff --git a/crates/sncast/tests/e2e/deploy.rs b/crates/sncast/tests/e2e/deploy.rs index c95c62da99..57ddb98a3b 100644 --- a/crates/sncast/tests/e2e/deploy.rs +++ b/crates/sncast/tests/e2e/deploy.rs @@ -8,7 +8,7 @@ use crate::helpers::fixtures::{ }; use crate::helpers::runner::runner; use indoc::indoc; -use shared::test_utils::output_assert::{assert_stderr_contains, assert_stdout_contains}; +use shared::test_utils::output_assert::{assert_stderr_contains, assert_stdout_contains, AsOutput}; use sncast::helpers::constants::{ARGENT_CLASS_HASH, BRAAVOS_CLASS_HASH, OZ_CLASS_HASH}; use sncast::AccountType; use starknet::core::types::{Felt, TransactionReceipt::Deploy}; @@ -298,17 +298,19 @@ fn test_wrong_calldata() { "--class-hash", CONSTRUCTOR_WITH_PARAMS_CONTRACT_CLASS_HASH_SEPOLIA, "--constructor-calldata", - "0x1 0x1", + "0x1 0x2 0x3 0x4", ]; let snapbox = runner(&args); let output = snapbox.assert().success(); + println!("{}\n{}", output.as_stdout(), output.as_stderr()); + assert_stderr_contains( output, indoc! {r" command: deploy - error: An error occurred in the called contract[..]Failed to deserialize param #2[..] + error: An error occurred in the called contract[..]('Input too long for arguments')[..] "}, ); } @@ -330,13 +332,12 @@ async fn test_contract_not_declared() { ]; let snapbox = runner(&args); - let output = snapbox.assert().success(); + let output = snapbox.assert().failure(); assert_stderr_contains( output, indoc! {r" - command: deploy - error: An error occurred in the called contract[..]Class with hash[..]is not declared[..] + Error: Couldn't retrieve contract class with hash: 0x1 "}, ); } diff --git a/crates/sncast/tests/e2e/invoke.rs b/crates/sncast/tests/e2e/invoke.rs index 555a88bdf0..e7f5ac3a50 100644 --- a/crates/sncast/tests/e2e/invoke.rs +++ b/crates/sncast/tests/e2e/invoke.rs @@ -278,13 +278,12 @@ async fn test_contract_does_not_exist() { ]; let snapbox = runner(&args); - let output = snapbox.assert().success(); + let output = snapbox.assert().failure(); assert_stderr_contains( output, indoc! {r" - command: invoke - error: An error occurred in the called contract[..]Requested contract address[..]is not deployed[..] + Error: Couldn't retrieve class hash of a contract with address 0x1 "}, ); } @@ -308,14 +307,11 @@ fn test_wrong_function_name() { ]; let snapbox = runner(&args); - let output = snapbox.assert().success(); + let output = snapbox.assert().failure(); assert_stderr_contains( output, - indoc! {r" - command: invoke - error: An error occurred in the called contract[..]Entry point[..]not found in contract[..] - "}, + r#"Error: Function with selector "[..]" not found in ABI of the contract"#, ); } diff --git a/crates/sncast/tests/e2e/main_tests.rs b/crates/sncast/tests/e2e/main_tests.rs index 65fd34651e..d1a3738d35 100644 --- a/crates/sncast/tests/e2e/main_tests.rs +++ b/crates/sncast/tests/e2e/main_tests.rs @@ -27,14 +27,11 @@ async fn test_happy_case_from_sncast_config() { ]; let snapbox = runner(&args).current_dir(tempdir.path()); - let output = snapbox.assert().success(); + let output = snapbox.assert().failure(); assert_stderr_contains( output, - indoc! {r" - command: call - error: There is no contract at the specified address - "}, + "Error: Couldn't retrieve class hash of a contract with address 0x0", ); } @@ -55,14 +52,11 @@ async fn test_happy_case_from_cli_no_scarb() { ]; let snapbox = runner(&args); - let output = snapbox.assert().success(); + let output = snapbox.assert().failure(); assert_stderr_contains( output, - indoc! {r" - command: call - error: There is no contract at the specified address - "}, + "Error: Couldn't retrieve class hash of a contract with address 0x0", ); } diff --git a/crates/sncast/tests/integration/data_transformer.rs b/crates/sncast/tests/integration/data_transformer.rs new file mode 100644 index 0000000000..8b6f0101e1 --- /dev/null +++ b/crates/sncast/tests/integration/data_transformer.rs @@ -0,0 +1,1022 @@ +use core::fmt; +use itertools::Itertools; +use primitive_types::U256; +use shared::rpc::create_rpc_client; +use sncast::helpers::data_transformer::transformer::transform; +use starknet::core::types::{BlockId, BlockTag, ContractClass, Felt}; +use starknet::core::utils::get_selector_from_name; +use starknet::providers::Provider; +use std::ops::Not; +use test_case::test_case; +use tokio::sync::OnceCell; + +const RPC_ENDPOINT: &str = "http://188.34.188.184:7070/rpc/v0_7"; + +// https://sepolia.starkscan.co/class/0x02a9b456118a86070a8c116c41b02e490f3dcc9db3cad945b4e9a7fd7cec9168#code +const TEST_CLASS_HASH: Felt = + Felt::from_hex_unchecked("0x02a9b456118a86070a8c116c41b02e490f3dcc9db3cad945b4e9a7fd7cec9168"); + +static CLASS: OnceCell = OnceCell::const_new(); + +// 2^128 + 3 +// const BIG_NUMBER: &str = "340282366920938463463374607431768211459"; + +async fn init_class() -> ContractClass { + let client = create_rpc_client(RPC_ENDPOINT).unwrap(); + + client + .get_class(BlockId::Tag(BlockTag::Latest), TEST_CLASS_HASH) + .await + .unwrap() +} + +trait Contains { + fn assert_contains(&self, value: T); +} + +impl Contains<&str> for anyhow::Error { + fn assert_contains(&self, value: &str) { + self.chain() + .any(|err| err.to_string().contains(value)) + .not() + .then(|| panic!("{value:?}\nnot found in\n{self:#?}")); + } +} + +#[tokio::test] +async fn test_function_not_found() { + let simulated_cli_input = vec![String::from("'some_felt'")]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("nonexistent_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector); + + assert!(output.is_err()); + output.unwrap_err().assert_contains( + format!(r#"Function with selector "{selector}" not found in ABI of the contract"#).as_str(), + ); +} + +#[tokio::test] +async fn test_happy_case_numeric_type_suffix() -> anyhow::Result<()> { + let simulated_cli_input = vec![String::from("1010101_u32")]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("unsigned_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector)?; + + assert_eq!(output, vec![Felt::from(1_010_101_u32)]); + + Ok(()) +} + +#[tokio::test] +async fn test_invalid_numeric_type_suffix() { + let simulated_cli_input = vec![String::from("1_u10")]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("simple_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector); + + assert!(output.is_err()); + + output + .unwrap_err() + .assert_contains(r#"Failed to parse value "1" into type "u10": unsupported type u10"#); +} + +#[tokio::test] +async fn test_invalid_cairo_expression() { + let simulated_cli_input = vec![String::from("some_invalid_expression:")]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("simple_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector); + + output + .unwrap_err() + .assert_contains("Invalid Cairo expression found in input calldata"); +} + +#[tokio::test] +async fn test_invalid_argument_number() { + let simulated_cli_input = vec!["0x123", "'some_obsolete_argument'", "10"] + .into_iter() + .map(String::from) + .collect_vec(); + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("simple_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector); + + output + .unwrap_err() + .assert_contains("Invalid number of arguments: passed 3, expected 1"); +} + +#[tokio::test] +async fn test_happy_case_simple_cairo_expressions_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("100")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_fn").unwrap(), + )?; + + let expected_output: Vec = vec![100.into()]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_simple_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x64")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_fn").unwrap(), + )?; + + let expected_output: Vec = vec![100.into()]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_u256_function_cairo_expressions_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![U256::MAX.to_string()]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("u256_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0xffffffffffffffffffffffffffffffff"), + Felt::from_hex_unchecked("0xffffffffffffffffffffffffffffffff"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_u256_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x2137"), String::from("0x0")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("u256_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x2137"), + Felt::from_hex_unchecked("0x0"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_signed_function_cairo_expressions_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("-273")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("signed_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from(-273i16)]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_signed_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![Felt::from(-273i16).to_hex_string()]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("signed_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from(-273i16)]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +// Problem: Although transformer fails to process the given input as `i32`, itthen succeeds to interpret it as `felt252` +// Overflow checks will not work for functions having the same serialized and Cairo-like calldata length. +// User must provide a type suffix or get the invoke-time error +#[ignore = "Impossible to pass with the current solution"] +#[tokio::test] +async fn test_signed_fn_overflow() { + let simulated_cli_input = vec![(i32::MAX as u64 + 1).to_string()]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("signed_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector); + + output + .unwrap_err() + .assert_contains(r#"Failed to parse value "2147483648" into type "i32""#); +} + +#[tokio::test] +async fn test_signed_fn_overflow_with_type_suffix() { + let simulated_cli_input = vec![format!("{}_i32", i32::MAX as u64 + 1)]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("signed_fn").unwrap(); + + let result = transform(&simulated_cli_input, contract_class, &selector); + + result + .unwrap_err() + .assert_contains(r#"Failed to parse value "2147483648" into type "i32""#); +} + +#[tokio::test] +async fn test_happy_case_unsigned_function_cairo_expressions_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![u32::MAX.to_string()]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("unsigned_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from(u32::MAX)]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_unsigned_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![Felt::from(u32::MAX).to_hex_string()]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("unsigned_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from(u32::MAX)]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_tuple_function_cairo_expression_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("(2137_felt252, 1_u8, Enum::One)")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("tuple_fn").unwrap(), + )?; + + let expected_output: Vec = vec![2137.into(), 1.into(), 0.into()]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_tuple_function_with_nested_struct_cairo_expression_input( +) -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from( + "(123, 234, Enum::Three(NestedStructWithField {a: SimpleStruct {a: 345}, b: 456 }))", + )]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("tuple_fn").unwrap(), + )?; + + let expected_output: Vec = vec![123, 234, 2, 345, 456] + .into_iter() + .map(Felt::from) + .collect(); + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_tuple_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![ + String::from("0x859"), + String::from("0x1"), + String::from("0x0"), + ]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("tuple_fn").unwrap(), + )?; + + let expected_output: Vec = vec![2137.into(), 1.into(), 0.into()]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_complex_function_cairo_expressions_input() -> anyhow::Result<()> { + let max_u256 = U256::max_value().to_string(); + + let simulated_cli_input = vec![ + "array![array![0x2137, 0x420], array![0x420, 0x2137]]", + "8_u8", + "-270", + "\"some_string\"", + "(0x69, 100)", + "true", + &max_u256, + ] + .into_iter() + .map(String::from) + .collect_vec(); + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let result = transform( + &simulated_cli_input, + contract_class, + &get_selector_from_name("complex_fn").unwrap(), + )?; + + // Manually serialized in Cairo + let expected_output: Vec = vec![ + "2", + "2", + "8503", + "1056", + "2", + "1056", + "8503", + "8", + "3618502788666131213697322783095070105623107215331596699973092056135872020211", + "0", + "139552669935068984642203239", + "11", + "105", + "100", + "1", + "340282366920938463463374607431768211455", + "340282366920938463463374607431768211455", + ] + .into_iter() + .map(Felt::from_dec_str) + .collect::>() + .unwrap(); + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_complex_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + // Input identical to `[..]complex_function_cairo_expressions_input` + let input: Vec = [ + "0x2", + "0x2", + "0x2137", + "0x420", + "0x2", + "0x420", + "0x2137", + "0x8", + "0x800000000000010fffffffffffffffffffffffffffffffffffffffffffffef3", + "0x0", + "0x736f6d655f737472696e67", + "0xb", + "0x69", + "0x64", + "0x1", + "0xffffffffffffffffffffffffffffffff", + "0xffffffffffffffffffffffffffffffff", + ] + .into_iter() + .map(String::from) + .collect(); + + let result = transform( + &input, + contract_class, + &get_selector_from_name("complex_fn").unwrap(), + )?; + + let expected_output: Vec = vec![ + "2", + "2", + "8503", + "1056", + "2", + "1056", + "8503", + "8", + "3618502788666131213697322783095070105623107215331596699973092056135872020211", + "0", + "139552669935068984642203239", + "11", + "105", + "100", + "1", + "340282366920938463463374607431768211455", + "340282366920938463463374607431768211455", + ] + .into_iter() + .map(Felt::from_dec_str) + .collect::>() + .unwrap(); + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_simple_struct_function_cairo_expression_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("SimpleStruct {a: 0x12}")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from_hex_unchecked("0x12")]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_simple_struct_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x12")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from_hex_unchecked("0x12")]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_simple_struct_function_invalid_struct_argument() { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from(r#"SimpleStruct {a: "string"}"#)]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + ); + + result + .unwrap_err() + .assert_contains(r#"Failed to parse value "string" into type "core::felt252""#); +} + +#[tokio::test] +async fn test_simple_struct_function_invalid_struct_name() { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("InvalidStructName {a: 0x10}")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + ); + + result + .unwrap_err() + .assert_contains(r#"Invalid argument type, expected "data_transformer_contract::SimpleStruct", got "InvalidStructName""#); +} + +#[test_case(r#""string_argument""#, r#"Failed to parse value "string_argument" into type "data_transformer_contract::SimpleStruct""# ; "string")] +#[test_case("'shortstring'", r#"Failed to parse value "shortstring" into type "data_transformer_contract::SimpleStruct""# ; "shortstring")] +#[test_case("true", r#"Failed to parse value "true" into type "data_transformer_contract::SimpleStruct""# ; "bool")] +#[test_case("array![0x1, 2, 0x3, 04]", r#"Invalid argument type, expected "data_transformer_contract::SimpleStruct", got array"# ; "array")] +#[test_case("(1, array![2], 0x3)", r#"Invalid argument type, expected "data_transformer_contract::SimpleStruct", got tuple"# ; "tuple")] +#[test_case("My::Enum", r#"Invalid argument type, expected "data_transformer_contract::SimpleStruct", got "My""# ; "enum_variant")] +#[test_case("core::path::My::Enum(10)", r#"Invalid argument type, expected "data_transformer_contract::SimpleStruct", got "core::path::My""# ; "enum_variant_with_path")] +#[tokio::test] +async fn test_simple_struct_function_cairo_expression_input_invalid_argument_type( + input: &str, + error_message: &str, +) { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![input.to_string()]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + ); + + result.unwrap_err().assert_contains(error_message); +} + +#[tokio::test] +async fn test_happy_case_nested_struct_function_cairo_expression_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from( + "NestedStructWithField { a: SimpleStruct { a: 0x24 }, b: 96 }", + )]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("nested_struct_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x24"), + Felt::from_hex_unchecked("0x60"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_nested_struct_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x24"), String::from("0x60")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x24"), + Felt::from_hex_unchecked("0x60"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_empty_variant_cairo_expression_input() -> anyhow::Result<()> +{ + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("Enum::One")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![Felt::ZERO]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_empty_variant_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x0")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![Felt::ZERO]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_one_argument_variant_cairo_expression_input( +) -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("Enum::Two(128)")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x80"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_one_argument_variant_serialized_input() -> anyhow::Result<()> +{ + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x1"), String::from("0x80")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x80"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_nested_struct_variant_cairo_expression_input( +) -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from( + "Enum::Three(NestedStructWithField { a: SimpleStruct { a: 123 }, b: 234 })", + )]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x2"), + Felt::from_hex_unchecked("0x7b"), + Felt::from_hex_unchecked("0xea"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_nested_struct_variant_serialized_input() -> anyhow::Result<()> +{ + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![ + String::from("0x2"), + String::from("0x7b"), + String::from("0xea"), + ]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x2"), + Felt::from_hex_unchecked("0x7b"), + Felt::from_hex_unchecked("0xea"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_enum_function_invalid_variant_cairo_expression_input() { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("Enum::InvalidVariant")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + ); + + result + .unwrap_err() + .assert_contains(r#"Couldn't find variant "InvalidVariant" in enum "Enum""#); +} + +#[tokio::test] +async fn test_happy_case_complex_struct_function_cairo_expression_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let data = concat!( + r#"ComplexStruct {"#, + r#" a: NestedStructWithField {"#, + r#" a: SimpleStruct { a: 1 },"#, + r#" b: 2"#, + r#" },"#, + r#" b: 3, c: 4, d: 5,"#, + r#" e: Enum::Two(6),"#, + r#" f: "seven","#, + r#" g: array![8, 9],"#, + r#" h: 10, i: (11, 12)"#, + r#"}"#, + ); + + let input = vec![String::from(data)]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("complex_struct_fn").unwrap(), + )?; + + let expected_output = vec![ + // a: NestedStruct + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x2"), + // b: felt252 + Felt::from_hex_unchecked("0x3"), + // c: u8 + Felt::from_hex_unchecked("0x4"), + // d: i32 + Felt::from_hex_unchecked("0x5"), + // e: Enum + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x6"), + // f: ByteArray + Felt::from_hex_unchecked("0x0"), + Felt::from_hex_unchecked("0x736576656e"), + Felt::from_hex_unchecked("0x5"), + // g: Array + Felt::from_hex_unchecked("0x2"), + Felt::from_hex_unchecked("0x8"), + Felt::from_hex_unchecked("0x9"), + // h: u256 + Felt::from_hex_unchecked("0xa"), + Felt::from_hex_unchecked("0x0"), + // i: (i128, u128) + Felt::from_hex_unchecked("0xb"), + Felt::from_hex_unchecked("0xc"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_complex_struct_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let felts = vec![ + // a: NestedStruct + "0x1", + "0x2", + // b: felt252 + "0x3", + // c: u8 + "0x4", + // d: i32 + "0x5", + // e: Enum + "0x1", + "0x6", + // f: ByteArray + "0x0", + "0x736576656e", + "0x5", + // g: Array + "0x2", + "0x8", + "0x9", + // h: u256 + "0xa", + "0x0", + // i: (i128, u128) + "0xb", + "0xc", + ]; + + let input = felts.clone().into_iter().map(String::from).collect_vec(); + + let result = transform( + &input, + contract_class, + &get_selector_from_name("complex_struct_fn").unwrap(), + )?; + + let expected_output = felts + .into_iter() + .map(Felt::from_hex_unchecked) + .collect_vec(); + + assert_eq!(result, expected_output); + + Ok(()) +} + +// TODO add similar test but with enums +// - take existing contract code +// - find/create a library with an enum +// - add to project as a dependency +// - create enum with the same name in your contract code +#[tokio::test] +async fn test_external_struct_function_ambiguous_struct_name_cairo_expression_input() { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![ + String::from("BitArray { bit: 23 }"), + String::from("BitArray { data: array![0], current: 1, read_pos: 2, write_pos: 3 }"), + ]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("external_struct_fn").unwrap(), + ); + + result.unwrap_err().assert_contains( + r#"Found more than one struct "BitArray" in ABI, please specify a full path to the item"#, + ); +} + +#[tokio::test] +async fn test_happy_case_external_struct_function_cairo_expression_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![ + String::from("data_transformer_contract::BitArray { bit: 23 }"), + String::from("alexandria_data_structures::bit_array::BitArray { data: array![0], current: 1, read_pos: 2, write_pos: 3 }") + ]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("external_struct_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x17"), + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x0"), + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x2"), + Felt::from_hex_unchecked("0x3"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_external_struct_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let felts = vec!["0x17", "0x1", "0x0", "0x1", "0x2", "0x3"]; + + let input = felts.clone().into_iter().map(String::from).collect_vec(); + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = felts + .into_iter() + .map(Felt::from_hex_unchecked) + .collect_vec(); + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_external_struct_function_invalid_path_to_external_struct() { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![ + String::from("something::BitArray { bit: 23 }"), + String::from("BitArray { data: array![0], current: 1, read_pos: 2, write_pos: 3 }"), + ]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("external_struct_fn").unwrap(), + ); + + result + .unwrap_err() + .assert_contains(r#"Struct "something::BitArray" not found in ABI"#); +} + +#[tokio::test] +async fn test_happy_case_contract_constructor() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x123")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("constructor").unwrap(), + )?; + + let expected_output = vec![Felt::from_hex_unchecked("0x123")]; + + assert_eq!(result, expected_output); + + Ok(()) +} diff --git a/crates/sncast/tests/integration/mod.rs b/crates/sncast/tests/integration/mod.rs index ccbb1d5946..8525e9a9f6 100644 --- a/crates/sncast/tests/integration/mod.rs +++ b/crates/sncast/tests/integration/mod.rs @@ -1,3 +1,4 @@ +pub mod data_transformer; mod fee; mod lib_tests; mod wait_for_tx;