diff --git a/common/authenticator-requests/src/error.rs b/common/authenticator-requests/src/error.rs index fc9f8700d5f..1c43da3cadb 100644 --- a/common/authenticator-requests/src/error.rs +++ b/common/authenticator-requests/src/error.rs @@ -22,4 +22,7 @@ pub enum Error { #[error("conversion: {0}")] Conversion(String), + + #[error("failed to serialize response packet: {source}")] + FailedToSerializeResponsePacket { source: Box }, } diff --git a/common/authenticator-requests/src/lib.rs b/common/authenticator-requests/src/lib.rs index dd98cc9f7f0..f2b2fb55ce6 100644 --- a/common/authenticator-requests/src/lib.rs +++ b/common/authenticator-requests/src/lib.rs @@ -1,6 +1,7 @@ // Copyright 2024 - Nym Technologies SA // SPDX-License-Identifier: Apache-2.0 +pub mod traits; pub mod v1; pub mod v2; pub mod v3; diff --git a/common/authenticator-requests/src/traits.rs b/common/authenticator-requests/src/traits.rs new file mode 100644 index 00000000000..cb7cf37ee13 --- /dev/null +++ b/common/authenticator-requests/src/traits.rs @@ -0,0 +1,269 @@ +// Copyright 2024 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +use std::net::IpAddr; + +use nym_credentials_interface::CredentialSpendingData; +use nym_crypto::asymmetric::x25519::PrivateKey; +use nym_service_provider_requests_common::{Protocol, ServiceProviderType}; +use nym_sphinx::addressing::clients::Recipient; +use nym_wireguard_types::PeerPublicKey; + +use crate::{v1, v2, v3, Error}; + +#[derive(Copy, Clone, Debug)] +pub enum AuthenticatorVersion { + V1, + V2, + V3, + UNKNOWN, +} + +impl From for AuthenticatorVersion { + fn from(value: Protocol) -> Self { + if value.service_provider_type != ServiceProviderType::Authenticator { + AuthenticatorVersion::UNKNOWN + } else if value.version == v1::VERSION { + AuthenticatorVersion::V1 + } else if value.version == v2::VERSION { + AuthenticatorVersion::V2 + } else if value.version == v3::VERSION { + AuthenticatorVersion::V3 + } else { + AuthenticatorVersion::UNKNOWN + } + } +} + +pub trait InitMessage { + fn pub_key(&self) -> PeerPublicKey; +} + +impl InitMessage for v1::registration::InitMessage { + fn pub_key(&self) -> PeerPublicKey { + self.pub_key + } +} + +impl InitMessage for v2::registration::InitMessage { + fn pub_key(&self) -> PeerPublicKey { + self.pub_key + } +} + +impl InitMessage for v3::registration::InitMessage { + fn pub_key(&self) -> PeerPublicKey { + self.pub_key + } +} + +pub trait FinalMessage { + fn pub_key(&self) -> PeerPublicKey; + fn verify(&self, private_key: &PrivateKey, nonce: u64) -> Result<(), Error>; + fn private_ip(&self) -> IpAddr; + fn credential(&self) -> Option; +} + +impl FinalMessage for v1::GatewayClient { + fn pub_key(&self) -> PeerPublicKey { + self.pub_key + } + + fn verify(&self, private_key: &PrivateKey, nonce: u64) -> Result<(), Error> { + self.verify(private_key, nonce) + } + + fn private_ip(&self) -> IpAddr { + self.private_ip + } + + fn credential(&self) -> Option { + None + } +} + +impl FinalMessage for v2::registration::FinalMessage { + fn pub_key(&self) -> PeerPublicKey { + self.gateway_client.pub_key + } + + fn verify(&self, private_key: &PrivateKey, nonce: u64) -> Result<(), Error> { + self.gateway_client.verify(private_key, nonce) + } + + fn private_ip(&self) -> IpAddr { + self.gateway_client.private_ip + } + + fn credential(&self) -> Option { + self.credential.clone() + } +} + +impl FinalMessage for v3::registration::FinalMessage { + fn pub_key(&self) -> PeerPublicKey { + self.gateway_client.pub_key + } + + fn verify(&self, private_key: &PrivateKey, nonce: u64) -> Result<(), Error> { + self.gateway_client.verify(private_key, nonce) + } + + fn private_ip(&self) -> IpAddr { + self.gateway_client.private_ip + } + + fn credential(&self) -> Option { + self.credential.clone() + } +} + +pub trait QueryBandwidthMessage { + fn pub_key(&self) -> PeerPublicKey; +} + +impl QueryBandwidthMessage for PeerPublicKey { + fn pub_key(&self) -> PeerPublicKey { + *self + } +} + +pub trait TopUpMessage { + fn pub_key(&self) -> PeerPublicKey; + fn credential(&self) -> CredentialSpendingData; +} + +impl TopUpMessage for v3::topup::TopUpMessage { + fn pub_key(&self) -> PeerPublicKey { + self.pub_key + } + + fn credential(&self) -> CredentialSpendingData { + self.credential.clone() + } +} + +pub enum AuthenticatorRequest { + Initial { + msg: Box, + protocol: Protocol, + reply_to: Recipient, + request_id: u64, + }, + Final { + msg: Box, + protocol: Protocol, + reply_to: Recipient, + request_id: u64, + }, + QueryBandwidth { + msg: Box, + protocol: Protocol, + reply_to: Recipient, + request_id: u64, + }, + TopUpBandwidth { + msg: Box, + protocol: Protocol, + reply_to: Recipient, + request_id: u64, + }, +} + +impl From for AuthenticatorRequest { + fn from(value: v1::request::AuthenticatorRequest) -> Self { + match value.data { + v1::request::AuthenticatorRequestData::Initial(init_message) => Self::Initial { + msg: Box::new(init_message), + protocol: Protocol { + version: value.version, + service_provider_type: ServiceProviderType::Authenticator, + }, + reply_to: value.reply_to, + request_id: value.request_id, + }, + v1::request::AuthenticatorRequestData::Final(gateway_client) => Self::Final { + msg: Box::new(gateway_client), + protocol: Protocol { + version: value.version, + service_provider_type: ServiceProviderType::Authenticator, + }, + reply_to: value.reply_to, + request_id: value.request_id, + }, + v1::request::AuthenticatorRequestData::QueryBandwidth(peer_public_key) => { + Self::QueryBandwidth { + msg: Box::new(peer_public_key), + protocol: Protocol { + version: value.version, + service_provider_type: ServiceProviderType::Authenticator, + }, + reply_to: value.reply_to, + request_id: value.request_id, + } + } + } + } +} + +impl From for AuthenticatorRequest { + fn from(value: v2::request::AuthenticatorRequest) -> Self { + match value.data { + v2::request::AuthenticatorRequestData::Initial(init_message) => Self::Initial { + msg: Box::new(init_message), + protocol: value.protocol, + reply_to: value.reply_to, + request_id: value.request_id, + }, + v2::request::AuthenticatorRequestData::Final(final_message) => Self::Final { + msg: final_message, + protocol: value.protocol, + reply_to: value.reply_to, + request_id: value.request_id, + }, + v2::request::AuthenticatorRequestData::QueryBandwidth(peer_public_key) => { + Self::QueryBandwidth { + msg: Box::new(peer_public_key), + protocol: value.protocol, + reply_to: value.reply_to, + request_id: value.request_id, + } + } + } + } +} + +impl From for AuthenticatorRequest { + fn from(value: v3::request::AuthenticatorRequest) -> Self { + match value.data { + v3::request::AuthenticatorRequestData::Initial(init_message) => Self::Initial { + msg: Box::new(init_message), + protocol: value.protocol, + reply_to: value.reply_to, + request_id: value.request_id, + }, + v3::request::AuthenticatorRequestData::Final(final_message) => Self::Final { + msg: final_message, + protocol: value.protocol, + reply_to: value.reply_to, + request_id: value.request_id, + }, + v3::request::AuthenticatorRequestData::QueryBandwidth(peer_public_key) => { + Self::QueryBandwidth { + msg: Box::new(peer_public_key), + protocol: value.protocol, + reply_to: value.reply_to, + request_id: value.request_id, + } + } + v3::request::AuthenticatorRequestData::TopUpBandwidth(top_up_message) => { + Self::TopUpBandwidth { + msg: top_up_message, + protocol: value.protocol, + reply_to: value.reply_to, + request_id: value.request_id, + } + } + } + } +} diff --git a/common/service-provider-requests-common/src/lib.rs b/common/service-provider-requests-common/src/lib.rs index d13a7156c9b..f9f0564e1d8 100644 --- a/common/service-provider-requests-common/src/lib.rs +++ b/common/service-provider-requests-common/src/lib.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[repr(u8)] pub enum ServiceProviderType { NetworkRequester = 0, diff --git a/service-providers/authenticator/src/error.rs b/service-providers/authenticator/src/error.rs index bfc04c3fbc1..66453693493 100644 --- a/service-providers/authenticator/src/error.rs +++ b/service-providers/authenticator/src/error.rs @@ -88,6 +88,9 @@ pub enum AuthenticatorError { #[error("storage should have the requested bandwidht entry")] MissingClientBandwidthEntry, + + #[error("unknown version number")] + UnknownVersion, } pub type Result = std::result::Result; diff --git a/service-providers/authenticator/src/mixnet_listener.rs b/service-providers/authenticator/src/mixnet_listener.rs index 8fc5c625824..606e38a2597 100644 --- a/service-providers/authenticator/src/mixnet_listener.rs +++ b/service-providers/authenticator/src/mixnet_listener.rs @@ -11,15 +11,17 @@ use defguard_wireguard_rs::net::IpAddrMask; use defguard_wireguard_rs::{host::Peer, key::Key}; use futures::StreamExt; use nym_authenticator_requests::{ + traits::{ + AuthenticatorRequest, AuthenticatorVersion, FinalMessage, InitMessage, + QueryBandwidthMessage, TopUpMessage, + }, v1, v2, v3::{ self, registration::{ - FinalMessage, GatewayClient, InitMessage, PendingRegistrations, PrivateIPs, - RegistrationData, RegistredData, RemainingBandwidthData, + GatewayClient, PendingRegistrations, PrivateIPs, RegistrationData, + RemainingBandwidthData, }, - request::{AuthenticatorRequest, AuthenticatorRequestData}, - response::AuthenticatorResponse, }, CURRENT_VERSION, }; @@ -32,7 +34,7 @@ use nym_crypto::asymmetric::x25519::KeyPair; use nym_gateway_requests::models::CredentialSpendingRequest; use nym_gateway_storage::Storage; use nym_sdk::mixnet::{InputMessage, MixnetMessageSender, Recipient, TransmissionLane}; -use nym_service_provider_requests_common::ServiceProviderType; +use nym_service_provider_requests_common::{Protocol, ServiceProviderType}; use nym_sphinx::receiver::ReconstructedMessage; use nym_task::TaskHandle; use nym_wireguard::WireguardGatewayData; @@ -43,7 +45,7 @@ use tokio_stream::wrappers::IntervalStream; use crate::{config::Config, error::*}; -type AuthenticatorHandleResult = Result; +type AuthenticatorHandleResult = Result<(Vec, Recipient)>; const DEFAULT_REGISTRATION_TIMEOUT_CHECK: Duration = Duration::from_secs(60); // 1 minute pub(crate) struct RegistredAndFree { @@ -153,22 +155,72 @@ impl MixnetListener { async fn on_initial_request( &mut self, - init_message: InitMessage, + init_message: Box, + protocol: Protocol, request_id: u64, reply_to: Recipient, ) -> AuthenticatorHandleResult { - let remote_public = init_message.pub_key; + let remote_public = init_message.pub_key(); let nonce: u64 = fastrand::u64(..); let mut registred_and_free = self.registred_and_free.write().await; if let Some(registration_data) = registred_and_free .registration_in_progres .get(&remote_public) { - return Ok(AuthenticatorResponse::new_pending_registration_success( - registration_data.clone(), - request_id, - reply_to, - )); + let gateway_data = registration_data.gateway_data.clone(); + let bytes = match AuthenticatorVersion::from(protocol) { + AuthenticatorVersion::V1 => { + v1::response::AuthenticatorResponse::new_pending_registration_success( + v1::registration::RegistrationData { + nonce: registration_data.nonce, + gateway_data: v1::GatewayClient { + pub_key: gateway_data.pub_key, + private_ip: gateway_data.private_ip, + mac: v1::ClientMac::new(gateway_data.mac.to_vec()), + }, + wg_port: registration_data.wg_port, + }, + request_id, + reply_to, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })? + } + AuthenticatorVersion::V2 => { + v2::response::AuthenticatorResponse::new_pending_registration_success( + v2::registration::RegistrationData { + nonce: registration_data.nonce, + gateway_data: registration_data.gateway_data.clone().into(), + wg_port: registration_data.wg_port, + }, + request_id, + reply_to, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })? + } + AuthenticatorVersion::V3 => { + v3::response::AuthenticatorResponse::new_pending_registration_success( + v3::registration::RegistrationData { + nonce: registration_data.nonce, + gateway_data: registration_data.gateway_data.clone(), + wg_port: registration_data.wg_port, + }, + request_id, + reply_to, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })? + } + AuthenticatorVersion::UNKNOWN => return Err(AuthenticatorError::UnknownVersion), + }; + return Ok((bytes, reply_to)); } let peer = self.peer_manager.query_peer(remote_public).await?; @@ -178,15 +230,49 @@ impl MixnetListener { "private ip list should not be empty".to_string(), )); }; - return Ok(AuthenticatorResponse::new_registered( - RegistredData { - pub_key: PeerPublicKey::new(self.keypair().public_key().to_bytes().into()), - private_ip: allowed_ip.ip, - wg_port: self.config.authenticator.announced_port, - }, - reply_to, - request_id, - )); + let bytes = match AuthenticatorVersion::from(protocol) { + AuthenticatorVersion::V1 => v1::response::AuthenticatorResponse::new_registered( + v1::registration::RegistredData { + pub_key: PeerPublicKey::new(self.keypair().public_key().to_bytes().into()), + private_ip: allowed_ip.ip, + wg_port: self.config.authenticator.announced_port, + }, + reply_to, + request_id, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })?, + AuthenticatorVersion::V2 => v2::response::AuthenticatorResponse::new_registered( + v2::registration::RegistredData { + pub_key: PeerPublicKey::new(self.keypair().public_key().to_bytes().into()), + private_ip: allowed_ip.ip, + wg_port: self.config.authenticator.announced_port, + }, + reply_to, + request_id, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })?, + AuthenticatorVersion::V3 => v3::response::AuthenticatorResponse::new_registered( + v3::registration::RegistredData { + pub_key: PeerPublicKey::new(self.keypair().public_key().to_bytes().into()), + private_ip: allowed_ip.ip, + wg_port: self.config.authenticator.announced_port, + }, + reply_to, + request_id, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })?, + AuthenticatorVersion::UNKNOWN => return Err(AuthenticatorError::UnknownVersion), + }; + return Ok((bytes, reply_to)); } let private_ip_ref = registred_and_free @@ -205,51 +291,98 @@ impl MixnetListener { ); let registration_data = RegistrationData { nonce, - gateway_data, + gateway_data: gateway_data.clone(), wg_port: self.config.authenticator.announced_port, }; registred_and_free .registration_in_progres .insert(remote_public, registration_data.clone()); + let bytes = match AuthenticatorVersion::from(protocol) { + AuthenticatorVersion::V1 => { + v1::response::AuthenticatorResponse::new_pending_registration_success( + v1::registration::RegistrationData { + nonce: registration_data.nonce, + gateway_data: v1::GatewayClient { + pub_key: gateway_data.pub_key, + private_ip: gateway_data.private_ip, + mac: v1::ClientMac::new(gateway_data.mac.to_vec()), + }, + wg_port: registration_data.wg_port, + }, + request_id, + reply_to, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })? + } + AuthenticatorVersion::V2 => { + v2::response::AuthenticatorResponse::new_pending_registration_success( + v2::registration::RegistrationData { + nonce: registration_data.nonce, + gateway_data: registration_data.gateway_data.into(), + wg_port: registration_data.wg_port, + }, + request_id, + reply_to, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })? + } + AuthenticatorVersion::V3 => { + v3::response::AuthenticatorResponse::new_pending_registration_success( + v3::registration::RegistrationData { + nonce: registration_data.nonce, + gateway_data: registration_data.gateway_data, + wg_port: registration_data.wg_port, + }, + request_id, + reply_to, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })? + } + AuthenticatorVersion::UNKNOWN => return Err(AuthenticatorError::UnknownVersion), + }; - Ok(AuthenticatorResponse::new_pending_registration_success( - registration_data, - request_id, - reply_to, - )) + Ok((bytes, reply_to)) } async fn on_final_request( &mut self, - final_message: FinalMessage, + final_message: Box, + protocol: Protocol, request_id: u64, reply_to: Recipient, ) -> AuthenticatorHandleResult { let mut registred_and_free = self.registred_and_free.write().await; let registration_data = registred_and_free .registration_in_progres - .get(&final_message.gateway_client.pub_key()) + .get(&final_message.pub_key()) .ok_or(AuthenticatorError::RegistrationNotInProgress)? .clone(); if final_message - .gateway_client .verify(self.keypair().private_key(), registration_data.nonce) .is_err() { return Err(AuthenticatorError::MacVerificationFailure); } - let mut peer = Peer::new(Key::new(final_message.gateway_client.pub_key.to_bytes())); + let mut peer = Peer::new(Key::new(final_message.pub_key().to_bytes())); peer.allowed_ips - .push(IpAddrMask::new(final_message.gateway_client.private_ip, 32)); + .push(IpAddrMask::new(final_message.private_ip(), 32)); // If gateway does ecash verification and client sends a credential, we do the additional // credential verification. Later this will become mandatory. - if let (Some(ecash_verifier), Some(credential)) = ( - self.ecash_verifier.clone(), - final_message.credential.clone(), - ) { + if let (Some(ecash_verifier), Some(credential)) = + (self.ecash_verifier.clone(), final_message.credential()) + { let client_id = ecash_verifier .storage() .insert_wireguard_peer(&peer, true) @@ -279,17 +412,45 @@ impl MixnetListener { } registred_and_free .registration_in_progres - .remove(&final_message.gateway_client.pub_key()); - - Ok(AuthenticatorResponse::new_registered( - RegistredData { - pub_key: registration_data.gateway_data.pub_key, - private_ip: registration_data.gateway_data.private_ip, - wg_port: registration_data.wg_port, - }, - reply_to, - request_id, - )) + .remove(&final_message.pub_key()); + + let bytes = match AuthenticatorVersion::from(protocol) { + AuthenticatorVersion::V1 => v1::response::AuthenticatorResponse::new_registered( + v1::registration::RegistredData { + pub_key: registration_data.gateway_data.pub_key, + private_ip: registration_data.gateway_data.private_ip, + wg_port: registration_data.wg_port, + }, + reply_to, + request_id, + ) + .to_bytes() + .map_err(|err| AuthenticatorError::FailedToSerializeResponsePacket { source: err })?, + AuthenticatorVersion::V2 => v2::response::AuthenticatorResponse::new_registered( + v2::registration::RegistredData { + pub_key: registration_data.gateway_data.pub_key, + private_ip: registration_data.gateway_data.private_ip, + wg_port: registration_data.wg_port, + }, + reply_to, + request_id, + ) + .to_bytes() + .map_err(|err| AuthenticatorError::FailedToSerializeResponsePacket { source: err })?, + AuthenticatorVersion::V3 => v3::response::AuthenticatorResponse::new_registered( + v3::registration::RegistredData { + pub_key: registration_data.gateway_data.pub_key, + private_ip: registration_data.gateway_data.private_ip, + wg_port: registration_data.wg_port, + }, + reply_to, + request_id, + ) + .to_bytes() + .map_err(|err| AuthenticatorError::FailedToSerializeResponsePacket { source: err })?, + AuthenticatorVersion::UNKNOWN => return Err(AuthenticatorError::UnknownVersion), + }; + Ok((bytes, reply_to)) } async fn credential_verification( @@ -325,22 +486,60 @@ impl MixnetListener { async fn on_query_bandwidth_request( &mut self, - peer_public_key: PeerPublicKey, + msg: Box, + protocol: Protocol, request_id: u64, reply_to: Recipient, ) -> AuthenticatorHandleResult { - let bandwidth_data = self.peer_manager.query_bandwidth(peer_public_key).await?; - Ok(AuthenticatorResponse::new_remaining_bandwidth( - bandwidth_data, - reply_to, - request_id, - )) + let bandwidth_data = self.peer_manager.query_bandwidth(msg).await?; + let bytes = match AuthenticatorVersion::from(protocol) { + AuthenticatorVersion::V1 => { + v1::response::AuthenticatorResponse::new_remaining_bandwidth( + bandwidth_data.map(|data| v1::registration::RemainingBandwidthData { + available_bandwidth: data.available_bandwidth as u64, + suspended: false, + }), + reply_to, + request_id, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })? + } + AuthenticatorVersion::V2 => { + v2::response::AuthenticatorResponse::new_remaining_bandwidth( + bandwidth_data.map(|data| v2::registration::RemainingBandwidthData { + available_bandwidth: data.available_bandwidth, + }), + reply_to, + request_id, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })? + } + AuthenticatorVersion::V3 => { + v3::response::AuthenticatorResponse::new_remaining_bandwidth( + bandwidth_data, + reply_to, + request_id, + ) + .to_bytes() + .map_err(|err| { + AuthenticatorError::FailedToSerializeResponsePacket { source: err } + })? + } + AuthenticatorVersion::UNKNOWN => return Err(AuthenticatorError::UnknownVersion), + }; + Ok((bytes, reply_to)) } async fn on_topup_bandwidth_request( &mut self, - peer_public_key: PeerPublicKey, - credential: CredentialSpendingData, + msg: Box, + protocol: Protocol, request_id: u64, reply_to: Recipient, ) -> AuthenticatorHandleResult { @@ -349,7 +548,7 @@ impl MixnetListener { }; let client_id = ecash_verifier .storage() - .get_wireguard_peer(&peer_public_key.to_string()) + .get_wireguard_peer(&msg.pub_key().to_string()) .await? .ok_or(AuthenticatorError::MissingClientBandwidthEntry)? .client_id @@ -364,7 +563,7 @@ impl MixnetListener { let client_bandwidth = ClientBandwidth::new(bandwidth.into()); let mut verifier = CredentialVerifier::new( - CredentialSpendingRequest::new(credential), + CredentialSpendingRequest::new(msg.credential()), ecash_verifier.clone(), BandwidthStorageManager::new( ecash_verifier.storage().clone(), @@ -376,13 +575,22 @@ impl MixnetListener { ); let available_bandwidth = verifier.verify().await?; - Ok(AuthenticatorResponse::new_topup_bandwidth( - RemainingBandwidthData { - available_bandwidth, - }, - reply_to, - request_id, - )) + let bytes = match AuthenticatorVersion::from(protocol) { + AuthenticatorVersion::V3 => v3::response::AuthenticatorResponse::new_topup_bandwidth( + RemainingBandwidthData { + available_bandwidth, + }, + reply_to, + request_id, + ) + .to_bytes() + .map_err(|err| AuthenticatorError::FailedToSerializeResponsePacket { source: err })?, + AuthenticatorVersion::V1 | AuthenticatorVersion::V2 | AuthenticatorVersion::UNKNOWN => { + return Err(AuthenticatorError::UnknownVersion) + } + }; + + Ok((bytes, reply_to)) } async fn on_reconstructed_message( @@ -394,62 +602,52 @@ impl MixnetListener { reconstructed.sender_tag ); - let request = match deserialize_request(&reconstructed) { - Err(AuthenticatorError::InvalidPacketVersion(version)) => { - return self.on_version_mismatch(version, &reconstructed); - } - req => req, - }?; + let request = deserialize_request(&reconstructed)?; - match request.data { - AuthenticatorRequestData::Initial(init_msg) => { - self.on_initial_request(init_msg, request.request_id, request.reply_to) + match request { + AuthenticatorRequest::Initial { + msg, + reply_to, + request_id, + protocol, + } => { + self.on_initial_request(msg, protocol, request_id, reply_to) .await } - AuthenticatorRequestData::Final(final_msg) => { - self.on_final_request(*final_msg, request.request_id, request.reply_to) + AuthenticatorRequest::Final { + msg, + reply_to, + request_id, + protocol, + } => { + self.on_final_request(msg, protocol, request_id, reply_to) .await } - AuthenticatorRequestData::QueryBandwidth(peer_public_key) => { - self.on_query_bandwidth_request( - peer_public_key, - request.request_id, - request.reply_to, - ) - .await + AuthenticatorRequest::QueryBandwidth { + msg, + reply_to, + request_id, + protocol, + } => { + self.on_query_bandwidth_request(msg, protocol, request_id, reply_to) + .await } - AuthenticatorRequestData::TopUpBandwidth(topup_message) => { - self.on_topup_bandwidth_request( - topup_message.pub_key, - topup_message.credential, - request.request_id, - request.reply_to, - ) - .await + AuthenticatorRequest::TopUpBandwidth { + msg, + reply_to, + request_id, + protocol, + } => { + self.on_topup_bandwidth_request(msg, protocol, request_id, reply_to) + .await } } } - fn on_version_mismatch( - &self, - version: u8, - _reconstructed: &ReconstructedMessage, - ) -> AuthenticatorHandleResult { - // If it's possible to parse, do so and return back a response, otherwise just drop - Err(AuthenticatorError::InvalidPacketVersion(version)) - } - // When an incoming mixnet message triggers a response that we send back. - async fn handle_response(&self, response: AuthenticatorResponse) -> Result<()> { - let recipient = response.recipient(); - - let response_packet = response.to_bytes().map_err(|err| { - log::error!("Failed to serialize response packet"); - AuthenticatorError::FailedToSerializeResponsePacket { source: err } - })?; - + async fn handle_response(&self, response: Vec, recipient: Recipient) -> Result<()> { let input_message = - InputMessage::new_regular(recipient, response_packet, TransmissionLane::General, None); + InputMessage::new_regular(recipient, response, TransmissionLane::General, None); self.mixnet_client .send(input_message) .await @@ -473,8 +671,8 @@ impl MixnetListener { msg = self.mixnet_client.next() => { if let Some(msg) = msg { match self.on_reconstructed_message(msg).await { - Ok(response) => { - if let Err(err) = self.handle_response(response).await { + Ok((response, recipient)) => { + if let Err(err) = self.handle_response(response, recipient).await { log::error!("Mixnet listener failed to handle response: {err}"); } } @@ -506,7 +704,6 @@ fn deserialize_request(reconstructed: &ReconstructedMessage) -> Result v1::request::AuthenticatorRequest::from_reconstructed_message(reconstructed) .map_err(|err| AuthenticatorError::FailedToDeserializeTaggedPacket { source: err }) - .map(Into::::into) .map(Into::into), [2, request_type] => { if request_type == ServiceProviderType::Authenticator as u8 { @@ -525,6 +722,7 @@ fn deserialize_request(reconstructed: &ReconstructedMessage) -> Result, ) -> Result> { - let key = Key::new(peer_public_key.to_bytes()); + let key = Key::new(msg.pub_key().to_bytes()); let (response_tx, response_rx) = oneshot::channel(); let msg = PeerControlRequest::QueryBandwidth { key, response_tx }; self.wireguard_gateway_data