Skip to content

Commit

Permalink
WIP: shard routing
Browse files Browse the repository at this point in the history
  • Loading branch information
suurkivi committed Dec 6, 2024
1 parent 71a04c0 commit 6b45a74
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 41 deletions.
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
rpc_shard_stores,
rpc_shard_senders,
statsd_client.clone(),
app_config.consensus.num_shards,
);

let resp = Server::builder()
Expand Down
162 changes: 121 additions & 41 deletions src/network/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::HashMap;

use crate::core::error::HubError;
use crate::proto;
use crate::proto::hub_service_server::HubService;
Expand All @@ -12,6 +10,8 @@ use crate::storage::store::stores::{StoreLimits, Stores};
use crate::storage::store::BlockStore;
use crate::utils::statsd_wrapper::StatsdClientWrapper;
use hex::ToHex;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status};
Expand All @@ -21,7 +21,7 @@ pub struct MyHubService {
block_store: BlockStore,
shard_stores: HashMap<u32, Stores>,
shard_senders: HashMap<u32, Senders>,
message_tx: mpsc::Sender<MempoolMessage>,
num_shards: u32,
statsd_client: StatsdClientWrapper,
}

Expand All @@ -31,54 +31,71 @@ impl MyHubService {
shard_stores: HashMap<u32, Stores>,
shard_senders: HashMap<u32, Senders>,
statsd_client: StatsdClientWrapper,
num_shards: u32,
) -> Self {
// TODO(aditi): This logic will change once a mempool exists
let message_tx = shard_senders.get(&1u32).unwrap().messages_tx.clone();

Self {
block_store,
shard_senders,
shard_stores,
message_tx,
statsd_client,
num_shards,
}
}
}

