Skip to content

Commit

Permalink
Ensure page tables are properly pinned into memory
Browse files Browse the repository at this point in the history
When we load the page tables into memory, we need to guarantee that they
stay valid at the same address, as the processor cares little about Rust
and its memory model.

The Rust way to do it is to `Pin` the pointers into memory. This
guarantees that the value will not move. Note that _overwriting_ the
value is safe: instead of `zero()` we now call `Pin::set`, which
overwrites the contents in-place, guaranteeing that the contents of the
memory at that address are always a valid page table.

We do need to resort to unsafe code to get a mutable reference for
setting the page table entries, though. This is because Rust can't
guarantee we don't call `mem::replace` on it.

This fixes one issue around the page tables; there's also the question
of lifetimes themselves (we don't want a `Box` containing a page table
to go out of scope while it's still refrerenced by other page tables),
but fixing that will be a separate exercise.

Bug: 377899703
Change-Id: I10559eda2c8a60f77201c3cc3cedabdac864f26f
  • Loading branch information
andrisaar committed Nov 8, 2024
1 parent bf97164 commit be2f057
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
25 changes: 13 additions & 12 deletions stage0/src/paging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use alloc::boxed::Box;
use core::{
marker::PhantomData,
ops::{Index, IndexMut},
pin::Pin,
ptr::addr_of_mut,
};

Expand Down Expand Up @@ -49,21 +50,21 @@ pub static mut PD_3: PageTable<page_table_level::PD> = PageTable::new();
pub struct PageTableRefs {
/// The root page-map level 4 table coverting virtual memory ranges
/// 0..128TiB and (16EiB-128TiB)..16EiB.
pub pml4: &'static mut PageTable<page_table_level::PML4>,
pub pml4: Pin<&'static mut PageTable<page_table_level::PML4>>,

/// The page-directory pointer table covering virtual memory range
/// 0..512GiB.
pub pdpt: &'static mut PageTable<page_table_level::PDPT>,
pub pdpt: Pin<&'static mut PageTable<page_table_level::PDPT>>,

/// The page directory covering virtual memory range 0..1GiB.
pub pd_0: &'static mut PageTable<page_table_level::PD>,
pub pd_0: Pin<&'static mut PageTable<page_table_level::PD>>,

/// The page directory covering virtual memory range 3..4GiB.
pub pd_3: &'static mut PageTable<page_table_level::PD>,
pub pd_3: Pin<&'static mut PageTable<page_table_level::PD>>,

/// The page table covering virtual memory range 0..2MiB where we want 4KiB
/// pages.
pub pt_0: Box<PageTable<page_table_level::PT>, &'static BootAllocator>,
pub pt_0: Pin<Box<PageTable<page_table_level::PT>, &'static BootAllocator>>,
}

/// References to all the pages tables we care about.
Expand Down Expand Up @@ -206,10 +207,10 @@ where
/// TDX.
pub fn set_lower_level_table<P: Platform>(
&mut self,
pdpt: &PageTable<Ln>,
pt: Pin<&PageTable<Ln>>,
flags: PageTableFlags,
) {
self.inner.set_addr(PhysAddr::new(pdpt as *const PageTable<Ln> as u64), flags)
self.inner.set_addr(PhysAddr::new(pt.get_ref() as *const _ as u64), flags)
}
}

