Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add gossip and eviction to the mempool #200

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
12 changes: 6 additions & 6 deletions src/consensus/proposer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use prost::Message;
use std::collections::BTreeMap;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::mpsc;
use tokio::sync::{broadcast, mpsc};
use tokio::time::Instant;
use tokio::{select, time};
use tonic::Request;
Expand Down Expand Up @@ -55,7 +55,7 @@ pub struct ShardProposer {
shard_id: SnapchainShard,
address: Address,
proposed_chunks: BTreeMap<ShardHash, FullProposal>,
tx_decision: mpsc::Sender<ShardChunk>,
tx_decision: broadcast::Sender<ShardChunk>,
engine: ShardEngine,
propose_value_delay: Duration,
statsd_client: StatsdClientWrapper,
Expand All @@ -67,7 +67,7 @@ impl ShardProposer {
shard_id: SnapchainShard,
engine: ShardEngine,
statsd_client: StatsdClientWrapper,
tx_decision: mpsc::Sender<ShardChunk>,
tx_decision: broadcast::Sender<ShardChunk>,
propose_value_delay: Duration,
) -> ShardProposer {
ShardProposer {
Expand All @@ -82,7 +82,7 @@ impl ShardProposer {
}

async fn publish_new_shard_chunk(&self, shard_chunk: &ShardChunk) {
let _ = &self.tx_decision.send(shard_chunk.clone()).await;
let _ = &self.tx_decision.send(shard_chunk.clone());
}
}

Expand Down Expand Up @@ -241,7 +241,7 @@ pub struct BlockProposer {
address: Address,
proposed_blocks: BTreeMap<ShardHash, FullProposal>,
pending_chunks: BTreeMap<u64, Vec<ShardChunk>>,
shard_decision_rx: mpsc::Receiver<ShardChunk>,
shard_decision_rx: broadcast::Receiver<ShardChunk>,
num_shards: u32,
block_tx: Option<mpsc::Sender<Block>>,
engine: BlockEngine,
Expand All @@ -252,7 +252,7 @@ impl BlockProposer {
pub fn new(
address: Address,
shard_id: SnapchainShard,
shard_decision_rx: mpsc::Receiver<ShardChunk>,
shard_decision_rx: broadcast::Receiver<ShardChunk>,
num_shards: u32,
block_tx: Option<mpsc::Sender<Block>>,
engine: BlockEngine,
Expand Down
19 changes: 15 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::net::SocketAddr;
use std::process;
use std::time::Duration;
use tokio::signal::ctrl_c;
use tokio::sync::mpsc;
use tokio::sync::{broadcast, mpsc};
use tokio::{select, time};
use tonic::transport::Server;
use tracing::{error, info, warn};
Expand Down Expand Up @@ -105,9 +105,14 @@ async fn main() -> Result<(), Box<dyn Error>> {
);

let (system_tx, mut system_rx) = mpsc::channel::<SystemMessage>(100);
let (mempool_tx, mempool_rx) = mpsc::channel(app_config.mempool.queue_size as usize);

let gossip_result =
SnapchainGossip::create(keypair.clone(), app_config.gossip, system_tx.clone());
let gossip_result = SnapchainGossip::create(
keypair.clone(),
app_config.gossip,
system_tx.clone(),
mempool_tx.clone(),
);

if let Err(e) = gossip_result {
error!(error = ?e, "Failed to create SnapchainGossip");
Expand All @@ -130,11 +135,14 @@ async fn main() -> Result<(), Box<dyn Error>> {
let _ = Metrics::register(registry);

let (messages_request_tx, messages_request_rx) = mpsc::channel(100);
let (shard_decision_tx, shard_decision_rx) = broadcast::channel(100);

let node = SnapchainNode::create(
keypair.clone(),
app_config.consensus.clone(),
Some(app_config.rpc_address.clone()),
gossip_tx.clone(),
shard_decision_tx,
None,
messages_request_tx,
block_store.clone(),
Expand All @@ -144,12 +152,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
)
.await;

let (mempool_tx, mempool_rx) = mpsc::channel(app_config.mempool.queue_size as usize);
let mut mempool = Mempool::new(
1024,
mempool_rx,
messages_request_rx,
app_config.consensus.num_shards,
node.shard_stores.clone(),
gossip_tx.clone(),
shard_decision_rx,
statsd_client.clone(),
);
tokio::spawn(async move { mempool.run().await });

Expand Down
174 changes: 148 additions & 26 deletions src/mempool/mempool.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
use std::collections::{BTreeMap, HashMap};

use serde::{Deserialize, Serialize};
use tokio::{
sync::{mpsc, oneshot},
time::Instant,
};
use tokio::sync::{broadcast, mpsc, oneshot};

use crate::storage::{
db::RocksDbTransactionBatch,
store::{
account::{
get_message_by_key, make_message_primary_key, make_ts_hash, type_to_set_postfix,
UserDataStore,
use crate::{
core::types::SnapchainValidatorContext,
network::gossip::GossipEvent,
proto::{self, ShardChunk},
storage::{
db::RocksDbTransactionBatch,
store::{
account::{
get_message_by_key, make_message_primary_key, make_ts_hash, type_to_set_postfix,
UserDataStore,
},
engine::MempoolMessage,
stores::Stores,
},
engine::MempoolMessage,
stores::Stores,
},
utils::statsd_wrapper::StatsdClientWrapper,
};

use super::routing::{MessageRouter, ShardRouter};
use tracing::error;
use tracing::{error, warn};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
Expand All @@ -34,7 +37,57 @@ impl Default for Config {

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct MempoolKey {
inserted_at: Instant,
timestamp: u64,
identity: String,
}

impl MempoolKey {
pub fn new(timestamp: u64, identity: String) -> Self {
MempoolKey {
timestamp,
identity,
}
}

pub fn identity(self) -> String {
self.identity
}
}

impl proto::Message {
pub fn mempool_key(&self) -> MempoolKey {
if let Some(data) = &self.data {
// TODO: Consider revisiting choice of timestamp here as backdated messages currently are prioritized.
return MempoolKey::new(data.timestamp as u64, self.hex_hash());
aditiharini marked this conversation as resolved.
Show resolved Hide resolved
}
todo!();
}
}

impl proto::ValidatorMessage {
pub fn mempool_key(&self) -> MempoolKey {
if let Some(fname) = &self.fname_transfer {
if let Some(proof) = &fname.proof {
return MempoolKey::new(proof.timestamp, fname.id.to_string());
}
}
if let Some(event) = &self.on_chain_event {
return MempoolKey::new(
event.block_timestamp,
hex::encode(&event.transaction_hash) + &event.log_index.to_string(),
);
}
todo!();
}
}

impl MempoolMessage {
pub fn mempool_key(&self) -> MempoolKey {
match self {
MempoolMessage::UserMessage(msg) => msg.mempool_key(),
MempoolMessage::ValidatorMessage(msg) => msg.mempool_key(),
}
}
}

pub struct MempoolMessagesRequest {
Expand All @@ -44,28 +97,40 @@ pub struct MempoolMessagesRequest {
}

pub struct Mempool {
capacity_per_shard: usize,
shard_stores: HashMap<u32, Stores>,
message_router: Box<dyn MessageRouter>,
num_shards: u32,
mempool_rx: mpsc::Receiver<MempoolMessage>,
messages_request_rx: mpsc::Receiver<MempoolMessagesRequest>,
messages: HashMap<u32, BTreeMap<MempoolKey, MempoolMessage>>,
gossip_tx: mpsc::Sender<GossipEvent<SnapchainValidatorContext>>,
shard_decision_rx: broadcast::Receiver<ShardChunk>,
statsd_client: StatsdClientWrapper,
}

impl Mempool {
pub fn new(
capacity_per_shard: usize,
mempool_rx: mpsc::Receiver<MempoolMessage>,
messages_request_rx: mpsc::Receiver<MempoolMessagesRequest>,
num_shards: u32,
shard_stores: HashMap<u32, Stores>,
gossip_tx: mpsc::Sender<GossipEvent<SnapchainValidatorContext>>,
shard_decision_rx: broadcast::Receiver<ShardChunk>,
statsd_client: StatsdClientWrapper,
) -> Self {
Mempool {
capacity_per_shard,
shard_stores,
num_shards,
mempool_rx,
message_router: Box::new(ShardRouter {}),
messages: HashMap::new(),
messages_request_rx,
gossip_tx,
shard_decision_rx,
statsd_client,
}
}

Expand Down Expand Up @@ -160,6 +225,60 @@ impl Mempool {
return true;
}

async fn insert(&mut self, message: MempoolMessage) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's emit some stats here, number of inserts tagged by success/failure and message type, and a guage for the current size of the mempool by shard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you want additional stats or tweaks to the formatting

// TODO(aditi): Maybe we don't need to run validations here?
if self.message_is_valid(&message) {
let fid = message.fid();
let shard_id = self.message_router.route_message(fid, self.num_shards);
// TODO(aditi): We need a size limit on the mempool and we need to figure out what to do if it's exceeded
match self.messages.get_mut(&shard_id) {
None => {
let mut messages = BTreeMap::new();
messages.insert(message.mempool_key(), message.clone());
self.messages.insert(shard_id, messages);
self.statsd_client
.gauge_with_shard(shard_id, "mempool.size", 1);
}
Some(messages) => {
if messages.len() >= self.capacity_per_shard {
aditiharini marked this conversation as resolved.
Show resolved Hide resolved
// For now, mempool messages are dropped here if the mempool is full.
warn!(
fid = message.fid(),
identity = message.mempool_key().identity(),
"Message dropped due to mempool being over capacity"
);
return;
}
messages.insert(message.mempool_key(), message.clone());
self.statsd_client.gauge_with_shard(
shard_id,
"mempool.size",
messages.len() as u64,
);
}
}

self.statsd_client
.count_with_shard(shard_id, "mempool.insert.success", 1);

match message {
MempoolMessage::UserMessage(_) => {
let result = self
.gossip_tx
.send(GossipEvent::BroadcastMempoolMessage(message))
.await;

if let Err(e) = result {
warn!("Failed to gossip message {:?}", e);
}
}
_ => {}
}
} else {
self.statsd_client.count("mempool.insert.failure", 1);
}
}

pub async fn run(&mut self) {
loop {
tokio::select! {
Expand All @@ -172,19 +291,22 @@ impl Mempool {
}
message = self.mempool_rx.recv() => {
if let Some(message) = message {
// TODO(aditi): Maybe we don't need to run validations here?
if self.message_is_valid(&message) {
let fid = message.fid();
let shard_id = self.message_router.route_message(fid, self.num_shards);
// TODO(aditi): We need a size limit on the mempool and we need to figure out what to do if it's exceeded
match self.messages.get_mut(&shard_id) {
None => {
let mut messages = BTreeMap::new();
messages.insert(MempoolKey { inserted_at: Instant::now()}, message.clone());
self.messages.insert(shard_id, messages);
self.insert(message).await;
}
}
chunk = self.shard_decision_rx.recv() => {
aditiharini marked this conversation as resolved.
Show resolved Hide resolved
if let Ok(chunk) = chunk {
let header = chunk.header.expect("Expects chunk to have a header");
let height = header.height.expect("Expects header to have a height");
if let Some(mempool) = self.messages.get_mut(&height.shard_index) {
for transaction in chunk.transactions {
for user_message in transaction.user_messages {
mempool.remove(&user_message.mempool_key());
self.statsd_client.count_with_shard(height.shard_index, "mempool.remove.success", 1);
}
Some(messages) => {
messages.insert(MempoolKey { inserted_at: Instant::now()}, message.clone());
for system_message in transaction.system_messages {
mempool.remove(&system_message.mempool_key());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should record the same stats as above when removing from the mempool as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.statsd_client.count_with_shard(height.shard_index, "mempool.remove.success", 1);
}
}
}
Expand Down
Loading