Skip to content

Commit

Permalink
Move message-queue to a fully binary representation (#454)
Browse files Browse the repository at this point in the history
* Move message-queue to a fully binary representation

Additionally adds a timeout to the message queue test.

* coordinator clippy

* Remove contention for the message-queue socket by using per-request sockets

* clippy
  • Loading branch information
kayabaNerve authored Nov 26, 2023
1 parent c6c7468 commit b79cf8a
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 241 deletions.
4 changes: 0 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion coordinator/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ async fn handle_processor_messages<D: Db, Pro: Processors, P: P2p>(
mut db: D,
key: Zeroizing<<Ristretto as Ciphersuite>::F>,
serai: Arc<Serai>,
mut processors: Pro,
processors: Pro,
p2p: P,
cosign_channel: mpsc::UnboundedSender<CosignedBlock>,
network: NetworkId,
Expand Down
8 changes: 4 additions & 4 deletions coordinator/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ pub struct Message {
#[async_trait::async_trait]
pub trait Processors: 'static + Send + Sync + Clone {
async fn send(&self, network: NetworkId, msg: impl Send + Into<CoordinatorMessage>);
async fn recv(&mut self, network: NetworkId) -> Message;
async fn ack(&mut self, msg: Message);
async fn recv(&self, network: NetworkId) -> Message;
async fn ack(&self, msg: Message);
}

#[async_trait::async_trait]
Expand All @@ -28,7 +28,7 @@ impl Processors for Arc<MessageQueue> {
let msg = borsh::to_vec(&msg).unwrap();
self.queue(metadata, msg).await;
}
async fn recv(&mut self, network: NetworkId) -> Message {
async fn recv(&self, network: NetworkId) -> Message {
let msg = self.next(Service::Processor(network)).await;
assert_eq!(msg.from, Service::Processor(network));

Expand All @@ -40,7 +40,7 @@ impl Processors for Arc<MessageQueue> {

return Message { id, network, msg };
}
async fn ack(&mut self, msg: Message) {
async fn ack(&self, msg: Message) {
MessageQueue::ack(self, Service::Processor(msg.network), msg.id).await
}
}
4 changes: 2 additions & 2 deletions coordinator/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ impl Processors for MemProcessors {
let processor = processors.entry(network).or_insert_with(VecDeque::new);
processor.push_back(msg.into());
}
async fn recv(&mut self, _: NetworkId) -> Message {
async fn recv(&self, _: NetworkId) -> Message {
todo!()
}
async fn ack(&mut self, _: Message) {
async fn ack(&self, _: Message) {
todo!()
}
}
Expand Down
11 changes: 3 additions & 8 deletions message-queue/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ rustdoc-args = ["--cfg", "docsrs"]
[dependencies]
# Macros
once_cell = { version = "1", default-features = false }
serde = { version = "1", default-features = false, features = ["std", "derive"] }

# Encoders
hex = { version = "0.4", default-features = false, features = ["std"] }
borsh = { version = "1", default-features = false, features = ["std", "derive", "de_strict_order"] }
serde_json = { version = "1", default-features = false, features = ["std"] }

# Libs
zeroize = { version = "1", default-features = false, features = ["std"] }
Expand All @@ -37,16 +35,13 @@ log = { version = "0.4", default-features = false, features = ["std"] }
env_logger = { version = "0.10", default-features = false, features = ["humantime"] }

# Uses a single threaded runtime since this shouldn't ever be CPU-bound
tokio = { version = "1", default-features = false, features = ["rt", "time", "macros"] }
tokio = { version = "1", default-features = false, features = ["rt", "time", "io-util", "net", "macros"] }

serai-db = { path = "../common/db", features = ["rocksdb"], optional = true }

serai-env = { path = "../common/env" }

serai-primitives = { path = "../substrate/primitives", features = ["borsh", "serde"] }

jsonrpsee = { version = "0.16", default-features = false, features = ["server"], optional = true }
simple-request = { path = "../common/request", default-features = false }
serai-primitives = { path = "../substrate/primitives", features = ["borsh"] }

[features]
binaries = ["serai-db", "jsonrpsee"]
binaries = ["serai-db"]
183 changes: 94 additions & 89 deletions message-queue/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ use ciphersuite::{
};
use schnorr_signatures::SchnorrSignature;

use serde::{Serialize, Deserialize};

use simple_request::{hyper::Request, Client};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};

use serai_env as env;

use crate::{Service, Metadata, QueuedMessage, message_challenge, ack_challenge};
#[rustfmt::skip]
use crate::{Service, Metadata, QueuedMessage, MessageQueueRequest, message_challenge, ack_challenge};

pub struct MessageQueue {
pub service: Service,
priv_key: Zeroizing<<Ristretto as Ciphersuite>::F>,
pub_key: <Ristretto as Ciphersuite>::G,
client: Client,
url: String,
}

Expand All @@ -37,17 +38,8 @@ impl MessageQueue {
if !url.contains(':') {
url += ":2287";
}
if !url.starts_with("http://") {
url = "http://".to_string() + &url;
}

MessageQueue {
service,
pub_key: Ristretto::generator() * priv_key.deref(),
priv_key,
client: Client::with_connection_pool(),
url,
}
MessageQueue { service, pub_key: Ristretto::generator() * priv_key.deref(), priv_key, url }
}

pub fn from_env(service: Service) -> MessageQueue {
Expand All @@ -72,60 +64,14 @@ impl MessageQueue {
Self::new(service, url, priv_key)
}

async fn json_call(&self, method: &'static str, params: serde_json::Value) -> serde_json::Value {
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
struct JsonRpcRequest {
jsonrpc: &'static str,
method: &'static str,
params: serde_json::Value,
id: u64,
}

let mut res = loop {
// Make the request
match self
.client
.request(
Request::post(&self.url)
.header("Content-Type", "application/json")
.body(
serde_json::to_vec(&JsonRpcRequest {
jsonrpc: "2.0",
method,
params: params.clone(),
id: 0,
})
.unwrap()
.into(),
)
.unwrap(),
)
.await
{
Ok(req) => {
// Get the response
match req.body().await {
Ok(res) => break res,
Err(e) => {
dbg!(e);
}
}
}
Err(e) => {
dbg!(e);
}
}

// Sleep for a second before trying again
tokio::time::sleep(core::time::Duration::from_secs(1)).await;
#[must_use]
async fn send(socket: &mut TcpStream, msg: MessageQueueRequest) -> bool {
let msg = borsh::to_vec(&msg).unwrap();
let Ok(_) = socket.write_all(&u32::try_from(msg.len()).unwrap().to_le_bytes()).await else {
return false;
};

let json: serde_json::Value =
serde_json::from_reader(&mut res).expect("message-queue returned invalid JSON");
if json.get("result").is_none() {
panic!("call failed: {json}");
}
json
let Ok(_) = socket.write_all(&msg).await else { return false };
true
}

pub async fn queue(&self, metadata: Metadata, msg: Vec<u8>) {
Expand All @@ -146,30 +92,76 @@ impl MessageQueue {
)
.serialize();

let json = self.json_call("queue", serde_json::json!([metadata, msg, sig])).await;
if json.get("result") != Some(&serde_json::Value::Bool(true)) {
panic!("failed to queue message: {json}");
let msg = MessageQueueRequest::Queue { meta: metadata, msg, sig };
let mut first = true;
loop {
// Sleep, so we don't hammer re-attempts
if !first {
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
}
first = false;

let Ok(mut socket) = TcpStream::connect(&self.url).await else { continue };
if !Self::send(&mut socket, msg.clone()).await {
continue;
}
if socket.read_u8().await.ok() != Some(1) {
continue;
}
break;
}
}

pub async fn next(&self, from: Service) -> QueuedMessage {
loop {
let json = self.json_call("next", serde_json::json!([from, self.service])).await;

// Convert from a Value to a type via reserialization
let msg: Option<QueuedMessage> = serde_json::from_str(
&serde_json::to_string(
&json.get("result").expect("successful JSON RPC call didn't have result"),
)
.unwrap(),
)
.expect("next didn't return an Option<QueuedMessage>");

// If there wasn't a message, check again in 1s
let Some(msg) = msg else {
tokio::time::sleep(core::time::Duration::from_secs(1)).await;
let msg = MessageQueueRequest::Next { from, to: self.service };
let mut first = true;
'outer: loop {
if !first {
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
continue;
}
first = false;

let Ok(mut socket) = TcpStream::connect(&self.url).await else { continue };

loop {
if !Self::send(&mut socket, msg.clone()).await {
continue 'outer;
}
let Ok(status) = socket.read_u8().await else {
continue 'outer;
};
// If there wasn't a message, check again in 1s
if status == 0 {
tokio::time::sleep(core::time::Duration::from_secs(1)).await;
continue;
}
assert_eq!(status, 1);
break;
}

// Timeout after 5 seconds in case there's an issue with the length handling
let Ok(msg) = tokio::time::timeout(core::time::Duration::from_secs(5), async {
// Read the message length
let Ok(len) = socket.read_u32_le().await else {
return vec![];
};
let mut buf = vec![0; usize::try_from(len).unwrap()];
// Read the message
let Ok(_) = socket.read_exact(&mut buf).await else {
return vec![];
};
buf
})
.await
else {
continue;
};
if msg.is_empty() {
continue;
}

let msg: QueuedMessage = borsh::from_slice(msg.as_slice()).unwrap();

// Verify the message
// Verify the sender is sane
Expand Down Expand Up @@ -202,9 +194,22 @@ impl MessageQueue {
)
.serialize();

let json = self.json_call("ack", serde_json::json!([from, self.service, id, sig])).await;
if json.get("result") != Some(&serde_json::Value::Bool(true)) {
panic!("failed to ack message {id}: {json}");
let msg = MessageQueueRequest::Ack { from, to: self.service, id, sig };
let mut first = true;
loop {
if !first {
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
}
first = false;

let Ok(mut socket) = TcpStream::connect(&self.url).await else { continue };
if !Self::send(&mut socket, msg.clone()).await {
continue;
}
if socket.read_u8().await.ok() != Some(1) {
continue;
}
break;
}
}
}
Loading

0 comments on commit b79cf8a

Please sign in to comment.