Expand Down Expand Up @@ -306,16 +307,16 @@ pub fn init_page_table_refs<P: Platform>() {
// Safety: accessing the mutable statics here is safe since we only do it once
// and protect the mutable references with a mutex. This function can only
// be called once, since updating `PAGE_TABLE_REFS` twice will panic.
let pml4: &mut PageTable<page_table_level::PML4> = unsafe { &mut *addr_of_mut!(PML4) };
let pdpt: &mut PageTable<page_table_level::PDPT> = unsafe { &mut *addr_of_mut!(PDPT) };
let pd_0: &mut PageTable<page_table_level::PD> = unsafe { &mut *addr_of_mut!(PD_0) };
let pd_3: &mut PageTable<page_table_level::PD> = unsafe { &mut *addr_of_mut!(PD_3) };
let pml4 = Pin::static_mut(unsafe { &mut *addr_of_mut!(PML4) });
let pdpt = Pin::static_mut(unsafe { &mut *addr_of_mut!(PDPT) });
let mut pd_0 = Pin::static_mut(unsafe { &mut *addr_of_mut!(PD_0) });
let pd_3 = Pin::static_mut(unsafe { &mut *addr_of_mut!(PD_3) });

// Set up a new page table that maps the first 2MiB as 4KiB pages (except for
// the lower 4KiB), so that we can share individual 4KiB pages with the
// hypervisor as needed. We are using an identity mapping between virtual
// and physical addresses.
let mut pt_0 = Box::new_in(PageTable::new(), &BOOT_ALLOC);
let mut pt_0 = Box::pin_in(PageTable::new(), &BOOT_ALLOC);
// Let entry 1 map to 4KiB, entry 2 to 8KiB, ... , entry 511 to 2MiB-4KiB:
// We leave [0,4K) unmapped to make sure null pointer dereferences crash
// with a page fault.
Expand Down
11 changes: 7 additions & 4 deletions stage0_bin_tdx/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use core::{
mem::{size_of, MaybeUninit},
ops::{Index, IndexMut},
panic::PanicInfo,
pin::pin,
ptr::addr_of,
sync::atomic::Ordering,
};
Expand Down Expand Up @@ -384,20 +385,22 @@ impl oak_stage0::Platform for Tdx {

info!("starting TDX memory acceptance");
let mut page_tables = paging::PAGE_TABLE_REFS.get().unwrap().lock();
let accept_pd_pt = oak_stage0::paging::PageTable::new();
let accept_pd_pt = pin!(oak_stage0::paging::PageTable::new());
if page_tables.pdpt[1].flags().contains(PageTableFlags::PRESENT) {
panic!("PDPT[1] is in use");
}

page_tables.pdpt[1].set_lower_level_table::<Tdx>(&accept_pd_pt, PageTableFlags::PRESENT);
page_tables.pdpt[1]
.set_lower_level_table::<Tdx>(accept_pd_pt.as_ref(), PageTableFlags::PRESENT);
info!("added pdpt[1]");

info!("adding pd_0[1]");
let accept_pt_pt = oak_stage0::paging::PageTable::new();
let accept_pt_pt = pin!(oak_stage0::paging::PageTable::new());
if page_tables.pd_0[1].flags().contains(PageTableFlags::PRESENT) {
panic!("PD_0[1] is in use");
}
page_tables.pd_0[1].set_lower_level_table::<Tdx>(&accept_pt_pt, PageTableFlags::PRESENT);
page_tables.pd_0[1]
.set_lower_level_table::<Tdx>(accept_pt_pt.as_ref(), PageTableFlags::PRESENT);
info!("added pd_0[1]");

let min_addr = 0xA0000;
Expand Down
24 changes: 15 additions & 9 deletions stage0_sev/src/platform/accept_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use core::sync::atomic::{AtomicUsize, Ordering};
use alloc::boxed::Box;
use core::{
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
};

use oak_linux_boot_params::{BootE820Entry, E820EntryType};
use oak_sev_guest::{
Expand Down Expand Up @@ -72,13 +76,13 @@ impl<S: PageSize + ValidatablePageSize> Validate<S> for Page<S> {
/// mappings for that page.
struct MappedPage<L: Leaf> {
pub page: Page<L::Size>,
pub page_table: PageTable<L>,
pub page_table: Pin<Box<PageTable<L>>>,
}

impl<L: Leaf> MappedPage<L> {
pub fn new(vaddr: VirtAddr) -> Result<Self, AddressNotAligned> {
let mapped_page =
Self { page: Page::from_start_address(vaddr)?, page_table: PageTable::new() };
Self { page: Page::from_start_address(vaddr)?, page_table: Box::pin(PageTable::new()) };
Ok(mapped_page)
}
}
Expand Down Expand Up @@ -107,9 +111,11 @@ where
// directory itself, and for each entry, fetch the next entry in the range.
// If we've covered everything, `range.next()` will return `None`, and
// the count will be zero once we've covered everything.
memory.page_table.zero();
while memory
.page_table
memory.page_table.set(PageTable::new());

// Safety: the call to `get_unchecked_mut` is safe as we will _not_ move the
// value out of `Pin`.
while unsafe { memory.page_table.as_mut().get_unchecked_mut() }
.iter_mut()
.filter_map(|entry| range.next().map(|frame| (entry, frame)))
.map(|(entry, frame)| {
Expand Down Expand Up @@ -139,7 +145,7 @@ where
}

// Clear out the page table, ready for the next iteration.
memory.page_table.zero();
memory.page_table.set(PageTable::new());
}

Ok(())
Expand Down Expand Up @@ -272,7 +278,7 @@ pub fn validate_memory(e820_table: &[BootE820Entry]) {
panic!("PDPT[1] is in use");
}
page_tables.pdpt[1]
.set_lower_level_table::<Sev>(&validation_pd.page_table, PageTableFlags::PRESENT);
.set_lower_level_table::<Sev>(validation_pd.page_table.as_ref(), PageTableFlags::PRESENT);

// Page table, for validation with 4 KiB pages.
let mut validation_pt = MappedPage::new(VirtAddr::new(Size2MiB::SIZE)).unwrap();
Expand All @@ -283,7 +289,7 @@ pub fn validate_memory(e820_table: &[BootE820Entry]) {
panic!("PD_0[1] is in use");
}
page_tables.pd_0[1]
.set_lower_level_table::<Sev>(&validation_pt.page_table, PageTableFlags::PRESENT);
.set_lower_level_table::<Sev>(validation_pt.page_table.as_ref(), PageTableFlags::PRESENT);

// We already pvalidated the memory in the first 640KiB of RAM in the boot
// assembly code. We avoid redoing this as calling pvalidate again on these
Expand Down

0 comments on commit be2f057

Please sign in to comment.