Skip to content

Commit

Permalink
feat(error): Add custom error handling to improve robustness
Browse files Browse the repository at this point in the history
Introduce `Error` enum for consistent error management and integrate `thiserror` for better error descriptions. Update functions to use the new error handling, replacing `anyhow::Result` with `Result<T, Error>`, and refactor tests to unwrap results for simplicity.
  • Loading branch information
shuhuiluo committed Sep 13, 2024
1 parent 47d1d97 commit 7317571
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 94 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ include = ["src/**/*.rs"]
[dependencies]
alloy = { version = "0.3", features = ["contract", "rpc-types"] }
anyhow = "1"
thiserror = { version = "1.0", optional = true }

[features]
default = []
std = ["alloy/std"]
std = ["alloy/std", "thiserror"]

[dev-dependencies]
alloy = { version = "0.3", features = ["transport-http"] }
Expand Down
16 changes: 9 additions & 7 deletions src/caller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ macro_rules! call_ephemeral_contract {
None => $deploy_builder,
};
match deploy_builder.call_raw().await {
Err(Error::TransportError(err)) => match err {
TransportError::ErrorResp(payload) => {
let data: Bytes = payload.as_revert_data().unwrap();
Ok(<$call_type>::abi_decode_returns(data.as_ref(), true).unwrap())
Err(ContractError::TransportError(TransportError::ErrorResp(payload))) => {
match payload.as_revert_data() {
Some(data) => Ok(<$call_type as SolCall>::abi_decode_returns(
data.as_ref(),
true,
)?),
None => Err(Error::InvalidRevertData),
}
_ => panic!("should be an error response: {:?}", err),
},
Err(err) => Err(err),
}
Err(err) => Err(Error::ContractError(err)),
Ok(_) => panic!("deployment should revert"),
}
}};
Expand Down
31 changes: 31 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use alloy::{contract::Error as ContractError, sol_types::Error as AbiError};

#[derive(Debug)]
#[cfg_attr(feature = "std", derive(thiserror::Error))]
pub enum Error {
/// An error occurred retrieving the revert data.
#[cfg_attr(feature = "std", error("Invalid revert data"))]
InvalidRevertData,

/// An error occurred ABI encoding or decoding.
#[cfg_attr(feature = "std", error("{0}"))]
AbiError(AbiError),

/// An error occurred interacting with a contract over RPC.
#[cfg_attr(feature = "std", error("{0}"))]
ContractError(ContractError),
}

impl From<AbiError> for Error {
#[inline]
fn from(e: AbiError) -> Self {
Self::AbiError(e)
}
}

impl From<ContractError> for Error {
#[inline]
fn from(e: ContractError) -> Self {
Self::ContractError(e)
}
}
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ extern crate alloc;
#[allow(warnings)]
pub mod bindings;
pub mod caller;
pub mod error;
pub mod pool_lens;
pub mod position_lens;
pub mod storage_lens;
Expand All @@ -28,5 +29,5 @@ pub mod storage_lens;
mod tests;

