diff --git a/Cargo.lock b/Cargo.lock index 35f64f07..a7a4b1ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -361,12 +361,30 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "futures-core" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" + [[package]] name = "futures-task" version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +[[package]] +name = "futures-util" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "pin-utils", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -548,6 +566,7 @@ dependencies = [ "bitfield", "cc-measurement", "crypto", + "futures-util", "lazy_static", "log", "minicov", @@ -603,6 +622,18 @@ dependencies = [ "x86", ] +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "policy" version = "0.1.0" diff --git a/src/migtd/Cargo.toml b/src/migtd/Cargo.toml index 56458c41..3f04dcb2 100644 --- a/src/migtd/Cargo.toml +++ b/src/migtd/Cargo.toml @@ -16,6 +16,7 @@ anyhow = { version = "1.0.68", default-features = false } attestation = { path = "../attestation", default-features = false } cc-measurement = { path = "../../deps/td-shim/cc-measurement"} crypto = { path = "../crypto" } +futures-util = { version = "0.3.17", default-features = false, optional = true } lazy_static = { version = "1.0", features = ["spin_no_std"] } log = { version = "0.4.13", features = ["release_max_level_off"] } pci = { path="../devices/pci" } @@ -44,7 +45,7 @@ td-benchmark = { path = "../../deps/td-shim/devtools/td-benchmark", default-feat [features] default = ["tdx"] -async = ["crypto/async", "vsock/async", "async_io", "async_runtime"] +async = ["crypto/async", "async_io", "async_runtime", "futures-util", "vsock/async"] cet-shstk = ["td-payload/cet-shstk"] coverage = ["minicov"] main = [] diff --git a/src/migtd/src/driver/mod.rs b/src/migtd/src/driver/mod.rs index 938e1cc5..e8e10296 100644 --- a/src/migtd/src/driver/mod.rs +++ b/src/migtd/src/driver/mod.rs @@ -4,6 +4,8 @@ #[cfg(feature = "virtio-serial")] pub mod serial; +#[cfg(feature = "async")] +pub mod ticks; pub mod timer; #[cfg(any(feature = "virtio-vsock", feature = "vmcall-vsock"))] pub mod vsock; diff --git a/src/migtd/src/driver/ticks.rs b/src/migtd/src/driver/ticks.rs new file mode 100644 index 00000000..25306d2d --- /dev/null +++ b/src/migtd/src/driver/ticks.rs @@ -0,0 +1,81 @@ +// Copyright (c) 2023 Intel Corporation +// +// SPDX-License-Identifier: BSD-2-Clause-Patent + +use core::{ + future::Future, + pin::Pin, + sync::atomic::{AtomicU64, Ordering}, + task::{Context, Poll}, + time::Duration, +}; + +use super::timer::*; + +static SYS_TICK: AtomicU64 = AtomicU64::new(0); +const INTERVAL: u64 = 1; + +pub struct TimeoutError; + +pub fn init_sys_tick() { + init_timer(); + set_timer_callback(timer_callback); + schedule_timeout(INTERVAL); +} + +fn timer_callback() { + SYS_TICK + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| v.checked_add(1)) + .unwrap(); + schedule_timeout(INTERVAL); +} + +fn now() -> u64 { + SYS_TICK.load(Ordering::SeqCst) +} + +/// Runs a given future with a timeout. +#[cfg(feature = "async")] +pub async fn with_timeout(timeout: Duration, fut: F) -> Result { + use futures_util::{ + future::{select, Either}, + pin_mut, + }; + + pin_mut!(fut); + let timeout_fut = Timer::after(timeout); + + match select(fut, timeout_fut).await { + Either::Left((r, _)) => Ok(r), + Either::Right(_) => Err(TimeoutError), + } +} + +struct Timer { + expires_at: u128, + yielded_once: bool, +} + +impl Timer { + /// Expire after specified duration. + pub fn after(duration: Duration) -> Self { + Self { + expires_at: now() as u128 + duration.as_millis(), + yielded_once: false, + } + } +} + +impl Unpin for Timer {} + +impl Future for Timer { + type Output = (); + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + if self.yielded_once && self.expires_at <= now() as u128 { + Poll::Ready(()) + } else { + self.yielded_once = true; + Poll::Pending + } + } +} diff --git a/src/migtd/src/driver/timer.rs b/src/migtd/src/driver/timer.rs index 0561653e..1d14f8be 100644 --- a/src/migtd/src/driver/timer.rs +++ b/src/migtd/src/driver/timer.rs @@ -3,22 +3,17 @@ // SPDX-License-Identifier: BSD-2-Clause-Patent use core::sync::atomic::{AtomicBool, Ordering}; +use spin::Once; use td_payload::arch::apic::*; use td_payload::arch::idt::register; use td_payload::interrupt_handler_template; -/// A simple apic timer notification handler used to handle the -/// time out events - +static TIMEOUT_CALLBACK: Once = Once::new(); static TIMEOUT_FLAG: AtomicBool = AtomicBool::new(false); const TIMEOUT_VECTOR: u8 = 33; const CPUID_TSC_DEADLINE_BIT: u32 = 1 << 24; -interrupt_handler_template!(timer, _stack, { - TIMEOUT_FLAG.store(true, Ordering::SeqCst); -}); - pub fn init_timer() { let cpuid = unsafe { core::arch::x86_64::__cpuid_count(0x1, 0) }; if cpuid.ecx & CPUID_TSC_DEADLINE_BIT == 0 { @@ -32,7 +27,7 @@ pub fn schedule_timeout(timeout: u64) -> Option { reset_timer(); let cpuid = unsafe { core::arch::x86_64::__cpuid_count(0x15, 0) }; let tsc_frequency = cpuid.ecx * (cpuid.ebx / cpuid.eax); - let timeout = timeout / 1000 * tsc_frequency as u64; + let timeout = tsc_frequency as u64 * timeout / 1000; apic_timer_lvtt_setup(TIMEOUT_VECTOR); one_shot_tsc_deadline_mode(timeout) @@ -47,6 +42,20 @@ pub fn reset_timer() { TIMEOUT_FLAG.store(false, Ordering::SeqCst); } +pub fn set_timer_callback(cb: fn()) { + TIMEOUT_CALLBACK.call_once(|| cb); +} + +interrupt_handler_template!(timer, _stack, { + TIMEOUT_CALLBACK + .get() + .unwrap_or(&(default_callback as fn()))(); +}); + +fn default_callback() { + TIMEOUT_FLAG.store(true, Ordering::SeqCst); +} + fn set_lvtt(val: u32) { unsafe { x86::msr::wrmsr(MSR_LVTT, val as u64) } } diff --git a/src/migtd/src/lib.rs b/src/migtd/src/lib.rs index f41996b2..94f938ce 100644 --- a/src/migtd/src/lib.rs +++ b/src/migtd/src/lib.rs @@ -55,8 +55,12 @@ pub extern "C" fn _start(hob: u64, payload: u64) -> ! { init(payload); // Initialize the timer based on the APIC TSC deadline mode + #[cfg(not(feature = "async"))] driver::timer::init_timer(); + #[cfg(feature = "async")] + driver::ticks::init_sys_tick(); + #[cfg(feature = "virtio-serial")] driver::serial::virtio_serial_device_init(end_of_ram() as u64); diff --git a/src/migtd/src/migration/mod.rs b/src/migtd/src/migration/mod.rs index 64e8f6a9..ca253137 100644 --- a/src/migtd/src/migration/mod.rs +++ b/src/migtd/src/migration/mod.rs @@ -7,6 +7,8 @@ pub mod event; // #[cfg(any(target_os = "none", target_os = "uefi"))] pub mod session; +#[cfg(feature = "async")] +use crate::driver::ticks::TimeoutError; use crate::ratls::RatlsError; use crate::ratls::MIG_POLICY_ERROR; use crate::ratls::MUTUAL_ATTESTATION_ERROR; @@ -214,3 +216,10 @@ impl From for MigrationResult { MigrationResult::TdxModuleError } } + +#[cfg(feature = "async")] +impl From for MigrationResult { + fn from(_: TimeoutError) -> Self { + MigrationResult::NetworkError + } +} diff --git a/src/migtd/src/migration/session.rs b/src/migtd/src/migration/session.rs index ad5eaf63..355190a6 100644 --- a/src/migtd/src/migration/session.rs +++ b/src/migtd/src/migration/session.rs @@ -7,6 +7,8 @@ use alloc::collections::BTreeSet; use alloc::vec::Vec; use core::mem::size_of; #[cfg(feature = "async")] +use core::time::Duration; +#[cfg(feature = "async")] use lazy_static::lazy_static; use scroll::Pread; #[cfg(feature = "async")] @@ -19,11 +21,15 @@ use zerocopy::AsBytes; type Result = core::result::Result; use super::{data::*, *}; +#[cfg(feature = "async")] +use crate::driver::ticks::with_timeout; use crate::ratls; const TDCS_FIELD_MIG_DEC_KEY: u64 = 0x9810_0003_0000_0010; const TDCS_FIELD_MIG_ENC_KEY: u64 = 0x9810_0003_0000_0018; const MSK_SIZE: usize = 32; +#[cfg(feature = "async")] +const TLS_TIMEOUT: Duration = Duration::from_secs(10); // 10 seconds #[cfg(feature = "async")] lazy_static! { @@ -363,9 +369,9 @@ pub async fn trans_msk_async(info: &MigrationInformation) -> Result<()> { ratls::async_client(transport).map_err(|_| MigrationResult::SecureSessionError)?; // MigTD-S send Migration Session Forward key to peer - ratls_client.start().await?; - ratls_client.write(msk.as_bytes()).await?; - let size = ratls_client.read(msk_peer.as_bytes_mut()).await?; + with_timeout(TLS_TIMEOUT, ratls_client.start()).await??; + with_timeout(TLS_TIMEOUT, ratls_client.write(msk.as_bytes())).await??; + let size = with_timeout(TLS_TIMEOUT, ratls_client.read(msk_peer.as_bytes_mut())).await??; if size < MSK_SIZE { return Err(MigrationResult::NetworkError); } @@ -374,9 +380,9 @@ pub async fn trans_msk_async(info: &MigrationInformation) -> Result<()> { let mut ratls_server = ratls::async_server(transport).map_err(|_| MigrationResult::SecureSessionError)?; - ratls_server.start().await?; - ratls_server.write(msk.as_bytes()).await?; - let size = ratls_server.read(msk_peer.as_bytes_mut()).await?; + with_timeout(TLS_TIMEOUT, ratls_server.start()).await??; + with_timeout(TLS_TIMEOUT, ratls_server.write(msk.as_bytes())).await??; + let size = with_timeout(TLS_TIMEOUT, ratls_server.read(msk_peer.as_bytes_mut())).await??; if size < MSK_SIZE { return Err(MigrationResult::NetworkError); }