Skip to content

Commit

Permalink
Properly manage state in Session
Browse files Browse the repository at this point in the history
That's to make sure we get rid of all private key once it's consumed

Change-Id: Ife6163e611d075c421501477339a7282de64b235
  • Loading branch information
k-naliuka committed Oct 25, 2024
1 parent 18dda3d commit 361bd00
Show file tree
Hide file tree
Showing 4 changed files with 354 additions and 190 deletions.
35 changes: 23 additions & 12 deletions oak_session/src/attestation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ impl From<Error> for AttestationFailure {
}
}

impl From<AttestationFailure> for Error {
fn from(value: AttestationFailure) -> Self {
anyhow!(value.to_string())
}
}

// TODO: b/371139436 - Use definition from `oak_attestation` crate, once DICE
// logic has been moved to a separate crate.
#[cfg_attr(test, automock)]
Expand All @@ -84,7 +90,7 @@ pub trait Endorser: Send {
}

#[cfg_attr(test, automock)]
// Verifier for the particular type of the attestation.
/// Verifier for the particular type of the attestation.
pub trait AttestationVerifier: Send {
fn verify(
&self,
Expand All @@ -110,17 +116,19 @@ pub enum AttestationType {
Unattested,
}

// Provider for the particular type of the attestation.
/// Provider for the particular type of the attestation.
pub trait AttestationProvider: Send {
// Consume the attestation results when they're ready. Returns None if the
// attestation still is still pending the incoming peer's data.
// attestation still is still pending the incoming peer's data. The result is
// taken rather than copied since the results returned might be heavy and
// contain cryptographic material.
fn take_attestation_result(&mut self)
-> Option<Result<AttestationSuccess, AttestationFailure>>;
}

// Aggregates the attestation result from multiple verifiers. Implementations of
// this trait define the logic of when the overall attestation step succeeds or
// fails.
/// Aggregates the attestation result from multiple verifiers. Implementations
/// of this trait define the logic of when the overall attestation step succeeds
/// or fails.
pub trait AttestationAggregator: Send {
fn aggregate_attestation_results(
&self,
Expand Down Expand Up @@ -234,7 +242,10 @@ impl ProtocolEngine<AttestResponse, AttestRequest> for ClientAttestationProvider
)?,
))
}
AttestationType::SelfUnidirectional => None,
AttestationType::SelfUnidirectional => Some(
Ok(AttestationSuccess { attestation_results: BTreeMap::new() })
.map_err(AttestationFailure::from),
),
AttestationType::Unattested => return Err(anyhow!("no attestation message expected'")),
};
Ok(Some(()))
Expand Down Expand Up @@ -276,10 +287,7 @@ impl ServerAttestationProvider {
.collect::<Result<BTreeMap<String, EndorsedEvidence>, Error>>()?,
})
}
AttestationType::PeerUnidirectional => {
Some(AttestResponse { endorsed_evidence: BTreeMap::new() })
}
AttestationType::Unattested => None,
AttestationType::PeerUnidirectional | AttestationType::Unattested => None,
},
config,
attestation_result: None,
Expand Down Expand Up @@ -317,7 +325,10 @@ impl ProtocolEngine<AttestRequest, AttestResponse> for ServerAttestationProvider
)?,
))
}
AttestationType::SelfUnidirectional => None,
AttestationType::SelfUnidirectional => Some(
Ok(AttestationSuccess { attestation_results: BTreeMap::new() })
.map_err(AttestationFailure::from),
),
AttestationType::Unattested => return Err(anyhow!("no attestation message expected'")),
};
Ok(Some(()))
Expand Down
84 changes: 77 additions & 7 deletions oak_session/src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,60 @@ pub enum HandshakeType {
NoiseNN,
}

/// Struct that represents the data extracted from a successfully executed Noise
/// handshake.
pub struct HandshakeResult {
/// Keys to use with the established encrypted channel.
pub session_keys: SessionKeys,
/// The hash of the data exchanged in the handshake.
pub handshake_hash: Vec<u8>,
/// Bindings fo
pub session_bindings: BTreeMap<String, SessionBinding>,
}

/// Trait that allows building a handshaker without passing any more data to it.
/// It encapsulates any parameters necessary to create a handshaker object
/// (i.e., the configuration)
pub trait HandshakerBuilder<T: Handshaker>: Send {
fn build(self: Box<Self>) -> Result<T, Error>;
}

pub struct ClientHandshakerBuilder {
pub config: HandshakerConfig,
}

impl HandshakerBuilder<ClientHandshaker> for ClientHandshakerBuilder {
fn build(self: Box<Self>) -> Result<ClientHandshaker, Error> {
ClientHandshaker::create(self.config)
}
}
pub struct ServerHandshakerBuilder {
pub config: HandshakerConfig,
pub client_binding_expected: bool,
}

impl HandshakerBuilder<ServerHandshaker> for ServerHandshakerBuilder {
fn build(self: Box<Self>) -> Result<ServerHandshaker, Error> {
Ok(ServerHandshaker::new(self.config, self.client_binding_expected))
}
}

/// Trait that performs a Noise handshake between the parties following the
/// pattern specified in the configuration.
pub trait Handshaker: Send {
// Consume the handshake result when it's ready. Returns None if the handshake
// is still in progress or its results have already been consumed.
fn take_handshake_result(&mut self) -> Option<HandshakeResult>;
/// Consume the session keys produced by the handshake. Returns error if the
/// keys are not ready. Can only be called once.
fn take_session_keys(self) -> Result<SessionKeys, Error>;

/// Gets the hash of the completed handshake without consuming the stored
/// handshake results. This allows using the hash for binding independently
/// from creating the encrypted channel. Returns an error if the
/// handshake is not yet complete.
fn get_handshake_hash(&self) -> Result<Vec<u8>, Error>;

// Allows checking whether the handshake is complete without consuming the
// produced results.
fn is_handshake_complete(&self) -> bool;
}

/// Client-side Handshaker that initiates the crypto handshake with the server.
Expand Down Expand Up @@ -118,8 +162,21 @@ impl ClientHandshaker {
}

impl Handshaker for ClientHandshaker {
fn take_handshake_result(&mut self) -> Option<HandshakeResult> {
self.handshake_result.take()
fn take_session_keys(mut self) -> Result<SessionKeys, Error> {
Ok(self.handshake_result.take().ok_or(anyhow!("handshake is not complete"))?.session_keys)
}

fn get_handshake_hash(&self) -> Result<Vec<u8>, Error> {
Ok(self
.handshake_result
.as_ref()
.ok_or(anyhow!("handshake is not complete"))?
.handshake_hash
.clone())
}

fn is_handshake_complete(&self) -> bool {
self.handshake_result.is_some() && self.followup_message.is_none()
}
}

Expand Down Expand Up @@ -208,8 +265,21 @@ impl ServerHandshaker {
}

impl Handshaker for ServerHandshaker {
fn take_handshake_result(&mut self) -> Option<HandshakeResult> {
self.handshake_result.take()
fn take_session_keys(mut self) -> Result<SessionKeys, Error> {
Ok(self.handshake_result.take().ok_or(anyhow!("handshake is not complete"))?.session_keys)
}

fn get_handshake_hash(&self) -> Result<Vec<u8>, Error> {
Ok(self
.handshake_result
.as_ref()
.ok_or(anyhow!("handshake is not complete"))?
.handshake_hash
.clone())
}

fn is_handshake_complete(&self) -> bool {
self.handshake_result.is_some()
}
}

Expand Down
Loading

0 comments on commit 361bd00

Please sign in to comment.