Skip to content

Commit

Permalink
migtd: add timeout future
Browse files Browse the repository at this point in the history
Set timeout for network operations.

Signed-off-by: Jiaqi Gao <[email protected]>
  • Loading branch information
gaojiaqi7 committed Oct 29, 2023
1 parent 7c1c789 commit 8852a00
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 15 deletions.
31 changes: 31 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion src/migtd/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ anyhow = { version = "1.0.68", default-features = false }
attestation = { path = "../attestation", default-features = false }
cc-measurement = { path = "../../deps/td-shim/cc-measurement"}
crypto = { path = "../crypto" }
futures-util = { version = "0.3.17", default-features = false, optional = true }
lazy_static = { version = "1.0", features = ["spin_no_std"] }
log = { version = "0.4.13", features = ["release_max_level_off"] }
pci = { path="../devices/pci" }
Expand Down Expand Up @@ -44,7 +45,7 @@ td-benchmark = { path = "../../deps/td-shim/devtools/td-benchmark", default-feat

[features]
default = ["tdx"]
async = ["crypto/async", "vsock/async", "async_io", "async_runtime"]
async = ["crypto/async", "async_io", "async_runtime", "futures-util", "vsock/async"]
cet-shstk = ["td-payload/cet-shstk"]
coverage = ["minicov"]
main = []
Expand Down
2 changes: 2 additions & 0 deletions src/migtd/src/driver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#[cfg(feature = "virtio-serial")]
pub mod serial;
#[cfg(feature = "async")]
pub mod ticks;
pub mod timer;
#[cfg(any(feature = "virtio-vsock", feature = "vmcall-vsock"))]
pub mod vsock;
81 changes: 81 additions & 0 deletions src/migtd/src/driver/ticks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) 2023 Intel Corporation
//
// SPDX-License-Identifier: BSD-2-Clause-Patent

use core::{
future::Future,
pin::Pin,
sync::atomic::{AtomicU64, Ordering},
task::{Context, Poll},
time::Duration,
};

use super::timer::*;

static SYS_TICK: AtomicU64 = AtomicU64::new(0);
const INTERVAL: u64 = 1;

pub struct TimeoutError;

pub fn init_sys_tick() {
init_timer();
set_timer_callback(timer_callback);
schedule_timeout(INTERVAL);
}

fn timer_callback() {
SYS_TICK
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| v.checked_add(1))
.unwrap();
schedule_timeout(INTERVAL);
}

fn now() -> u64 {
SYS_TICK.load(Ordering::SeqCst)
}

/// Runs a given future with a timeout.
#[cfg(feature = "async")]
pub async fn with_timeout<F: Future>(timeout: Duration, fut: F) -> Result<F::Output, TimeoutError> {
use futures_util::{
future::{select, Either},
pin_mut,
};

pin_mut!(fut);
let timeout_fut = Timer::after(timeout);

match select(fut, timeout_fut).await {
Either::Left((r, _)) => Ok(r),
Either::Right(_) => Err(TimeoutError),
}
}

struct Timer {
expires_at: u128,
yielded_once: bool,
}

impl Timer {
/// Expire after specified duration.
pub fn after(duration: Duration) -> Self {
Self {
expires_at: now() as u128 + duration.as_millis(),
yielded_once: false,
}
}
}

impl Unpin for Timer {}

impl Future for Timer {
type Output = ();
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.yielded_once && self.expires_at <= now() as u128 {
Poll::Ready(())
} else {
self.yielded_once = true;
Poll::Pending
}
}
}
25 changes: 17 additions & 8 deletions src/migtd/src/driver/timer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,17 @@
// SPDX-License-Identifier: BSD-2-Clause-Patent

use core::sync::atomic::{AtomicBool, Ordering};
use spin::Once;
use td_payload::arch::apic::*;
use td_payload::arch::idt::register;
use td_payload::interrupt_handler_template;

/// A simple apic timer notification handler used to handle the
/// time out events

static TIMEOUT_CALLBACK: Once<fn()> = Once::new();
static TIMEOUT_FLAG: AtomicBool = AtomicBool::new(false);

const TIMEOUT_VECTOR: u8 = 33;
const CPUID_TSC_DEADLINE_BIT: u32 = 1 << 24;

interrupt_handler_template!(timer, _stack, {
TIMEOUT_FLAG.store(true, Ordering::SeqCst);
});

