Skip to content

Commit

Permalink
td-payload: change Dma to SharedMemory
Browse files Browse the repository at this point in the history
To avoid misunderstanding.

Signed-off-by: Jiaqi Gao <[email protected]>
  • Loading branch information
gaojiaqi7 authored and jyao1 committed Nov 4, 2023
1 parent a462a60 commit 0001b34
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 40 deletions.
8 changes: 4 additions & 4 deletions td-payload/src/arch/x86_64/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use crate::{
arch::{gdt, idt},
hob::{self, get_hob},
mm::{
dma::init_dma, get_usable, heap::init_heap, init_ram, layout::RuntimeLayout,
page_table::init_pt_frame_allocator,
get_usable, heap::init_heap, init_ram, layout::RuntimeLayout,
page_table::init_pt_frame_allocator, shared::init_shared_memory,
},
};

Expand All @@ -34,8 +34,8 @@ pub fn pre_init(hob: u64, layout: &RuntimeLayout) {
let heap = get_usable(layout.heap_size).expect("Failed to allocate heap");
init_heap(heap, layout.heap_size);

let dma = get_usable(layout.dma_size).expect("Failed to allocate dma");
init_dma(dma, layout.dma_size);
let shared = get_usable(layout.shared_memory_size).expect("Failed to allocate shared memory");
init_shared_memory(shared, layout.shared_memory_size);

// Init Global Descriptor Table and Task State Segment
gdt::init_gdt();
Expand Down
4 changes: 2 additions & 2 deletions td-payload/src/arch/x86_64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

pub mod apic;
pub mod cet;
#[cfg(feature = "tdx")]
pub mod dma;
pub mod gdt;
pub mod guard_page;
pub mod idt;
pub mod init;
pub mod paging;
pub mod serial;
#[cfg(feature = "tdx")]
pub mod shared;
File renamed without changes.
6 changes: 3 additions & 3 deletions td-payload/src/bin/example/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ pub extern "C" fn main() -> ! {
#[cfg(all(feature = "coverage", feature = "tdx"))]
{
const MAX_COVERAGE_DATA_PAGE_COUNT: usize = 0x100;
let mut dma = td_payload::mm::dma::DmaMemory::new(MAX_COVERAGE_DATA_PAGE_COUNT)
.expect("New dma fail.");
let buffer = dma.as_mut_bytes();
let mut shared = td_payload::mm::shared::SharedMemory::new(MAX_COVERAGE_DATA_PAGE_COUNT)
.expect("New shared memory fail.");
let buffer = shared.as_mut_bytes();

let coverage_len = minicov::get_coverage_data_size();
assert!(coverage_len < MAX_COVERAGE_DATA_PAGE_COUNT * td_paging::PAGE_SIZE);
Expand Down
6 changes: 3 additions & 3 deletions td-payload/src/mm/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
pub const DEFAULT_HEAP_SIZE: usize = 0x1000000;
pub const DEFAULT_STACK_SIZE: usize = 0x800000;
pub const DEFAULT_PAGE_TABLE_SIZE: usize = 0x800000;
pub const DEFAULT_DMA_SIZE: usize = 0x100000;
pub const DEFAULT_SHARED_MEMORY_SIZE: usize = 0x100000;
#[cfg(feature = "cet-shstk")]
pub const DEFAULT_SHADOW_STACK_SIZE: usize = 0x10000;

Expand All @@ -14,7 +14,7 @@ pub struct RuntimeLayout {
pub heap_size: usize,
pub stack_size: usize,
pub page_table_size: usize,
pub dma_size: usize,
pub shared_memory_size: usize,
#[cfg(feature = "cet-shstk")]
pub shadow_stack_size: usize,
}
Expand All @@ -25,7 +25,7 @@ impl Default for RuntimeLayout {
heap_size: DEFAULT_HEAP_SIZE,
stack_size: DEFAULT_STACK_SIZE,
page_table_size: DEFAULT_PAGE_TABLE_SIZE,
dma_size: DEFAULT_DMA_SIZE,
shared_memory_size: DEFAULT_SHARED_MEMORY_SIZE,
#[cfg(feature = "cet-shstk")]
shadow_stack_size: DEFAULT_SHADOW_STACK_SIZE,
}
Expand Down
4 changes: 2 additions & 2 deletions td-payload/src/mm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ use zerocopy::FromBytes;

use crate::Error;

#[cfg(feature = "tdx")]
pub mod dma;
#[cfg(any(target_os = "none", target_os = "uefi"))]
pub(crate) mod heap;
#[cfg(feature = "tdx")]
pub mod shared;
#[cfg(not(any(target_os = "none", target_os = "uefi")))]
pub(crate) mod heap {
// A null implementation used by test
Expand Down
42 changes: 21 additions & 21 deletions td-payload/src/mm/dma.rs → td-payload/src/mm/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@ use core::{alloc::Layout, ptr::NonNull};
use linked_list_allocator::LockedHeap;

use super::SIZE_4K;
use crate::arch::dma::decrypt;
use crate::arch::shared::decrypt;

static DMA_ALLOCATOR: LockedHeap = LockedHeap::empty();
static SHARED_MEMORY_ALLOCATOR: LockedHeap = LockedHeap::empty();

pub fn init_dma(start: u64, size: usize) {
// Set the DMA memory region to be shared
pub fn init_shared_memory(start: u64, size: usize) {
// Set the shared memory region to be shared
decrypt(start, size);
// Initialize the DMA allocator
// Initialize the shared memory allocator
unsafe {
DMA_ALLOCATOR.lock().init(start as *mut u8, size);
SHARED_MEMORY_ALLOCATOR.lock().init(start as *mut u8, size);
}
}

pub struct DmaMemory {
pub struct SharedMemory {
addr: usize,
size: usize,
}

impl DmaMemory {
impl SharedMemory {
pub fn new(num_page: usize) -> Option<Self> {
let addr = unsafe { alloc_dma_pages(num_page)? };
let addr = unsafe { alloc_shared_pages(num_page)? };

Some(Self {
addr,
Expand All @@ -43,18 +43,18 @@ impl DmaMemory {
}
}

impl Drop for DmaMemory {
impl Drop for SharedMemory {
fn drop(&mut self) {
unsafe { free_dma_pages(self.addr, self.size / SIZE_4K) }
unsafe { free_shared_pages(self.addr, self.size / SIZE_4K) }
}
}

/// # Safety
/// The caller needs to explicitly call the `free_dma_pages` function after use
pub unsafe fn alloc_dma_pages(num: usize) -> Option<usize> {
/// The caller needs to explicitly call the `free_shared_pages` function after use
pub unsafe fn alloc_shared_pages(num: usize) -> Option<usize> {
let size = SIZE_4K.checked_mul(num)?;

let addr = DMA_ALLOCATOR
let addr = SHARED_MEMORY_ALLOCATOR
.lock()
.allocate_first_fit(Layout::from_size_align(size, SIZE_4K).ok()?)
.map(|ptr| ptr.as_ptr() as usize)
Expand All @@ -66,24 +66,24 @@ pub unsafe fn alloc_dma_pages(num: usize) -> Option<usize> {
}

/// # Safety
/// The caller needs to explicitly call the `free_dma_page` function after use
pub unsafe fn alloc_dma_page() -> Option<usize> {
alloc_dma_pages(1)
/// The caller needs to explicitly call the `free_shared_page` function after use
pub unsafe fn alloc_shared_page() -> Option<usize> {
alloc_shared_pages(1)
}

/// # Safety
/// The caller needs to ensure the correctness of the addr and page num
pub unsafe fn free_dma_pages(addr: usize, num: usize) {
pub unsafe fn free_shared_pages(addr: usize, num: usize) {
let size = SIZE_4K.checked_mul(num).expect("Invalid page num");

DMA_ALLOCATOR.lock().deallocate(
SHARED_MEMORY_ALLOCATOR.lock().deallocate(
NonNull::new(addr as *mut u8).unwrap(),
Layout::from_size_align(size, SIZE_4K).unwrap(),
);
}

/// # Safety
/// The caller needs to ensure the correctness of the addr
pub unsafe fn free_dma_page(addr: usize) {
free_dma_pages(addr, 1)
pub unsafe fn free_shared_page(addr: usize) {
free_shared_pages(addr, 1)
}
10 changes: 5 additions & 5 deletions tests/test-td-payload/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub extern "C" fn _start(hob: u64, _payload: u64) -> ! {
heap_size: layout::DEFAULT_HEAP_SIZE,
stack_size: layout::DEFAULT_STACK_SIZE,
page_table_size: PAGE_TABLE_SIZE,
dma_size: layout::DEFAULT_DMA_SIZE,
shared_memory_size: layout::DEFAULT_SHARED_MEMORY_SIZE,
shadow_stack_size: layout::DEFAULT_SHADOW_STACK_SIZE,
};

Expand Down Expand Up @@ -274,13 +274,13 @@ extern "C" fn main() -> ! {
ts.failed_cases
);

// Need to set DEFAULT_DMA_SIZE to 0x200000 before build
// Need to set DEFAULT_SHARED_MEMORY_SIZE to 0x200000 before build
#[cfg(all(feature = "coverage", feature = "tdx"))]
{
const MAX_COVERAGE_DATA_PAGE_COUNT: usize = 0x200;
let mut dma = td_payload::mm::dma::DmaMemory::new(MAX_COVERAGE_DATA_PAGE_COUNT)
.expect("New dma fail.");
let buffer = dma.as_mut_bytes();
let mut shared = td_payload::mm::shared::SharedMemory::new(MAX_COVERAGE_DATA_PAGE_COUNT)
.expect("New shared memory fail.");
let buffer = shared.as_mut_bytes();
let coverage_len = minicov::get_coverage_data_size();
assert!(coverage_len < MAX_COVERAGE_DATA_PAGE_COUNT * td_paging::PAGE_SIZE);
minicov::capture_coverage_to_buffer(&mut buffer[0..coverage_len]);
Expand Down

0 comments on commit 0001b34

Please sign in to comment.