diff --git a/crates/dyn-abi/src/eip712/typed_data.rs b/crates/dyn-abi/src/eip712/typed_data.rs index 8671f0775..12c53dc3c 100644 --- a/crates/dyn-abi/src/eip712/typed_data.rs +++ b/crates/dyn-abi/src/eip712/typed_data.rs @@ -538,7 +538,7 @@ mod tests { let typed_data: TypedData = serde_json::from_value(json).unwrap(); - assert_eq!(typed_data.eip712_signing_hash(), Err(Error::CircularDependency("Mail".into())),); + assert_eq!(typed_data.eip712_signing_hash(), Err(Error::CircularDependency("Mail".into()))); } #[test] @@ -677,7 +677,7 @@ mod tests { let s = MyStruct { name: "hello".to_string(), otherThing: "world".to_string() }; let typed_data = TypedData::from_struct(&s, None); - assert_eq!(typed_data.encode_type().unwrap(), "MyStruct(string name,string otherThing)",); + assert_eq!(typed_data.encode_type().unwrap(), "MyStruct(string name,string otherThing)"); assert!(typed_data.resolver.contains_type_name("EIP712Domain")); } diff --git a/crates/json-abi/src/abi.rs b/crates/json-abi/src/abi.rs index 55d24c627..ffed8f1d2 100644 --- a/crates/json-abi/src/abi.rs +++ b/crates/json-abi/src/abi.rs @@ -208,20 +208,8 @@ impl JsonAbi { /// /// See [`to_sol`](JsonAbi::to_sol) for more information. pub fn to_sol_raw(&self, name: &str, out: &mut String, config: Option) { - let len = self.len(); - out.reserve(len * 128); - - out.push_str("interface "); - if !name.is_empty() { - out.push_str(name); - out.push(' '); - } - out.push('{'); - if len > 0 { - out.push('\n'); - SolPrinter::new(out, config.unwrap_or_default()).print(self); - } - out.push('}'); + out.reserve(self.len() * 128); + SolPrinter::new(out, name, config.unwrap_or_default()).print(self); } /// Deduplicates all functions, errors, and events which have the same name and inputs. diff --git a/crates/json-abi/src/to_sol.rs b/crates/json-abi/src/to_sol.rs index 0a6de1d31..472164657 100644 --- a/crates/json-abi/src/to_sol.rs +++ b/crates/json-abi/src/to_sol.rs @@ -2,7 +2,11 @@ use crate::{ item::{Constructor, Error, Event, Fallback, Function, Receive}, EventParam, InternalType, JsonAbi, Param, StateMutability, }; -use alloc::{collections::BTreeSet, string::String, vec::Vec}; +use alloc::{ + collections::{BTreeMap, BTreeSet}, + string::String, + vec::Vec, +}; use core::{ cmp::Ordering, ops::{Deref, DerefMut}, @@ -13,6 +17,7 @@ use core::{ #[allow(missing_copy_implementations)] // Future-proofing pub struct ToSolConfig { print_constructors: bool, + enums_as_udvt: bool, } impl Default for ToSolConfig { @@ -26,7 +31,7 @@ impl ToSolConfig { /// Creates a new configuration with default settings. #[inline] pub const fn new() -> Self { - Self { print_constructors: false } + Self { print_constructors: false, enums_as_udvt: true } } /// Sets whether to print constructors. Default: `false`. @@ -35,6 +40,14 @@ impl ToSolConfig { self.print_constructors = yes; self } + + /// Sets whether to print `enum`s as user-defined value types (UDVTs) instead of `uint8`. + /// Default: `true`. + #[inline] + pub const fn enums_as_udvt(mut self, yes: bool) -> Self { + self.enums_as_udvt = yes; + self + } } pub(crate) trait ToSol { @@ -45,9 +58,12 @@ pub(crate) struct SolPrinter<'a> { /// The buffer to write to. s: &'a mut String, + /// The name of the current library/interface being printed. + name: &'a str, + /// Whether to emit `memory` when printing parameters. /// This is set to `true` when printing functions so that we emit valid Solidity. - emit_param_location: bool, + print_param_location: bool, /// Configuration. config: ToSolConfig, @@ -70,26 +86,22 @@ impl DerefMut for SolPrinter<'_> { } impl<'a> SolPrinter<'a> { - #[inline] - pub(crate) fn new(s: &'a mut String, config: ToSolConfig) -> Self { - Self { s, emit_param_location: false, config } + pub(crate) fn new(s: &'a mut String, name: &'a str, config: ToSolConfig) -> Self { + Self { s, name, print_param_location: false, config } } - #[inline] - pub(crate) fn print(&mut self, value: &T) { - value.to_sol(self); + pub(crate) fn print(&mut self, abi: &'a JsonAbi) { + abi.to_sol_root(self); } - #[inline] fn indent(&mut self) { self.push_str(" "); } } -impl ToSol for JsonAbi { +impl JsonAbi { #[allow(unknown_lints, for_loops_over_fallibles)] - #[inline] - fn to_sol(&self, out: &mut SolPrinter<'_>) { + fn to_sol_root<'a>(&'a self, out: &mut SolPrinter<'a>) { macro_rules! fmt { ($iter:expr) => { let mut any = false; @@ -105,9 +117,35 @@ impl ToSol for JsonAbi { }; } - let mut its = InternalTypes::new(); + let mut its = InternalTypes::new(out.name, out.config.enums_as_udvt); its.visit_abi(self); - fmt!(its.0); + + for (name, its) in &its.other { + if its.is_empty() { + continue; + } + out.push_str("library "); + out.push_str(name); + out.push_str(" {\n"); + let prev = core::mem::replace(&mut out.name, name); + for it in its { + out.indent(); + it.to_sol(out); + out.push('\n'); + } + out.name = prev; + out.push_str("}\n\n"); + } + + out.push_str("interface "); + if !out.name.is_empty() { + out.s.push_str(out.name); + out.push(' '); + } + out.push('{'); + out.push('\n'); + + fmt!(its.this_its); fmt!(self.errors()); fmt!(self.events()); if out.config.print_constructors { @@ -117,17 +155,23 @@ impl ToSol for JsonAbi { fmt!(self.receive); fmt!(self.functions()); out.pop(); // trailing newline + + out.push('}'); } } /// Recursively collects internal structs, enums, and UDVTs from an ABI's items. -struct InternalTypes<'a>(BTreeSet>); +struct InternalTypes<'a> { + name: &'a str, + this_its: BTreeSet>, + other: BTreeMap<&'a String, BTreeSet>>, + enums_as_udvt: bool, +} impl<'a> InternalTypes<'a> { #[allow(clippy::missing_const_for_fn)] - #[inline] - fn new() -> Self { - Self(BTreeSet::new()) + fn new(name: &'a str, enums_as_udvt: bool) -> Self { + Self { name, this_its: BTreeSet::new(), other: BTreeMap::new(), enums_as_udvt } } fn visit_abi(&mut self, abi: &'a JsonAbi) { @@ -176,22 +220,37 @@ impl<'a> InternalTypes<'a> { ) { match internal_type { None | Some(InternalType::AddressPayable(_) | InternalType::Contract(_)) => {} - Some(InternalType::Struct { contract: _, ty }) => { - self.0.insert(It::new(ty, ItKind::Struct(components))); + Some(InternalType::Struct { contract, ty }) => { + self.extend_one(contract, It::new(ty, ItKind::Struct(components))); } - Some(InternalType::Enum { contract: _, ty }) => { - self.0.insert(It::new(ty, ItKind::Enum)); + Some(InternalType::Enum { contract, ty }) => { + if self.enums_as_udvt { + self.extend_one(contract, It::new(ty, ItKind::Enum)); + } } - Some(it @ InternalType::Other { contract: _, ty }) => { + Some(it @ InternalType::Other { contract, ty }) => { // `Other` is a UDVT if it's not a basic Solidity type and not an array if let Some(it) = it.other_specifier() { if it.try_basic_solidity().is_err() && !it.is_array() { - self.0.insert(It::new(ty, ItKind::Udvt(real_ty))); + self.extend_one(contract, It::new(ty, ItKind::Udvt(real_ty))); } } } } } + + fn extend_one(&mut self, contract: &'a Option, it: It<'a>) { + let contract = contract.as_ref(); + if let Some(contract) = contract { + if contract == self.name { + self.this_its.insert(it); + } else { + self.other.entry(contract).or_default().insert(it); + } + } else { + self.this_its.insert(it); + } + } } /// An internal ABI type. @@ -419,7 +478,7 @@ impl ToSol for AbiFunction<'_, IN> { self.kw, AbiFunctionKw::Function | AbiFunctionKw::Fallback | AbiFunctionKw::Receive ) { - out.emit_param_location = true; + out.print_param_location = true; } out.push_str(self.kw.as_str()); @@ -466,7 +525,7 @@ impl ToSol for AbiFunction<'_, IN> { out.push(';'); - out.emit_param_location = false; + out.print_param_location = false; } } @@ -497,22 +556,25 @@ fn param( components: &[Param], out: &mut SolPrinter<'_>, ) { + let mut contract_name = None::<&str>; let mut type_name = type_name; let storage; if let Some(it) = internal_type { - type_name = match it { + (contract_name, type_name) = match it { InternalType::Contract(s) => { - if let Some(start) = s.find('[') { + let ty = if let Some(start) = s.find('[') { storage = format!("address{}", &s[start..]); &storage } else { "address" - } + }; + (None, ty) } - InternalType::AddressPayable(ty) - | InternalType::Struct { ty, .. } - | InternalType::Enum { ty, .. } - | InternalType::Other { ty, .. } => ty, + InternalType::Enum { .. } if !out.config.enums_as_udvt => (None, "uint8"), + InternalType::AddressPayable(ty) => (None, &ty[..]), + InternalType::Struct { contract, ty } + | InternalType::Enum { contract, ty } + | InternalType::Other { contract, ty } => (contract.as_deref(), &ty[..]), }; }; @@ -525,7 +587,7 @@ fn param( // tuple types `(T, U, V, ...)`, but it's valid for `sol!`. out.push('('); // Don't emit `memory` for tuple components because `sol!` can't parse them. - let prev = core::mem::replace(&mut out.emit_param_location, false); + let prev = core::mem::replace(&mut out.print_param_location, false); for (i, component) in components.iter().enumerate() { if i > 0 { out.push_str(", "); @@ -539,7 +601,7 @@ fn param( out, ); } - out.emit_param_location = prev; + out.print_param_location = prev; // trailing comma for single-element tuples if components.len() == 1 { out.push(','); @@ -549,7 +611,15 @@ fn param( out.push_str(rest); } // primitive type - _ => out.push_str(type_name), + _ => { + if let Some(contract_name) = contract_name { + if contract_name != out.name { + out.push_str(contract_name); + out.push('.'); + } + } + out.push_str(type_name); + } } // add `memory` if required (functions) @@ -558,7 +628,7 @@ fn param( "bytes" | "string" => true, s => s.ends_with(']') || !components.is_empty(), }; - if out.emit_param_location && is_memory { + if out.print_param_location && is_memory { out.push_str(" memory"); } diff --git a/crates/json-abi/tests/abi.rs b/crates/json-abi/tests/abi.rs index edaa79711..544b8c7e0 100644 --- a/crates/json-abi/tests/abi.rs +++ b/crates/json-abi/tests/abi.rs @@ -101,21 +101,25 @@ fn to_sol_test(path: &str, abi: &JsonAbi, run_solc: bool) { } if run_solc { - let out = Command::new("solc").arg("--abi").arg(&sol_path).output().unwrap(); + let out = Command::new("solc").arg("--combined-json=abi").arg(&sol_path).output().unwrap(); let stdout = String::from_utf8_lossy(&out.stdout); let stderr = String::from_utf8_lossy(&out.stderr); let panik = |s| -> ! { panic!("{s}\n\nstdout:\n{stdout}\n\nstderr:\n{stderr}") }; if !out.status.success() { panik("solc failed"); } - let Some(json_str_start) = stdout.find("[{") else { - panik("no JSON"); - }; - let json_str = &stdout[json_str_start..]; - let solc_abi = match serde_json::from_str::(json_str) { - Ok(solc_abi) => solc_abi, + let combined_json = match serde_json::from_str::(stdout.trim()) { + Ok(j) => j, Err(e) => panik(&format!("invalid JSON: {e}")), }; + let (_, contract) = combined_json["contracts"] + .as_object() + .unwrap() + .iter() + .find(|(k, _)| k.contains(&format!(":{name}"))) + .unwrap(); + let solc_abi_str = serde_json::to_string(&contract["abi"]).unwrap(); + let solc_abi: JsonAbi = serde_json::from_str(&solc_abi_str).unwrap(); // Note that we don't compare the ABIs directly since the conversion is lossy, e.g. // `internalType` fields change. diff --git a/crates/json-abi/tests/abi/Abiencoderv2Test.sol b/crates/json-abi/tests/abi/Abiencoderv2Test.sol index ad471b226..2a8b9fbd4 100644 --- a/crates/json-abi/tests/abi/Abiencoderv2Test.sol +++ b/crates/json-abi/tests/abi/Abiencoderv2Test.sol @@ -1,8 +1,10 @@ -interface Abiencoderv2Test { +library Hello { struct Person { string name; uint256 age; } +} - function defaultPerson() external pure returns (Person memory); +interface Abiencoderv2Test { + function defaultPerson() external pure returns (Hello.Person memory); } \ No newline at end of file diff --git a/crates/json-abi/tests/abi/AggregationRouterV5.sol b/crates/json-abi/tests/abi/AggregationRouterV5.sol index f1163f982..ac8f19cef 100644 --- a/crates/json-abi/tests/abi/AggregationRouterV5.sol +++ b/crates/json-abi/tests/abi/AggregationRouterV5.sol @@ -1,4 +1,16 @@ -interface AggregationRouterV5 { +library GenericRouter { + struct SwapDescription { + address srcToken; + address dstToken; + address payable srcReceiver; + address payable dstReceiver; + uint256 amount; + uint256 minReturnAmount; + uint256 flags; + } +} + +library OrderLib { struct Order { uint256 salt; address makerAsset; @@ -11,6 +23,9 @@ interface AggregationRouterV5 { uint256 offsets; bytes interactions; } +} + +library OrderRFQLib { struct OrderRFQ { uint256 info; address makerAsset; @@ -20,16 +35,9 @@ interface AggregationRouterV5 { uint256 makingAmount; uint256 takingAmount; } - struct SwapDescription { - address srcToken; - address dstToken; - address payable srcReceiver; - address payable dstReceiver; - uint256 amount; - uint256 minReturnAmount; - uint256 flags; - } +} +interface AggregationRouterV5 { error AccessDenied(); error AdvanceNonceFailed(); error AlreadyFilled(); @@ -89,24 +97,24 @@ interface AggregationRouterV5 { function advanceNonce(uint8 amount) external; function and(uint256 offsets, bytes memory data) external view returns (bool); function arbitraryStaticCall(address target, bytes memory data) external view returns (uint256); - function cancelOrder(Order memory order) external returns (uint256 orderRemaining, bytes32 orderHash); + function cancelOrder(OrderLib.Order memory order) external returns (uint256 orderRemaining, bytes32 orderHash); function cancelOrderRFQ(uint256 orderInfo) external; function cancelOrderRFQ(uint256 orderInfo, uint256 additionalMask) external; - function checkPredicate(Order memory order) external view returns (bool); + function checkPredicate(OrderLib.Order memory order) external view returns (bool); function clipperSwap(address clipperExchange, address srcToken, address dstToken, uint256 inputAmount, uint256 outputAmount, uint256 goodUntil, bytes32 r, bytes32 vs) external payable returns (uint256 returnAmount); function clipperSwapTo(address clipperExchange, address payable recipient, address srcToken, address dstToken, uint256 inputAmount, uint256 outputAmount, uint256 goodUntil, bytes32 r, bytes32 vs) external payable returns (uint256 returnAmount); function clipperSwapToWithPermit(address clipperExchange, address payable recipient, address srcToken, address dstToken, uint256 inputAmount, uint256 outputAmount, uint256 goodUntil, bytes32 r, bytes32 vs, bytes memory permit) external returns (uint256 returnAmount); function destroy() external; function eq(uint256 value, bytes memory data) external view returns (bool); - function fillOrder(Order memory order, bytes memory signature, bytes memory interaction, uint256 makingAmount, uint256 takingAmount, uint256 skipPermitAndThresholdAmount) external payable returns (uint256, uint256, bytes32); - function fillOrderRFQ(OrderRFQ memory order, bytes memory signature, uint256 flagsAndAmount) external payable returns (uint256, uint256, bytes32); - function fillOrderRFQCompact(OrderRFQ memory order, bytes32 r, bytes32 vs, uint256 flagsAndAmount) external payable returns (uint256 filledMakingAmount, uint256 filledTakingAmount, bytes32 orderHash); - function fillOrderRFQTo(OrderRFQ memory order, bytes memory signature, uint256 flagsAndAmount, address target) external payable returns (uint256 filledMakingAmount, uint256 filledTakingAmount, bytes32 orderHash); - function fillOrderRFQToWithPermit(OrderRFQ memory order, bytes memory signature, uint256 flagsAndAmount, address target, bytes memory permit) external returns (uint256, uint256, bytes32); - function fillOrderTo(Order memory order_, bytes memory signature, bytes memory interaction, uint256 makingAmount, uint256 takingAmount, uint256 skipPermitAndThresholdAmount, address target) external payable returns (uint256 actualMakingAmount, uint256 actualTakingAmount, bytes32 orderHash); - function fillOrderToWithPermit(Order memory order, bytes memory signature, bytes memory interaction, uint256 makingAmount, uint256 takingAmount, uint256 skipPermitAndThresholdAmount, address target, bytes memory permit) external returns (uint256, uint256, bytes32); + function fillOrder(OrderLib.Order memory order, bytes memory signature, bytes memory interaction, uint256 makingAmount, uint256 takingAmount, uint256 skipPermitAndThresholdAmount) external payable returns (uint256, uint256, bytes32); + function fillOrderRFQ(OrderRFQLib.OrderRFQ memory order, bytes memory signature, uint256 flagsAndAmount) external payable returns (uint256, uint256, bytes32); + function fillOrderRFQCompact(OrderRFQLib.OrderRFQ memory order, bytes32 r, bytes32 vs, uint256 flagsAndAmount) external payable returns (uint256 filledMakingAmount, uint256 filledTakingAmount, bytes32 orderHash); + function fillOrderRFQTo(OrderRFQLib.OrderRFQ memory order, bytes memory signature, uint256 flagsAndAmount, address target) external payable returns (uint256 filledMakingAmount, uint256 filledTakingAmount, bytes32 orderHash); + function fillOrderRFQToWithPermit(OrderRFQLib.OrderRFQ memory order, bytes memory signature, uint256 flagsAndAmount, address target, bytes memory permit) external returns (uint256, uint256, bytes32); + function fillOrderTo(OrderLib.Order memory order_, bytes memory signature, bytes memory interaction, uint256 makingAmount, uint256 takingAmount, uint256 skipPermitAndThresholdAmount, address target) external payable returns (uint256 actualMakingAmount, uint256 actualTakingAmount, bytes32 orderHash); + function fillOrderToWithPermit(OrderLib.Order memory order, bytes memory signature, bytes memory interaction, uint256 makingAmount, uint256 takingAmount, uint256 skipPermitAndThresholdAmount, address target, bytes memory permit) external returns (uint256, uint256, bytes32); function gt(uint256 value, bytes memory data) external view returns (bool); - function hashOrder(Order memory order) external view returns (bytes32); + function hashOrder(OrderLib.Order memory order) external view returns (bytes32); function increaseNonce() external; function invalidatorForOrderRFQ(address maker, uint256 slot) external view returns (uint256); function lt(uint256 value, bytes memory data) external view returns (bool); @@ -120,7 +128,7 @@ interface AggregationRouterV5 { function renounceOwnership() external; function rescueFunds(address token, uint256 amount) external; function simulate(address target, bytes memory data) external; - function swap(address executor, SwapDescription memory desc, bytes memory permit, bytes memory data) external payable returns (uint256 returnAmount, uint256 spentAmount); + function swap(address executor, GenericRouter.SwapDescription memory desc, bytes memory permit, bytes memory data) external payable returns (uint256 returnAmount, uint256 spentAmount); function timestampBelow(uint256 time) external view returns (bool); function timestampBelowAndNonceEquals(uint256 timeNonceAccount) external view returns (bool); function transferOwnership(address newOwner) external; diff --git a/crates/json-abi/tests/abi/BalancerV2Vault.sol b/crates/json-abi/tests/abi/BalancerV2Vault.sol index fdf99f15d..6fa7e7b41 100644 --- a/crates/json-abi/tests/abi/BalancerV2Vault.sol +++ b/crates/json-abi/tests/abi/BalancerV2Vault.sol @@ -1,4 +1,4 @@ -interface BalancerV2Vault { +library IVault { type PoolBalanceOpKind is uint8; type PoolSpecialization is uint8; type SwapKind is uint8; @@ -49,7 +49,9 @@ interface BalancerV2Vault { address sender; address payable recipient; } +} +interface BalancerV2Vault { event AuthorizerChanged(address indexed newAuthorizer); event ExternalBalanceTransfer(address indexed token, address indexed sender, address recipient, uint256 amount); event FlashLoan(address indexed recipient, address indexed token, uint256 amount, uint256 feeAmount); @@ -57,7 +59,7 @@ interface BalancerV2Vault { event PausedStateChanged(bool paused); event PoolBalanceChanged(bytes32 indexed poolId, address indexed liquidityProvider, address[] tokens, int256[] deltas, uint256[] protocolFeeAmounts); event PoolBalanceManaged(bytes32 indexed poolId, address indexed assetManager, address indexed token, int256 cashDelta, int256 managedDelta); - event PoolRegistered(bytes32 indexed poolId, address indexed poolAddress, PoolSpecialization specialization); + event PoolRegistered(bytes32 indexed poolId, address indexed poolAddress, IVault.PoolSpecialization specialization); event RelayerApprovalChanged(address indexed relayer, address indexed sender, bool approved); event Swap(bytes32 indexed poolId, address indexed tokenIn, address indexed tokenOut, uint256 amountIn, uint256 amountOut); event TokensDeregistered(bytes32 indexed poolId, address[] tokens); @@ -66,9 +68,9 @@ interface BalancerV2Vault { receive() external payable; function WETH() external view returns (address); - function batchSwap(SwapKind kind, BatchSwapStep[] memory swaps, address[] memory assets, FundManagement memory funds, int256[] memory limits, uint256 deadline) external payable returns (int256[] memory assetDeltas); + function batchSwap(IVault.SwapKind kind, IVault.BatchSwapStep[] memory swaps, address[] memory assets, IVault.FundManagement memory funds, int256[] memory limits, uint256 deadline) external payable returns (int256[] memory assetDeltas); function deregisterTokens(bytes32 poolId, address[] memory tokens) external; - function exitPool(bytes32 poolId, address sender, address payable recipient, ExitPoolRequest memory request) external; + function exitPool(bytes32 poolId, address sender, address payable recipient, IVault.ExitPoolRequest memory request) external; function flashLoan(address recipient, address[] memory tokens, uint256[] memory amounts, bytes memory userData) external; function getActionId(bytes4 selector) external view returns (bytes32); function getAuthorizer() external view returns (address); @@ -76,19 +78,19 @@ interface BalancerV2Vault { function getInternalBalance(address user, address[] memory tokens) external view returns (uint256[] memory balances); function getNextNonce(address user) external view returns (uint256); function getPausedState() external view returns (bool paused, uint256 pauseWindowEndTime, uint256 bufferPeriodEndTime); - function getPool(bytes32 poolId) external view returns (address, PoolSpecialization); + function getPool(bytes32 poolId) external view returns (address, IVault.PoolSpecialization); function getPoolTokenInfo(bytes32 poolId, address token) external view returns (uint256 cash, uint256 managed, uint256 lastChangeBlock, address assetManager); function getPoolTokens(bytes32 poolId) external view returns (address[] memory tokens, uint256[] memory balances, uint256 lastChangeBlock); function getProtocolFeesCollector() external view returns (address); function hasApprovedRelayer(address user, address relayer) external view returns (bool); - function joinPool(bytes32 poolId, address sender, address recipient, JoinPoolRequest memory request) external payable; - function managePoolBalance(PoolBalanceOp[] memory ops) external; - function manageUserBalance(UserBalanceOp[] memory ops) external payable; - function queryBatchSwap(SwapKind kind, BatchSwapStep[] memory swaps, address[] memory assets, FundManagement memory funds) external returns (int256[] memory); - function registerPool(PoolSpecialization specialization) external returns (bytes32); + function joinPool(bytes32 poolId, address sender, address recipient, IVault.JoinPoolRequest memory request) external payable; + function managePoolBalance(IVault.PoolBalanceOp[] memory ops) external; + function manageUserBalance(IVault.UserBalanceOp[] memory ops) external payable; + function queryBatchSwap(IVault.SwapKind kind, IVault.BatchSwapStep[] memory swaps, address[] memory assets, IVault.FundManagement memory funds) external returns (int256[] memory); + function registerPool(IVault.PoolSpecialization specialization) external returns (bytes32); function registerTokens(bytes32 poolId, address[] memory tokens, address[] memory assetManagers) external; function setAuthorizer(address newAuthorizer) external; function setPaused(bool paused) external; function setRelayerApproval(address sender, address relayer, bool approved) external; - function swap(SingleSwap memory singleSwap, FundManagement memory funds, uint256 limit, uint256 deadline) external payable returns (uint256 amountCalculated); + function swap(IVault.SingleSwap memory singleSwap, IVault.FundManagement memory funds, uint256 limit, uint256 deadline) external payable returns (uint256 amountCalculated); } \ No newline at end of file diff --git a/crates/json-abi/tests/abi/EventWithStruct.sol b/crates/json-abi/tests/abi/EventWithStruct.sol index 616bc45a6..d710163b0 100644 --- a/crates/json-abi/tests/abi/EventWithStruct.sol +++ b/crates/json-abi/tests/abi/EventWithStruct.sol @@ -1,8 +1,10 @@ -interface EventWithStruct { +library MyContract { struct MyStruct { uint256 a; uint256 b; } +} - event MyEvent(MyStruct, uint256 c); +interface EventWithStruct { + event MyEvent(MyContract.MyStruct, uint256 c); } \ No newline at end of file diff --git a/crates/json-abi/tests/abi/GnosisSafe.sol b/crates/json-abi/tests/abi/GnosisSafe.sol index 0dc1d37ef..f40c41e7d 100644 --- a/crates/json-abi/tests/abi/GnosisSafe.sol +++ b/crates/json-abi/tests/abi/GnosisSafe.sol @@ -1,6 +1,8 @@ -interface GnosisSafe { +library Enum { type Operation is uint8; +} +interface GnosisSafe { event AddedOwner(address owner); event ApproveHash(bytes32 indexed approvedHash, address indexed owner); event ChangedMasterCopy(address masterCopy); @@ -26,21 +28,21 @@ interface GnosisSafe { function disableModule(address prevModule, address module) external; function domainSeparator() external view returns (bytes32); function enableModule(address module) external; - function encodeTransactionData(address to, uint256 value, bytes memory data, Operation operation, uint256 safeTxGas, uint256 baseGas, uint256 gasPrice, address gasToken, address refundReceiver, uint256 _nonce) external view returns (bytes memory); - function execTransaction(address to, uint256 value, bytes memory data, Operation operation, uint256 safeTxGas, uint256 baseGas, uint256 gasPrice, address gasToken, address payable refundReceiver, bytes memory signatures) external returns (bool success); - function execTransactionFromModule(address to, uint256 value, bytes memory data, Operation operation) external returns (bool success); - function execTransactionFromModuleReturnData(address to, uint256 value, bytes memory data, Operation operation) external returns (bool success, bytes memory returnData); + function encodeTransactionData(address to, uint256 value, bytes memory data, Enum.Operation operation, uint256 safeTxGas, uint256 baseGas, uint256 gasPrice, address gasToken, address refundReceiver, uint256 _nonce) external view returns (bytes memory); + function execTransaction(address to, uint256 value, bytes memory data, Enum.Operation operation, uint256 safeTxGas, uint256 baseGas, uint256 gasPrice, address gasToken, address payable refundReceiver, bytes memory signatures) external returns (bool success); + function execTransactionFromModule(address to, uint256 value, bytes memory data, Enum.Operation operation) external returns (bool success); + function execTransactionFromModuleReturnData(address to, uint256 value, bytes memory data, Enum.Operation operation) external returns (bool success, bytes memory returnData); function getMessageHash(bytes memory message) external view returns (bytes32); function getModules() external view returns (address[] memory); function getModulesPaginated(address start, uint256 pageSize) external view returns (address[] memory array, address next); function getOwners() external view returns (address[] memory); function getThreshold() external view returns (uint256); - function getTransactionHash(address to, uint256 value, bytes memory data, Operation operation, uint256 safeTxGas, uint256 baseGas, uint256 gasPrice, address gasToken, address refundReceiver, uint256 _nonce) external view returns (bytes32); + function getTransactionHash(address to, uint256 value, bytes memory data, Enum.Operation operation, uint256 safeTxGas, uint256 baseGas, uint256 gasPrice, address gasToken, address refundReceiver, uint256 _nonce) external view returns (bytes32); function isOwner(address owner) external view returns (bool); function isValidSignature(bytes memory _data, bytes memory _signature) external returns (bytes4); function nonce() external view returns (uint256); function removeOwner(address prevOwner, address owner, uint256 _threshold) external; - function requiredTxGas(address to, uint256 value, bytes memory data, Operation operation) external returns (uint256); + function requiredTxGas(address to, uint256 value, bytes memory data, Enum.Operation operation) external returns (uint256); function setFallbackHandler(address handler) external; function setup(address[] memory _owners, uint256 _threshold, address to, bytes memory data, address fallbackHandler, address paymentToken, uint256 payment, address payable paymentReceiver) external; function signMessage(bytes memory _data) external; diff --git a/crates/json-abi/tests/abi/LargeStruct.sol b/crates/json-abi/tests/abi/LargeStruct.sol index f80925319..5c6bcc6de 100644 --- a/crates/json-abi/tests/abi/LargeStruct.sol +++ b/crates/json-abi/tests/abi/LargeStruct.sol @@ -1,4 +1,4 @@ -interface LargeStruct { +library Many { struct Info { uint128 x; int24 y; @@ -10,6 +10,8 @@ interface LargeStruct { uint256 e; uint256 f; } +} - function getById(bytes32 id) external view returns (Info memory); +interface LargeStruct { + function getById(bytes32 id) external view returns (Many.Info memory); } \ No newline at end of file diff --git a/crates/json-abi/tests/abi/LargeStructs.sol b/crates/json-abi/tests/abi/LargeStructs.sol index 3648acc51..597f8a459 100644 --- a/crates/json-abi/tests/abi/LargeStructs.sol +++ b/crates/json-abi/tests/abi/LargeStructs.sol @@ -1,4 +1,4 @@ -interface LargeStructs { +library IReader { struct AssetStorage { bytes32 symbol; address tokenAddress; @@ -73,9 +73,11 @@ interface LargeStructs { uint96 entryPrice; uint128 entryFunding; } +} - function getChainStorage() external returns (ChainStorage memory chain); +interface LargeStructs { + function getChainStorage() external returns (IReader.ChainStorage memory chain); function getOrders(uint64[] memory orderIds) external pure returns (bytes32[3][] memory orders, bool[] memory isExist); - function getSubAccounts(bytes32[] memory subAccountIds) external pure returns (SubAccountState[] memory subAccounts); - function getSubAccountsAndOrders(bytes32[] memory subAccountIds, uint64[] memory orderIds) external pure returns (SubAccountState[] memory subAccounts, bytes32[3][] memory orders, bool[] memory isOrderExist); + function getSubAccounts(bytes32[] memory subAccountIds) external pure returns (IReader.SubAccountState[] memory subAccounts); + function getSubAccountsAndOrders(bytes32[] memory subAccountIds, uint64[] memory orderIds) external pure returns (IReader.SubAccountState[] memory subAccounts, bytes32[3][] memory orders, bool[] memory isOrderExist); } \ No newline at end of file diff --git a/crates/json-abi/tests/abi/LargeTuple.sol b/crates/json-abi/tests/abi/LargeTuple.sol index 5771b08ed..e4a8add22 100644 --- a/crates/json-abi/tests/abi/LargeTuple.sol +++ b/crates/json-abi/tests/abi/LargeTuple.sol @@ -1,4 +1,4 @@ -interface LargeTuple { +library Contract { struct Response { bytes output1; bytes output2; @@ -14,6 +14,8 @@ interface LargeTuple { bytes output12; bytes output13; } +} - function doSomething(uint160 input) external view returns (Response memory); +interface LargeTuple { + function doSomething(uint160 input) external view returns (Contract.Response memory); } \ No newline at end of file diff --git a/crates/sol-macro-expander/src/expand/contract.rs b/crates/sol-macro-expander/src/expand/contract.rs index dab0bd086..ffa1a15bc 100644 --- a/crates/sol-macro-expander/src/expand/contract.rs +++ b/crates/sol-macro-expander/src/expand/contract.rs @@ -28,7 +28,7 @@ use syn::{parse_quote, Attribute, Result}; /// } /// } /// ``` -pub(super) fn expand(cx: &ExpCtxt<'_>, contract: &ItemContract) -> Result { +pub(super) fn expand(cx: &mut ExpCtxt<'_>, contract: &ItemContract) -> Result { let ItemContract { name, body, .. } = contract; let (sol_attrs, attrs) = contract.split_attrs()?; diff --git a/crates/sol-macro-expander/src/expand/mod.rs b/crates/sol-macro-expander/src/expand/mod.rs index 3dd463750..5928588b8 100644 --- a/crates/sol-macro-expander/src/expand/mod.rs +++ b/crates/sol-macro-expander/src/expand/mod.rs @@ -48,36 +48,103 @@ pub fn expand(ast: File) -> Result { ExpCtxt::new(&ast).expand() } +/// Mapping namespace -> ident -> T +/// +/// Keeps namespaced items. Namespace `None` represents global namespace (top-level items). +/// Namespace `Some(ident)` represents items declared inside of a contract. +#[derive(Debug, Clone)] +pub struct NamespacedMap(pub IndexMap, IndexMap>); + +impl Default for NamespacedMap { + fn default() -> Self { + Self(Default::default()) + } +} + +impl NamespacedMap { + /// Inserts an item into the map. + pub fn insert(&mut self, namespace: Option, name: SolIdent, value: T) { + self.0.entry(namespace).or_default().insert(name, value); + } + + /// Given [SolPath] and current namespace, resolves item + pub fn resolve(&self, path: &SolPath, current_namespace: &Option) -> Option<&T> { + // If path contains two components, its `Contract.Something` where `Contract` is a namespace + if path.len() == 2 { + self.get_by_name_and_namespace(&Some(path.first().clone()), path.last()) + } else { + // If there's only one component, this is either global item, or item declared in the + // current namespace. + // + // NOTE: This does not account for inheritance + self.get_by_name_and_namespace(&None, path.last()) + .or_else(|| self.get_by_name_and_namespace(current_namespace, path.last())) + } + } + + fn get_by_name_and_namespace( + &self, + namespace: &Option, + name: &SolIdent, + ) -> Option<&T> { + self.0.get(namespace).and_then(|vals| vals.get(name)) + } +} + +impl NamespacedMap { + /// Inserts an item into the map if it does not exist and returns a mutable reference to it. + pub fn get_or_insert_default(&mut self, namespace: Option, name: SolIdent) -> &mut T { + self.0.entry(namespace).or_default().entry(name).or_default() + } +} + /// The expansion context. #[derive(Debug)] pub struct ExpCtxt<'ast> { - all_items: Vec<&'ast Item>, - custom_types: IndexMap, + /// Keeps items along with optional parent contract holding their definition. + all_items: NamespacedMap<&'ast Item>, + custom_types: NamespacedMap, /// `name => item` - overloaded_items: IndexMap>>, - /// `signature => new_name` - overloads: IndexMap, + overloaded_items: NamespacedMap>>, + /// `namespace => signature => new_name` + overloads: IndexMap, IndexMap>, attrs: SolAttrs, crates: ExternCrates, ast: &'ast File, + + /// Current namespace. Switched during AST traversal and expansion of different contracts. + current_namespace: Option, } // expand impl<'ast> ExpCtxt<'ast> { fn new(ast: &'ast File) -> Self { Self { - all_items: Vec::new(), - custom_types: IndexMap::new(), - overloaded_items: IndexMap::new(), + all_items: Default::default(), + custom_types: Default::default(), + overloaded_items: Default::default(), overloads: IndexMap::new(), attrs: SolAttrs::default(), crates: ExternCrates::default(), ast, + current_namespace: None, } } + /// Sets the current namespace for the duration of the closure. + fn with_namespace( + &mut self, + namespace: Option, + mut f: impl FnMut(&mut Self) -> O, + ) -> O { + let prev = std::mem::replace(&mut self.current_namespace, namespace); + let res = f(self); + self.current_namespace = prev; + res + } + fn expand(mut self) -> Result { let mut abort = false; let mut tokens = TokenStream::new(); @@ -88,7 +155,7 @@ impl<'ast> ExpCtxt<'ast> { self.visit_file(self.ast); - if self.all_items.len() > 1 { + if !self.all_items.0.is_empty() { self.resolve_custom_types(); if self.mk_overloads_map().is_err() { abort = true; @@ -110,9 +177,11 @@ impl<'ast> ExpCtxt<'ast> { Ok(tokens) } - fn expand_item(&self, item: &Item) -> Result { + fn expand_item(&mut self, item: &Item) -> Result { match item { - Item::Contract(contract) => contract::expand(self, contract), + Item::Contract(contract) => self.with_namespace(Some(contract.name.clone()), |this| { + contract::expand(this, contract) + }), Item::Enum(enumm) => r#enum::expand(self, enumm), Item::Error(error) => error::expand(self, error), Item::Event(event) => event::expand(self, event), @@ -138,16 +207,18 @@ impl<'ast> ExpCtxt<'ast> { fn mk_types_map(&mut self) { let mut map = std::mem::take(&mut self.custom_types); - map.reserve(self.all_items.len()); - for &item in &self.all_items { - let (name, ty) = match item { - Item::Contract(c) => (&c.name, c.as_type()), - Item::Enum(e) => (&e.name, e.as_type()), - Item::Struct(s) => (&s.name, s.as_type()), - Item::Udt(u) => (&u.name, u.ty.clone()), - _ => continue, - }; - map.insert(name.clone(), ty); + for (namespace, items) in &self.all_items.0 { + for (name, item) in items { + let ty = match item { + Item::Contract(c) => c.as_type(), + Item::Enum(e) => e.as_type(), + Item::Struct(s) => s.as_type(), + Item::Udt(u) => u.ty.clone(), + _ => continue, + }; + + map.insert(namespace.clone(), name.clone(), ty); + } } self.custom_types = map; } @@ -155,81 +226,92 @@ impl<'ast> ExpCtxt<'ast> { fn resolve_custom_types(&mut self) { self.mk_types_map(); let map = self.custom_types.clone(); - for ty in self.custom_types.values_mut() { - let mut i = 0; - ty.visit_mut(|ty| { + for (namespace, custom_types) in &mut self.custom_types.0 { + for ty in custom_types.values_mut() { + let mut i = 0; + ty.visit_mut(|ty| { + if i >= RESOLVE_LIMIT { + return; + } + let ty @ Type::Custom(_) = ty else { return }; + let Type::Custom(name) = &*ty else { unreachable!() }; + let Some(resolved) = map.resolve(name, namespace) else { + return; + }; + ty.clone_from(resolved); + i += 1; + }); if i >= RESOLVE_LIMIT { - return; + abort!( + ty.span(), + "failed to resolve types.\n\ + This is likely due to an infinitely recursive type definition.\n\ + If you believe this is a bug, please file an issue at \ + https://github.com/alloy-rs/core/issues/new/choose" + ); } - let ty @ Type::Custom(_) = ty else { return }; - let Type::Custom(name) = &*ty else { unreachable!() }; - let Some(resolved) = map.get(name.last()) else { - return; - }; - ty.clone_from(resolved); - i += 1; - }); - if i >= RESOLVE_LIMIT { - abort!( - ty.span(), - "failed to resolve types.\n\ - This is likely due to an infinitely recursive type definition.\n\ - If you believe this is a bug, please file an issue at \ - https://github.com/alloy-rs/core/issues/new/choose" - ); } } } fn mk_overloads_map(&mut self) -> std::result::Result<(), ()> { - let all_orig_names: Vec<_> = - self.overloaded_items.values().flatten().filter_map(|f| f.name()).collect(); let mut overloads_map = std::mem::take(&mut self.overloads); - let mut failed = false; - - for functions in self.overloaded_items.values().filter(|fs| fs.len() >= 2) { - // check for same parameters - for (i, &a) in functions.iter().enumerate() { - for &b in functions.iter().skip(i + 1) { - if a.eq_by_types(b) { - failed = true; - emit_error!( - a.span(), - "{} with same name and parameter types defined twice", - a.desc(); - - note = b.span() => "other declaration is here"; - ); + for namespace in &self.overloaded_items.0.keys().cloned().collect::>() { + let mut failed = false; + + self.with_namespace(namespace.clone(), |this| { + let overloaded_items = this.overloaded_items.0.get(namespace).unwrap(); + let all_orig_names: Vec<_> = + overloaded_items.values().flatten().filter_map(|f| f.name()).collect(); + + for functions in overloaded_items.values().filter(|fs| fs.len() >= 2) { + // check for same parameters + for (i, &a) in functions.iter().enumerate() { + for &b in functions.iter().skip(i + 1) { + if a.eq_by_types(b) { + failed = true; + emit_error!( + a.span(), + "{} with same name and parameter types defined twice", + a.desc(); + + note = b.span() => "other declaration is here"; + ); + } + } } - } - } - for (i, &item) in functions.iter().enumerate() { - let Some(old_name) = item.name() else { - continue; - }; - let new_name = format!("{old_name}_{i}"); - if let Some(other) = all_orig_names.iter().find(|x| x.0 == new_name) { - failed = true; - emit_error!( - old_name.span(), - "{} `{old_name}` is overloaded, \ - but the generated name `{new_name}` is already in use", - item.desc(); - - note = other.span() => "other declaration is here"; - ) + for (i, &item) in functions.iter().enumerate() { + let Some(old_name) = item.name() else { + continue; + }; + let new_name = format!("{old_name}_{i}"); + if let Some(other) = all_orig_names.iter().find(|x| x.0 == new_name) { + failed = true; + emit_error!( + old_name.span(), + "{} `{old_name}` is overloaded, \ + but the generated name `{new_name}` is already in use", + item.desc(); + + note = other.span() => "other declaration is here"; + ) + } + + overloads_map + .entry(namespace.clone()) + .or_default() + .insert(item.signature(this), new_name); + } } + }); - overloads_map.insert(item.signature(self), new_name); + if failed { + return Err(()); } } - if failed { - return Err(()); - } - self.overloads = overloads_map; Ok(()) } @@ -237,15 +319,23 @@ impl<'ast> ExpCtxt<'ast> { impl<'ast> Visit<'ast> for ExpCtxt<'ast> { fn visit_item(&mut self, item: &'ast Item) { - self.all_items.push(item); - ast::visit::visit_item(self, item); + if let Some(name) = item.name() { + self.all_items.insert(self.current_namespace.clone(), name.clone(), item) + } + + if let Item::Contract(contract) = item { + self.with_namespace(Some(contract.name.clone()), |this| { + ast::visit::visit_item(this, item); + }); + } else { + ast::visit::visit_item(self, item); + } } fn visit_item_function(&mut self, function: &'ast ItemFunction) { if let Some(name) = &function.name { self.overloaded_items - .entry(name.as_string()) - .or_default() + .get_or_insert_default(self.current_namespace.clone(), name.clone()) .push(OverloadedItem::Function(function)); } ast::visit::visit_item_function(self, function); @@ -253,16 +343,14 @@ impl<'ast> Visit<'ast> for ExpCtxt<'ast> { fn visit_item_event(&mut self, event: &'ast ItemEvent) { self.overloaded_items - .entry(event.name.as_string()) - .or_default() + .get_or_insert_default(self.current_namespace.clone(), event.name.clone()) .push(OverloadedItem::Event(event)); ast::visit::visit_item_event(self, event); } fn visit_item_error(&mut self, error: &'ast ItemError) { self.overloaded_items - .entry(error.name.as_string()) - .or_default() + .get_or_insert_default(self.current_namespace.clone(), error.name.clone()) .push(OverloadedItem::Error(error)); ast::visit::visit_item_error(self, error); } @@ -347,8 +435,7 @@ impl<'ast> ExpCtxt<'ast> { } fn try_item(&self, name: &SolPath) -> Option<&Item> { - let name = name.last(); - self.all_items.iter().copied().find(|item| item.name() == Some(name)) + self.all_items.resolve(name, &self.current_namespace).copied() } /// Recursively resolves the given type by constructing a new one. @@ -371,7 +458,7 @@ impl<'ast> ExpCtxt<'ast> { } fn try_custom_type(&self, name: &SolPath) -> Option<&Type> { - self.custom_types.get(name.last()) + self.custom_types.resolve(name, &self.current_namespace) } /// Returns the name of the function, adjusted for overloads. @@ -385,7 +472,7 @@ impl<'ast> ExpCtxt<'ast> { fn overloaded_name(&self, item: OverloadedItem<'ast>) -> SolIdent { let original_ident = item.name().expect("item has no name"); let sig = item.signature(self); - match self.overloads.get(&sig) { + match self.overloads.get(&self.current_namespace).and_then(|m| m.get(&sig)) { Some(name) => SolIdent::new_spanned(name, original_ident.span()), None => original_ident.clone(), } @@ -526,7 +613,7 @@ impl<'ast> ExpCtxt<'ast> { for param in params { param.ty.visit(|ty| { if let Type::Custom(name) = ty { - if !self.custom_types.contains_key(name.last()) { + if self.try_custom_type(name).is_none() { let note = (!errored).then(|| { errored = true; "Custom types must be declared inside of the same scope they are referenced in,\n\ diff --git a/crates/sol-macro-expander/src/expand/ty.rs b/crates/sol-macro-expander/src/expand/ty.rs index e2e714e8e..78f14e3ff 100644 --- a/crates/sol-macro-expander/src/expand/ty.rs +++ b/crates/sol-macro-expander/src/expand/ty.rs @@ -27,7 +27,7 @@ pub fn expand_rust_type(ty: &Type, crates: &ExternCrates) -> TokenStream { } /// The [`expand_type`] recursive implementation. -pub fn rec_expand_type(ty: &Type, crates: &ExternCrates, tokens: &mut TokenStream) { +pub(super) fn rec_expand_type(ty: &Type, crates: &ExternCrates, tokens: &mut TokenStream) { let alloy_sol_types = &crates.sol_types; let tts = match *ty { Type::Address(span, _) => quote_spanned! {span=> #alloy_sol_types::sol_data::Address }, @@ -80,14 +80,17 @@ pub fn rec_expand_type(ty: &Type, crates: &ExternCrates, tokens: &mut TokenStrea ::core::compile_error!("Mapping types are not supported here") }, - Type::Custom(ref custom) => return custom.to_tokens(tokens), + Type::Custom(ref custom) => { + let segments = custom.iter(); + quote_spanned! {custom.span()=> #(#segments)::* } + } }; tokens.extend(tts); } // IMPORTANT: Keep in sync with `sol-types/src/types/data_type.rs` /// The [`expand_rust_type`] recursive implementation. -pub fn rec_expand_rust_type(ty: &Type, crates: &ExternCrates, tokens: &mut TokenStream) { +pub(super) fn rec_expand_rust_type(ty: &Type, crates: &ExternCrates, tokens: &mut TokenStream) { let alloy_sol_types = &crates.sol_types; let tts = match *ty { Type::Address(span, _) => quote_spanned! {span=> #alloy_sol_types::private::Address }, diff --git a/crates/sol-macro-input/src/json.rs b/crates/sol-macro-input/src/json.rs index 0579de2ba..8d451a9d5 100644 --- a/crates/sol-macro-input/src/json.rs +++ b/crates/sol-macro-input/src/json.rs @@ -1,8 +1,8 @@ use crate::{SolInput, SolInputKind}; use alloy_json_abi::{ContractObject, JsonAbi, ToSolConfig}; -use proc_macro2::{Ident, TokenStream}; -use quote::{quote, TokenStreamExt}; -use syn::Result; +use proc_macro2::{Ident, TokenStream, TokenTree}; +use quote::quote; +use syn::{AttrStyle, Result}; impl SolInput { /// Normalize JSON ABI inputs into Sol inputs. @@ -18,7 +18,42 @@ impl SolInput { let mut abi = abi.ok_or_else(|| syn::Error::new(name.span(), "ABI not found in JSON"))?; let sol = abi_to_sol(&name, &mut abi); - let sol_interface_tokens = tokens_for_sol(&name, &sol)?; + let mut all_tokens = tokens_for_sol(&name, &sol)?.into_iter(); + + let (inner_attrs, attrs) = attrs + .into_iter() + .partition::, _>(|attr| matches!(attr.style, AttrStyle::Inner(_))); + + let derives = + attrs.iter().filter(|attr| attr.path().is_ident("derive")).collect::>(); + + let mut library_tokens_iter = all_tokens + .by_ref() + .take_while(|tt| !matches!(tt, TokenTree::Ident(id) if id == "interface")) + .skip_while(|tt| matches!(tt, TokenTree::Ident(id) if id == "library")) + .peekable(); + + let library_tokens = library_tokens_iter.by_ref(); + + let mut libraries = Vec::new(); + + while library_tokens.peek().is_some() { + let sol_library_tokens: TokenStream = std::iter::once(TokenTree::Ident(id("library"))) + .chain( + library_tokens + .take_while(|tt| !matches!(tt, TokenTree::Ident(id) if id == "library")), + ) + .collect(); + + let tokens = quote! { + #(#derives)* + #sol_library_tokens + }; + + libraries.push(tokens); + } + let sol_interface_tokens: TokenStream = + std::iter::once(TokenTree::Ident(id("interface"))).chain(all_tokens).collect(); let bytecode = bytecode.map(|bytes| { let s = bytes.to_string(); quote!(bytecode = #s,) @@ -43,6 +78,9 @@ Generated by the following Solidity interface... json_s = serde_json::to_string_pretty(&abi).unwrap() ); let tokens = quote! { + #(#inner_attrs)* + #(#libraries)* + #(#attrs_iter)* #[doc = #doc_str] #[sol(#bytecode #deployed_bytecode)] @@ -80,16 +118,16 @@ pub fn tokens_for_sol(name: &Ident, sol: &str) -> Result { ); syn::Error::new(name.span(), msg) }; - let brace_idx = sol.find('{').ok_or_else(|| mk_err("missing `{`"))?; - let tts = - syn::parse_str::(&sol[brace_idx..]).map_err(|e| mk_err(&e.to_string()))?; - - let mut tokens = TokenStream::new(); - // append `name` manually for the span - tokens.append(id("interface")); - tokens.append(name.clone()); - tokens.extend(tts); - Ok(tokens) + let tts = syn::parse_str::(sol).map_err(|e| mk_err(&e.to_string()))?; + Ok(tts + .into_iter() + .map(|mut tt| { + if matches!(&tt, TokenTree::Ident(id) if id == name) { + tt.set_span(name.span()); + } + tt + }) + .collect()) } #[inline] @@ -102,7 +140,6 @@ fn id(s: impl AsRef) -> Ident { #[cfg(test)] mod tests { use super::*; - use ast::Item; use std::path::{Path, PathBuf}; #[test] @@ -122,98 +159,7 @@ mod tests { } } - #[allow(clippy::single_match)] fn parse_test(s: &str, path: &str) { - let (c, name) = expand_test(s, path); - match name { - "Udvts" => { - assert_eq!(c.name, "Udvts"); - assert_eq!(c.body.len(), 12, "{}, {:#?}", c.body.len(), c); - let [Item::Udt(a), Item::Udt(b), Item::Udt(c), rest @ ..] = &c.body[..] else { - for item in &c.body { - eprintln!("{item:?}\n"); - } - panic!(); - }; - - assert_eq!(a.name, "ItemType"); - assert_eq!(a.ty.to_string(), "bytes32"); - - assert_eq!(b.name, "OrderType"); - assert_eq!(b.ty.to_string(), "uint256"); - - assert_eq!(c.name, "Side"); - assert_eq!(c.ty.to_string(), "bool"); - - rest[..8].iter().for_each(|item| assert!(matches!(item, Item::Struct(_)))); - - let last = &rest[8]; - assert!(rest[9..].is_empty()); - let Item::Function(f) = last else { panic!("{last:#?}") }; - assert_eq!(f.name.as_ref().unwrap(), "fulfillAvailableAdvancedOrders"); - assert!(f.attributes.contains(&ast::FunctionAttribute::Mutability( - ast::Mutability::Payable(Default::default()) - ))); - assert!(f.attributes.contains(&ast::FunctionAttribute::Visibility( - ast::Visibility::External(Default::default()) - ))); - - let args = &f.parameters; - assert_eq!(args.len(), 7); - - assert_eq!(args[0].ty.to_string(), "AdvancedOrder[]"); - assert_eq!(args[0].name.as_ref().unwrap(), "a"); - assert_eq!(args[1].ty.to_string(), "CriteriaResolver[]"); - assert_eq!(args[1].name.as_ref().unwrap(), "b"); - assert_eq!(args[2].ty.to_string(), "FulfillmentComponent[][]"); - assert_eq!(args[2].name.as_ref().unwrap(), "c"); - assert_eq!(args[3].ty.to_string(), "FulfillmentComponent[][]"); - assert_eq!(args[3].name.as_ref().unwrap(), "d"); - assert_eq!(args[4].ty.to_string(), "bytes32"); - assert_eq!(args[4].name.as_ref().unwrap(), "fulfillerConduitKey"); - assert_eq!(args[5].ty.to_string(), "address"); - assert_eq!(args[5].name.as_ref().unwrap(), "recipient"); - assert_eq!(args[6].ty.to_string(), "uint256"); - assert_eq!(args[6].name.as_ref().unwrap(), "maximumFulfilled"); - - let returns = &f.returns.as_ref().unwrap().returns; - assert_eq!(returns.len(), 2); - - assert_eq!(returns[0].ty.to_string(), "bool[]"); - assert_eq!(returns[0].name.as_ref().unwrap(), "e"); - assert_eq!(returns[1].ty.to_string(), "Execution[]"); - assert_eq!(returns[1].name.as_ref().unwrap(), "f"); - } - "EnumsInLibraryFunctions" => { - assert_eq!(c.name, "EnumsInLibraryFunctions"); - assert_eq!(c.body.len(), 5); - let [Item::Udt(the_enum), Item::Function(f_array), Item::Function(f_arrays), Item::Function(f_dyn_array), Item::Function(f_just_enum)] = - &c.body[..] - else { - panic!("{c:#?}"); - }; - - assert_eq!(the_enum.name, "TheEnum"); - assert_eq!(the_enum.ty.to_string(), "uint8"); - - let function_tests = [ - (f_array, "enumArray", "TheEnum[2]"), - (f_arrays, "enumArrays", "TheEnum[][69][]"), - (f_dyn_array, "enumDynArray", "TheEnum[]"), - (f_just_enum, "enum_", "TheEnum"), - ]; - for (f, name, ty) in function_tests { - assert_eq!(f.name.as_ref().unwrap(), name); - assert_eq!(f.parameters.type_strings().collect::>(), [ty]); - let ret = &f.returns.as_ref().expect("no returns").returns; - assert_eq!(ret.type_strings().collect::>(), [ty]); - } - } - _ => {} - } - } - - fn expand_test<'a>(s: &str, path: &'a str) -> (ast::ItemContract, &'a str) { let mut abi: JsonAbi = serde_json::from_str(s).unwrap(); let name = Path::new(path).file_stem().unwrap().to_str().unwrap(); @@ -231,7 +177,7 @@ mod tests { } }; - let ast = match syn::parse2::(tokens.clone()) { + let _ast = match syn::parse2::(tokens.clone()) { Ok(ast) => ast, Err(e) => { let spath = write_tmp_sol(name, &sol); @@ -242,18 +188,9 @@ mod tests { emitted tokens: {}", spath.display(), tpath.display(), - ) + ); } }; - - let mut items = ast.items.into_iter(); - let Some(Item::Contract(c)) = items.next() else { - panic!("first item is not a contract"); - }; - let next = items.next(); - assert!(next.is_none(), "AST does not contain exactly one item: {next:#?}, {items:#?}"); - assert!(!c.body.is_empty(), "generated contract is empty"); - (c, name) } fn write_tmp_sol(name: &str, contents: &str) -> PathBuf { diff --git a/crates/sol-macro/doctests/json.rs b/crates/sol-macro/doctests/json.rs index 5e5aa7f1a..4ccfcf54d 100644 --- a/crates/sol-macro/doctests/json.rs +++ b/crates/sol-macro/doctests/json.rs @@ -7,7 +7,7 @@ sol!( "inputs": [ { "name": "bar", "type": "uint256" }, { - "internalType": "struct MyJsonContract.MyStruct", + "internalType": "struct MyStruct", "name": "baz", "type": "tuple", "components": [ @@ -38,5 +38,5 @@ sol! { #[test] fn abigen() { - assert_eq!(MyJsonContract1::fooCall::SIGNATURE, MyJsonContract2::fooCall::SIGNATURE,); + assert_eq!(MyJsonContract1::fooCall::SIGNATURE, MyJsonContract2::fooCall::SIGNATURE); } diff --git a/crates/sol-type-parser/src/state_mutability.rs b/crates/sol-type-parser/src/state_mutability.rs index ed4fc109a..3dc34f495 100644 --- a/crates/sol-type-parser/src/state_mutability.rs +++ b/crates/sol-type-parser/src/state_mutability.rs @@ -1,6 +1,7 @@ #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +#[cfg(feature = "serde")] const COMPAT_ERROR: &str = "state mutability cannot be both `payable` and `constant`"; /// A JSON ABI function's state mutability. diff --git a/crates/sol-types/Cargo.toml b/crates/sol-types/Cargo.toml index f5c6894ff..7f1d4e667 100644 --- a/crates/sol-types/Cargo.toml +++ b/crates/sol-types/Cargo.toml @@ -51,7 +51,7 @@ trybuild = "1.0" [features] default = ["std"] -std = ["alloy-primitives/std", "hex/std", "serde?/std"] +std = ["alloy-primitives/std", "hex/std", "alloy-json-abi?/std", "serde?/std"] json = ["dep:alloy-json-abi", "alloy-sol-macro/json"] eip712-serde = ["dep:serde", "alloy-primitives/serde"] arbitrary = ["alloy-primitives/arbitrary"] diff --git a/crates/sol-types/tests/macros/sol/json.rs b/crates/sol-types/tests/macros/sol/json.rs index 4f88e5837..af4041f98 100644 --- a/crates/sol-types/tests/macros/sol/json.rs +++ b/crates/sol-types/tests/macros/sol/json.rs @@ -69,17 +69,18 @@ fn seaport() { ); } +// https://etherscan.io/address/0x1111111254eeb25477b68fb85ed929f73a960582#code +sol!( + #[sol(docs = false)] + #[derive(Debug)] + AggregationRouterV5, + "../json-abi/tests/abi/AggregationRouterV5.json" +); + // Handle multiple identical error objects in the JSON ABI // https://github.com/alloy-rs/core/issues/344 #[test] fn aggregation_router_v5() { - // https://etherscan.io/address/0x1111111254eeb25477b68fb85ed929f73a960582#code - sol!( - #[sol(docs = false)] - AggregationRouterV5, - "../json-abi/tests/abi/AggregationRouterV5.json" - ); - assert_eq!( ::SIGNATURE, "ETHTransferFailed()" @@ -142,11 +143,12 @@ fn uniswap_v2_factory() { }; } +sol!(GnosisSafe, "../json-abi/tests/abi/GnosisSafe.json"); + // Fully qualify `SolInterface::NAME` which conflicted with the `NAME` call // https://github.com/alloy-rs/core/issues/361 #[test] fn gnosis_safe() { - sol!(GnosisSafe, "../json-abi/tests/abi/GnosisSafe.json"); let GnosisSafe::NAMECall {} = GnosisSafe::NAMECall {}; let GnosisSafe::NAMEReturn { _0: _ } = GnosisSafe::NAMEReturn { _0: String::new() }; } @@ -202,13 +204,17 @@ fn zrx_token() { assert_eq!(ZRXToken::approveCall::SIGNATURE, "approve(address,uint256)"); } +// https://etherscan.io/address/0xBA12222222228d8Ba445958a75a0704d566BF2C8#code +sol!( + #![sol(all_derives)] + BalancerV2Vault, + "../json-abi/tests/abi/BalancerV2Vault.json" +); + // Handle contract **array** types in JSON ABI // https://github.com/alloy-rs/core/issues/585 #[test] fn balancer_v2_vault() { - // https://etherscan.io/address/0xBA12222222228d8Ba445958a75a0704d566BF2C8#code - sol!(BalancerV2Vault, "../json-abi/tests/abi/BalancerV2Vault.json"); - let _ = BalancerV2Vault::PoolBalanceChanged { poolId: B256::ZERO, liquidityProvider: Address::ZERO, diff --git a/crates/sol-types/tests/macros/sol/mod.rs b/crates/sol-types/tests/macros/sol/mod.rs index 498901846..ccecedb0d 100644 --- a/crates/sol-types/tests/macros/sol/mod.rs +++ b/crates/sol-types/tests/macros/sol/mod.rs @@ -925,3 +925,77 @@ fn contract_derive_default() { let MyContract::e2 {} = MyContract::e2::default(); let MyContract::c {} = MyContract::c::default(); } + +#[test] +fn contract_namespaces() { + mod inner { + alloy_sol_types::sol! { + library LibA { + struct Struct { + uint64 field64; + } + } + + library LibB { + struct Struct { + uint128 field128; + } + } + + contract Contract { + LibA.Struct internal aValue; + LibB.Struct internal bValue; + + constructor( + LibA.Struct memory aValue_, + LibB.Struct memory bValue_ + ) + { + aValue = aValue_; + bValue = bValue_; + } + + function fn( + LibA.Struct memory aValue_, + LibB.Struct memory bValue_ + ) public + { + aValue = aValue_; + bValue = bValue_; + } + } + } + } + + let _ = inner::Contract::fnCall { + aValue_: inner::LibA::Struct { field64: 0 }, + bValue_: inner::LibB::Struct { field128: 0 }, + }; + assert_eq!(inner::Contract::fnCall::SIGNATURE, "fn((uint64),(uint128))"); +} + +// https://github.com/alloy-rs/core/pull/694#issuecomment-2274263880 +#[test] +fn regression_overloads() { + sol! { + contract Vm { + struct Wallet { + uint stuff; + } + + /// Gets the nonce of an account. + function getNonce(address account) external view returns (uint64 nonce); + + /// Get the nonce of a `Wallet`. + function getNonce(Wallet calldata wallet) external returns (uint64 nonce); + } + } + + let _ = Vm::getNonce_0Call { account: Address::ZERO }; + let _ = Vm::getNonce_0Return { nonce: 0 }; + assert_eq!(Vm::getNonce_0Call::SIGNATURE, "getNonce(address)"); + + let _ = Vm::getNonce_1Call { wallet: Vm::Wallet { stuff: U256::ZERO } }; + let _ = Vm::getNonce_1Return { nonce: 0 }; + assert_eq!(Vm::getNonce_1Call::SIGNATURE, "getNonce((uint256))"); +}