#[tonic::async_trait]
impl HubService for MyHubService {
async fn submit_message(
async fn submit_message_internal(
&self,
request: Request<proto::Message>,
) -> Result<Response<proto::Message>, Status> {
let start_time = std::time::Instant::now();
message: proto::Message,
bypass_validation: bool,
) -> Result<proto::Message, Status> {
let fid = message.fid();
if fid == 0 {
return Err(Status::invalid_argument(
"no fid or invalid fid".to_string(),
));
}

let hash = request.get_ref().hash.encode_hex::<String>();
info!(hash, "Received call to [submit_message] RPC");
let dst_shard = route_message(fid, self.num_shards);

let message = request.into_inner();
let sender = match self.shard_senders.get(&dst_shard) {
Some(sender) => sender,
None => {
return Err(Status::invalid_argument(
"no shard sender for fid".to_string(),
))
}
};

let stores = self.shard_stores.get(&1u32).unwrap();
// TODO: This is a hack to get around the fact that self cannot be made mutable
let mut readonly_engine = ShardEngine::new(
stores.db.clone(),
stores.trie.clone(),
1,
StoreLimits::default(),
self.statsd_client.clone(),
100,
);
let result = readonly_engine.simulate_message(&message);
let stores = match self.shard_stores.get(&dst_shard) {
Some(sender) => sender,
None => {
return Err(Status::invalid_argument(
"no shard store for fid".to_string(),
))
}
};

if let Err(err) = result {
return Err(Status::invalid_argument(format!(
"Invalid message: {}",
err.to_string()
)));
if !bypass_validation {
// TODO: This is a hack to get around the fact that self cannot be made mutable
let mut readonly_engine = ShardEngine::new(
stores.db.clone(),
stores.trie.clone(),
1,
StoreLimits::default(),
self.statsd_client.clone(),
100,
);
let result = readonly_engine.simulate_message(&message);

if let Err(err) = result {
return Err(Status::invalid_argument(format!(
"Invalid message: {}",
err.to_string()
)));
}
}

let result = self
.message_tx
let result = sender
.messages_tx
.send(MempoolMessage::UserMessage(message.clone()))
.await;

Expand All @@ -94,14 +111,77 @@ impl HubService for MyHubService {
}
}

let elapsed = start_time.elapsed().as_millis();
Ok(message)
}
}

// TODO: find a better place for this
fn route_message(fid: u32, num_shards: u32) -> u32 {
let hash = Sha256::digest(fid.to_be_bytes());
let hash_u32 = u32::from_be_bytes(hash[..4].try_into().unwrap());
(hash_u32 % num_shards) + 1
}

#[tonic::async_trait]
impl HubService for MyHubService {
async fn submit_message_with_options(
&self,
request: Request<proto::SubmitMessageRequest>,
) -> Result<Response<proto::SubmitMessageResponse>, Status> {
let start_time = std::time::Instant::now();

let hash = request
.get_ref()
.message
.as_ref()
.map(|msg| msg.hash.encode_hex::<String>())
.unwrap_or_default();
info!(%hash, "Received call to [submit_message_with_options] RPC");

let proto::SubmitMessageRequest {
message,
bypass_validation,
} = request.into_inner();

let message = match message {
Some(msg) => msg,
None => return Err(Status::invalid_argument("Message is required")),
};

let response = Response::new(message);
let response_message = self
.submit_message_internal(message, bypass_validation.unwrap_or(false))
.await?;

let response = proto::SubmitMessageResponse {
message: Some(response_message),
};

self.statsd_client.time(
"rpc.submit_message_with_options.duration",
start_time.elapsed().as_millis() as u64,
);

Ok(Response::new(response))
}

self.statsd_client
.time("rpc.submit_message.duration", elapsed as u64);
async fn submit_message(
&self,
request: Request<proto::Message>,
) -> Result<Response<proto::Message>, Status> {
let start_time = std::time::Instant::now();

let hash = request.get_ref().hash.encode_hex::<String>();
info!(hash, "Received call to [submit_message] RPC");

let message = request.into_inner();
let response_message = self.submit_message_internal(message, false).await?;

self.statsd_client.time(
"rpc.submit_message.duration",
start_time.elapsed().as_millis() as u64,
);

Ok(response)
Ok(Response::new(response_message))
}

type GetBlocksStream = ReceiverStream<Result<Block, Status>>;
Expand Down Expand Up @@ -198,7 +278,7 @@ impl HubService for MyHubService {
// TODO(aditi): Rethink the channel size
let (server_tx, client_rx) = mpsc::channel::<Result<HubEvent, Status>>(100);
let events_txs = match request.get_ref().shard_index {
Some(shard_id) => match self.shard_senders.get(&(shard_id as u32)) {
Some(shard_id) => match self.shard_senders.get(&(shard_id)) {
None => {
return Err(Status::from_error(Box::new(
HubError::invalid_internal_state("Missing shard event tx"),
Expand Down
2 changes: 2 additions & 0 deletions src/network/server_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ mod tests {
let shard2_senders = Senders::new(msgs_tx.clone());
let stores = HashMap::from([(1, shard1_stores), (2, shard2_stores)]);
let senders = HashMap::from([(1, shard1_senders), (2, shard2_senders)]);
let num_shards = senders.len() as u32;

(
stores.clone(),
Expand All @@ -118,6 +119,7 @@ mod tests {
stores,
senders,
statsd_client,
num_shards,
),
)
}
Expand Down
11 changes: 11 additions & 0 deletions src/proto/rpc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,19 @@ message SubscribeRequest {
optional uint32 shard_index = 5;
}

message SubmitMessageRequest {
Message message = 1;
optional bool bypass_validation = 99;
}

message SubmitMessageResponse {
Message message = 1;
}


service HubService {
rpc SubmitMessage(Message) returns (Message);
rpc SubmitMessageWithOptions(SubmitMessageRequest) returns (SubmitMessageResponse);
rpc GetBlocks(BlocksRequest) returns (stream Block);
rpc GetShardChunks(ShardChunksRequest) returns (ShardChunksResponse);
rpc Subscribe(SubscribeRequest) returns (stream HubEvent);
Expand Down
1 change: 1 addition & 0 deletions tests/consensus_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ impl NodeForTest {
grpc_shard_stores,
grpc_shard_senders,
statsd_client.clone(),
num_shards,
);

let grpc_socket_addr: SocketAddr = addr.parse().unwrap();
Expand Down

0 comments on commit 6b45a74

Please sign in to comment.