Skip to content

Commit

Permalink
refactor: use EntityInfos for PoolOperation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-miao committed Jan 25, 2024
1 parent ec66719 commit d83e4c2
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 54 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

95 changes: 42 additions & 53 deletions crates/pool/src/mempool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ use ethers::types::{Address, H256, U256};
use mockall::automock;
use rundler_sim::{EntityInfos, MempoolConfig, PrecheckSettings, SimulationSettings};
use rundler_types::{Entity, EntityType, EntityUpdate, UserOperation, ValidTimeRange};
use strum::IntoEnumIterator;
use tonic::async_trait;
pub(crate) use uo_pool::UoPool;

Expand Down Expand Up @@ -215,9 +214,11 @@ pub struct PaymasterMetadata {
impl PoolOperation {
/// Returns true if the operation contains the given entity.
pub fn contains_entity(&self, entity: &Entity) -> bool {
self.entity_address(entity.kind)
.map(|address| address == entity.address)
.unwrap_or(false)
if let Some(e) = self.entity_infos.get(entity.kind) {
e.address == entity.address
} else {
false
}
}

/// Returns true if the operation requires the given entity to stake.
Expand All @@ -239,48 +240,31 @@ impl PoolOperation {

/// Returns an iterator over all entities that are included in this operation.
pub fn entities(&'_ self) -> impl Iterator<Item = Entity> + '_ {
EntityType::iter().filter_map(|entity| {
self.entity_address(entity)
.map(|address| Entity::new(entity, address))
})
self.entity_infos
.entities()
.map(|(t, entity)| Entity::new(t, entity.address))
}

/// Returns an iterator over all entities that need stake in this operation.
/// Returns an iterator over all entities that need stake in this operation. This can be a subset of entities that are staked in the operation.
pub fn entities_requiring_stake(&'_ self) -> impl Iterator<Item = Entity> + '_ {
EntityType::iter()
.filter(|entity| self.requires_stake(*entity))
.filter_map(|entity| {
self.entity_address(entity)
.map(|address| Entity::new(entity, address))
})
self.entity_infos.entities().filter_map(|(t, entity)| {
if self.requires_stake(t) {
Entity::new(t, entity.address).into()
} else {
None
}
})
}

/// Return all the unstaked entities that are used in this operation.
pub fn unstaked_entities(&'_ self) -> impl Iterator<Item = Entity> + '_ {
let mut unstaked_entities = vec![];
if !self.entity_infos.sender.is_staked {
unstaked_entities.push(Entity::new(
EntityType::Account,
self.entity_infos.sender.address,
))
}
if let Some(factory) = self.entity_infos.factory {
if !factory.is_staked {
unstaked_entities.push(Entity::new(EntityType::Factory, factory.address))
self.entity_infos.entities().filter_map(|(t, entity)| {
if entity.is_staked {
None
} else {
Entity::new(t, entity.address).into()
}
}
if let Some(paymaster) = self.entity_infos.paymaster {
if !paymaster.is_staked {
unstaked_entities.push(Entity::new(EntityType::Paymaster, paymaster.address))
}
}
if let Some(aggregator) = self.entity_infos.aggregator {
if !aggregator.is_staked {
unstaked_entities.push(Entity::new(EntityType::Aggregator, aggregator.address))
}
}

unstaked_entities.into_iter()
})
}

/// Compute the amount of heap memory the PoolOperation takes up.
Expand All @@ -289,19 +273,12 @@ impl PoolOperation {
+ self.uo.heap_size()
+ self.entities_needing_stake.len() * std::mem::size_of::<EntityType>()
}

fn entity_address(&self, entity: EntityType) -> Option<Address> {
match entity {
EntityType::Account => Some(self.uo.sender),
EntityType::Paymaster => self.uo.paymaster(),
EntityType::Factory => self.uo.factory(),
EntityType::Aggregator => self.aggregator,
}
}
}

#[cfg(test)]
mod tests {
use rundler_sim::EntityInfo;

use super::*;

#[test]
Expand All @@ -326,19 +303,31 @@ mod tests {
sim_block_number: 0,
entities_needing_stake: vec![EntityType::Account, EntityType::Aggregator],
account_is_staked: true,
entity_infos: EntityInfos::default(),
entity_infos: EntityInfos {
factory: Some(EntityInfo {
address: factory,
is_staked: false,
}),
sender: EntityInfo {
address: sender,
is_staked: false,
},
paymaster: Some(EntityInfo {
address: paymaster,
is_staked: false,
}),
aggregator: Some(EntityInfo {
address: aggregator,
is_staked: false,
}),
},
};

assert!(po.requires_stake(EntityType::Account));
assert!(!po.requires_stake(EntityType::Paymaster));
assert!(!po.requires_stake(EntityType::Factory));
assert!(po.requires_stake(EntityType::Aggregator));

assert_eq!(po.entity_address(EntityType::Account), Some(sender));
assert_eq!(po.entity_address(EntityType::Paymaster), Some(paymaster));
assert_eq!(po.entity_address(EntityType::Factory), Some(factory));
assert_eq!(po.entity_address(EntityType::Aggregator), Some(aggregator));

let entities = po.entities().collect::<Vec<_>>();
assert_eq!(entities.len(), 4);
for e in entities {
Expand Down
40 changes: 40 additions & 0 deletions crates/pool/src/mempool/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,8 @@ impl PoolMetrics {

#[cfg(test)]
mod tests {
use rundler_sim::{EntityInfo, EntityInfos};

use super::*;

#[test]
Expand Down Expand Up @@ -784,6 +786,10 @@ mod tests {
];
for mut op in ops.into_iter() {
op.aggregator = Some(agg);
op.entity_infos.aggregator = Some(EntityInfo {
address: agg,
is_staked: false,
});
pool.add_operation(op.clone(), None).unwrap();
}
assert_eq!(pool.by_hash.len(), 3);
Expand All @@ -805,6 +811,10 @@ mod tests {
];
for mut op in ops.into_iter() {
op.uo.paymaster_and_data = paymaster.as_bytes().to_vec().into();
op.entity_infos.paymaster = Some(EntityInfo {
address: op.uo.paymaster().unwrap(),
is_staked: false,
});
pool.add_operation(op.clone(), None).unwrap();
}
assert_eq!(pool.by_hash.len(), 3);
Expand Down Expand Up @@ -839,8 +849,20 @@ mod tests {

let mut op = create_op(sender, 0, 1);
op.uo.paymaster_and_data = paymaster.as_bytes().to_vec().into();
op.entity_infos.paymaster = Some(EntityInfo {
address: op.uo.paymaster().unwrap(),
is_staked: false,
});
op.uo.init_code = factory.as_bytes().to_vec().into();
op.entity_infos.factory = Some(EntityInfo {
address: op.uo.factory().unwrap(),
is_staked: false,
});
op.aggregator = Some(aggregator);
op.entity_infos.aggregator = Some(EntityInfo {
address: aggregator,
is_staked: false,
});

let count = 5;
let mut hashes = vec![];
Expand Down Expand Up @@ -937,13 +959,21 @@ mod tests {
let mut po1 = create_op(sender, 0, 10);
po1.uo.max_priority_fee_per_gas = 10.into();
po1.uo.paymaster_and_data = paymaster1.as_bytes().to_vec().into();
po1.entity_infos.paymaster = Some(EntityInfo {
address: po1.uo.paymaster().unwrap(),
is_staked: false,
});
let _ = pool.add_operation(po1, None).unwrap();
assert_eq!(pool.address_count(&paymaster1), 1);

let paymaster2 = Address::random();
let mut po2 = create_op(sender, 0, 11);
po2.uo.max_priority_fee_per_gas = 11.into();
po2.uo.paymaster_and_data = paymaster2.as_bytes().to_vec().into();
po2.entity_infos.paymaster = Some(EntityInfo {
address: po2.uo.paymaster().unwrap(),
is_staked: false,
});
let _ = pool.add_operation(po2.clone(), None).unwrap();

assert_eq!(pool.address_count(&sender), 1);
Expand Down Expand Up @@ -1038,8 +1068,18 @@ mod tests {
sender,
nonce: nonce.into(),
max_fee_per_gas: max_fee_per_gas.into(),

..UserOperation::default()
},
entity_infos: EntityInfos {
factory: None,
sender: EntityInfo {
address: sender,
is_staked: false,
},
paymaster: None,
aggregator: None,
},
..PoolOperation::default()
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/sim/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ reqwest.workspace = true
tokio = { workspace = true, features = ["macros"] }
tracing.workspace = true
url.workspace = true
strum.workspace = true

mockall = {workspace = true, optional = true }

Expand Down
9 changes: 8 additions & 1 deletion crates/sim/src/simulation/simulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use rundler_types::{
contracts::i_entry_point::FailedOp, Entity, EntityType, StorageSlot, UserOperation,
ValidTimeRange,
};
use strum::IntoEnumIterator;

use super::{
mempool::{match_mempools, AllowEntity, AllowRule, MempoolConfig, MempoolMatchResult},
Expand Down Expand Up @@ -798,6 +799,11 @@ impl EntityInfos {
}
}

/// Get iterator over the entities
pub fn entities(&'_ self) -> impl Iterator<Item = (EntityType, EntityInfo)> + '_ {
EntityType::iter().filter_map(|t| self.get(t).map(|info| (t, info)))
}

fn override_is_staked(&mut self, allow_unstaked_addresses: &HashSet<Address>) {
if let Some(mut factory) = self.factory {
factory.override_is_staked(allow_unstaked_addresses)
Expand All @@ -811,7 +817,8 @@ impl EntityInfos {
}
}

fn get(self, entity: EntityType) -> Option<EntityInfo> {
/// Get the EntityInfo of a specific entity
pub fn get(self, entity: EntityType) -> Option<EntityInfo> {
match entity {
EntityType::Factory => self.factory,
EntityType::Account => Some(self.sender),
Expand Down

0 comments on commit d83e4c2

Please sign in to comment.