pub fn init_timer() {
let cpuid = unsafe { core::arch::x86_64::__cpuid_count(0x1, 0) };
if cpuid.ecx & CPUID_TSC_DEADLINE_BIT == 0 {
Expand All @@ -32,7 +27,7 @@ pub fn schedule_timeout(timeout: u64) -> Option<u64> {
reset_timer();
let cpuid = unsafe { core::arch::x86_64::__cpuid_count(0x15, 0) };
let tsc_frequency = cpuid.ecx * (cpuid.ebx / cpuid.eax);
let timeout = timeout / 1000 * tsc_frequency as u64;
let timeout = tsc_frequency as u64 * timeout / 1000;

apic_timer_lvtt_setup(TIMEOUT_VECTOR);
one_shot_tsc_deadline_mode(timeout)
Expand All @@ -47,6 +42,20 @@ pub fn reset_timer() {
TIMEOUT_FLAG.store(false, Ordering::SeqCst);
}

pub fn set_timer_callback(cb: fn()) {
TIMEOUT_CALLBACK.call_once(|| cb);
}

interrupt_handler_template!(timer, _stack, {
TIMEOUT_CALLBACK
.get()
.unwrap_or(&(default_callback as fn()))();
});

fn default_callback() {
TIMEOUT_FLAG.store(true, Ordering::SeqCst);
}

fn set_lvtt(val: u32) {
unsafe { x86::msr::wrmsr(MSR_LVTT, val as u64) }
}
Expand Down
4 changes: 4 additions & 0 deletions src/migtd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ pub extern "C" fn _start(hob: u64, payload: u64) -> ! {
init(payload);

// Initialize the timer based on the APIC TSC deadline mode
#[cfg(not(feature = "async"))]
driver::timer::init_timer();

#[cfg(feature = "async")]
driver::ticks::init_sys_tick();

#[cfg(feature = "virtio-serial")]
driver::serial::virtio_serial_device_init(end_of_ram() as u64);

Expand Down
9 changes: 9 additions & 0 deletions src/migtd/src/migration/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pub mod event;
// #[cfg(any(target_os = "none", target_os = "uefi"))]
pub mod session;

#[cfg(feature = "async")]
use crate::driver::ticks::TimeoutError;
use crate::ratls::RatlsError;
use crate::ratls::MIG_POLICY_ERROR;
use crate::ratls::MUTUAL_ATTESTATION_ERROR;
Expand Down Expand Up @@ -214,3 +216,10 @@ impl From<TdCallError> for MigrationResult {
MigrationResult::TdxModuleError
}
}

#[cfg(feature = "async")]
impl From<TimeoutError> for MigrationResult {
fn from(_: TimeoutError) -> Self {
MigrationResult::NetworkError
}
}
18 changes: 12 additions & 6 deletions src/migtd/src/migration/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use alloc::collections::BTreeSet;
use alloc::vec::Vec;
use core::mem::size_of;
#[cfg(feature = "async")]
use core::time::Duration;
#[cfg(feature = "async")]
use lazy_static::lazy_static;
use scroll::Pread;
#[cfg(feature = "async")]
Expand All @@ -19,11 +21,15 @@ use zerocopy::AsBytes;
type Result<T> = core::result::Result<T, MigrationResult>;

use super::{data::*, *};
#[cfg(feature = "async")]
use crate::driver::ticks::with_timeout;
use crate::ratls;

const TDCS_FIELD_MIG_DEC_KEY: u64 = 0x9810_0003_0000_0010;
const TDCS_FIELD_MIG_ENC_KEY: u64 = 0x9810_0003_0000_0018;
const MSK_SIZE: usize = 32;
#[cfg(feature = "async")]
const TLS_TIMEOUT: Duration = Duration::from_secs(10); // 10 seconds

#[cfg(feature = "async")]
lazy_static! {
Expand Down Expand Up @@ -363,9 +369,9 @@ pub async fn trans_msk_async(info: &MigrationInformation) -> Result<()> {
ratls::async_client(transport).map_err(|_| MigrationResult::SecureSessionError)?;

// MigTD-S send Migration Session Forward key to peer
ratls_client.start().await?;
ratls_client.write(msk.as_bytes()).await?;
let size = ratls_client.read(msk_peer.as_bytes_mut()).await?;
with_timeout(TLS_TIMEOUT, ratls_client.start()).await??;
with_timeout(TLS_TIMEOUT, ratls_client.write(msk.as_bytes())).await??;
let size = with_timeout(TLS_TIMEOUT, ratls_client.read(msk_peer.as_bytes_mut())).await??;
if size < MSK_SIZE {
return Err(MigrationResult::NetworkError);
}
Expand All @@ -374,9 +380,9 @@ pub async fn trans_msk_async(info: &MigrationInformation) -> Result<()> {
let mut ratls_server =
ratls::async_server(transport).map_err(|_| MigrationResult::SecureSessionError)?;

ratls_server.start().await?;
ratls_server.write(msk.as_bytes()).await?;
let size = ratls_server.read(msk_peer.as_bytes_mut()).await?;
with_timeout(TLS_TIMEOUT, ratls_server.start()).await??;
with_timeout(TLS_TIMEOUT, ratls_server.write(msk.as_bytes())).await??;
let size = with_timeout(TLS_TIMEOUT, ratls_server.read(msk_peer.as_bytes_mut())).await??;
if size < MSK_SIZE {
return Err(MigrationResult::NetworkError);
}
Expand Down

0 comments on commit 8852a00

Please sign in to comment.