pub mod prelude {
pub use super::{bindings::*, pool_lens::*, position_lens::*, storage_lens::*};
pub use super::{error::Error, pool_lens::*, position_lens::*, storage_lens::*};
}
78 changes: 40 additions & 38 deletions src/pool_lens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@
use crate::{
bindings::{
ephemeralgetpopulatedticksinrange::{
EphemeralGetPopulatedTicksInRange,
EphemeralGetPopulatedTicksInRange::{
getPopulatedTicksInRangeCall, getPopulatedTicksInRangeReturn,
EphemeralGetPopulatedTicksInRangeInstance,
},
PoolUtils::PopulatedTick,
},
ephemeralpoolpositions::{
EphemeralPoolPositions::EphemeralPoolPositionsInstance, PoolUtils::PositionKey,
},
ephemeralpoolpositions::{EphemeralPoolPositions, PoolUtils::PositionKey},
ephemeralpoolslots::{
EphemeralPoolSlots::{getSlotsCall, getSlotsReturn, EphemeralPoolSlotsInstance},
EphemeralPoolSlots,
EphemeralPoolSlots::{getSlotsCall, getSlotsReturn},
PoolUtils::Slot,
},
ephemeralpooltickbitmap::EphemeralPoolTickBitmap::EphemeralPoolTickBitmapInstance,
ephemeralpoolticks::EphemeralPoolTicks::EphemeralPoolTicksInstance,
ephemeralpooltickbitmap::EphemeralPoolTickBitmap,
ephemeralpoolticks::EphemeralPoolTicks,
},
call_ephemeral_contract,
error::Error,
};
use alloc::vec::Vec;
use alloy::{
contract::Error,
contract::Error as ContractError,
eips::BlockId,
primitives::{aliases::I24, Address, Bytes},
primitives::{aliases::I24, Address},
providers::Provider,
sol_types::SolCall,
transports::{Transport, TransportError},
Expand All @@ -54,14 +54,13 @@ pub async fn get_populated_ticks_in_range<T, P>(
tick_upper: I24,
provider: P,
block_id: Option<BlockId>,
) -> Result<(Vec<PopulatedTick>, I24)>
) -> Result<(Vec<PopulatedTick>, I24), Error>
where
T: Transport + Clone,
P: Provider<T>,
{
let deploy_builder = EphemeralGetPopulatedTicksInRangeInstance::deploy_builder(
provider, pool, tick_lower, tick_upper,
);
let deploy_builder =
EphemeralGetPopulatedTicksInRange::deploy_builder(provider, pool, tick_lower, tick_upper);
match call_ephemeral_contract!(deploy_builder, getPopulatedTicksInRangeCall, block_id) {
Ok(getPopulatedTicksInRangeReturn {
populatedTicks,
Expand All @@ -73,7 +72,7 @@ where
.collect(),
tickSpacing,
)),
Err(err) => Err(err.into()),
Err(err) => Err(err),
}
}

Expand Down Expand Up @@ -103,15 +102,12 @@ pub async fn get_static_slots<T, P>(
pool: Address,
provider: P,
block_id: Option<BlockId>,
) -> Result<Vec<Slot>>
) -> Result<Vec<Slot>, Error>
where
T: Transport + Clone,
P: Provider<T>,
{
get_pool_storage!(
EphemeralPoolSlotsInstance::deploy_builder(provider, pool),
block_id
)
get_pool_storage!(EphemeralPoolSlots::deploy_builder(provider, pool), block_id)
}

/// Get the storage slots in the `ticks` mapping between `tick_lower` and `tick_upper`.
Expand All @@ -134,13 +130,13 @@ pub async fn get_ticks_slots<T, P>(
tick_upper: I24,
provider: P,
block_id: Option<BlockId>,
) -> Result<Vec<Slot>>
) -> Result<Vec<Slot>, Error>
where
T: Transport + Clone,
P: Provider<T>,
{
get_pool_storage!(
EphemeralPoolTicksInstance::deploy_builder(provider, pool, tick_lower, tick_upper),
EphemeralPoolTicks::deploy_builder(provider, pool, tick_lower, tick_upper),
block_id
)
}
Expand All @@ -160,13 +156,13 @@ pub async fn get_tick_bitmap_slots<T, P>(
pool: Address,
provider: P,
block_id: Option<BlockId>,
) -> Result<Vec<Slot>>
) -> Result<Vec<Slot>, Error>
where
T: Transport + Clone,
P: Provider<T>,
{
get_pool_storage!(
EphemeralPoolTickBitmapInstance::deploy_builder(provider, pool),
EphemeralPoolTickBitmap::deploy_builder(provider, pool),
block_id
)
}
Expand All @@ -188,13 +184,13 @@ pub async fn get_positions_slots<T, P>(
positions: Vec<PositionKey>,
provider: P,
block_id: Option<BlockId>,
) -> Result<Vec<Slot>>
) -> Result<Vec<Slot>, Error>
where
T: Transport + Clone,
P: Provider<T>,
{
get_pool_storage!(
EphemeralPoolPositionsInstance::deploy_builder(provider, pool, positions),
EphemeralPoolPositions::deploy_builder(provider, pool, positions),
block_id
)
}
Expand All @@ -203,28 +199,35 @@ where
mod tests {
use super::*;
use crate::{
bindings::iuniswapv3pool::IUniswapV3Pool::{IUniswapV3PoolInstance, Mint},
bindings::iuniswapv3pool::{IUniswapV3Pool, IUniswapV3Pool::Mint},
tests::*,
};
use alloy::{primitives::address, rpc::types::Filter, sol_types::SolEvent};
use anyhow::Result;
use futures::future::join_all;

const POOL_ADDRESS: Address = address!("88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640");

#[tokio::test]
async fn test_get_populated_ticks_in_range() -> Result<()> {
async fn test_get_populated_ticks_in_range() {
let provider = PROVIDER.clone();
let pool = IUniswapV3PoolInstance::new(POOL_ADDRESS, provider.clone());
let tick_current = pool.slot0().block(BLOCK_NUMBER).call().await?.tick;
let tick_spacing = pool.tickSpacing().block(BLOCK_NUMBER).call().await?._0;
let pool = IUniswapV3Pool::new(POOL_ADDRESS, provider.clone());
let tick_current = pool.slot0().block(BLOCK_NUMBER).call().await.unwrap().tick;
let tick_spacing = pool
.tickSpacing()
.block(BLOCK_NUMBER)
.call()
.await
.unwrap()
._0;
let (ticks, _) = get_populated_ticks_in_range(
POOL_ADDRESS,
tick_current,
tick_current + (tick_spacing << 8),
provider,
Some(BLOCK_NUMBER),
)
.await?;
.await
.unwrap();
assert!(!ticks.is_empty());
// let mut multicall = Multicall::new(client.clone(), None).await?;
// multicall.add_calls(
Expand Down Expand Up @@ -254,7 +257,6 @@ mod tests {
// assert_eq!(liquidity_gross, _liquidity_gross);
// assert_eq!(liquidity_net, _liquidity_net);
// }
Ok(())
}

async fn verify_slots<T, P>(slots: Vec<Slot>, provider: P)
Expand Down Expand Up @@ -287,7 +289,7 @@ mod tests {
#[tokio::test]
async fn test_get_ticks_slots() {
let provider = PROVIDER.clone();
let pool = IUniswapV3PoolInstance::new(POOL_ADDRESS, provider.clone());
let pool = IUniswapV3Pool::new(POOL_ADDRESS, provider.clone());
let tick_current = pool.slot0().block(BLOCK_NUMBER).call().await.unwrap().tick;
let slots = get_ticks_slots(
POOL_ADDRESS,
Expand All @@ -311,14 +313,14 @@ mod tests {
}

#[tokio::test]
async fn test_get_positions_slots() -> Result<()> {
async fn test_get_positions_slots() {
let provider = PROVIDER.clone();
// create a filter to get the mint events
let filter = Filter::new()
.from_block(BLOCK_NUMBER.as_u64().unwrap() - 10000)
.to_block(BLOCK_NUMBER.as_u64().unwrap())
.event_signature(<Mint as SolEvent>::SIGNATURE_HASH);
let logs = provider.get_logs(&filter).await?;
let logs = provider.get_logs(&filter).await.unwrap();
// decode the logs into position keys
let positions: Vec<_> = logs
.iter()
Expand All @@ -343,8 +345,8 @@ mod tests {
provider.clone(),
Some(BLOCK_NUMBER),
)
.await?;
.await
.unwrap();
verify_slots(slots, provider).await;
Ok(())
}
}
Loading

0 comments on commit 7317571

Please sign in to comment.