Skip to content

Commit

Permalink
manually impl Deserialize for PciPath for validation purposes (#801)
Browse files Browse the repository at this point in the history
This ensures that PciPaths generated via deserialization obey all the same
constraints as PciPaths generated by PciPath::new.
  • Loading branch information
gjcolombo authored Oct 25, 2024
1 parent 9ad7368 commit 7627cff
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 25 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/propolis-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ license = "MPL-2.0"
edition = "2021"

[lib]
test = false
doctest = false

[dependencies]
schemars = { workspace = true, features = [ "uuid1" ] }
serde.workspace = true

[dev-dependencies]
serde_json.workspace = true
serde_test.workspace = true
122 changes: 98 additions & 24 deletions crates/propolis-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,11 @@ const PCI_DEVICES_PER_BUS: u8 = 32;
const PCI_FUNCTIONS_PER_DEVICE: u8 = 8;

/// A PCI bus/device/function tuple.
//
// N.B. Field names here should be kept in sync with the helper struct in the
// Deserialize impl below.
#[derive(
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Debug,
JsonSchema,
Serialize,
Deserialize,
Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, JsonSchema, Serialize,
)]
pub struct PciPath {
bus: u8,
Expand Down Expand Up @@ -119,27 +113,48 @@ impl Display for PciPath {
}
}

impl<'de> Deserialize<'de> for PciPath {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
// N.B. The field names here should be kept in sync with the actual
// PciPath structure above.
#[derive(Deserialize)]
struct Raw {
bus: u8,
device: u8,
function: u8,
}

let raw = Raw::deserialize(deserializer)?;

Self::new(raw.bus, raw.device, raw.function)
.map_err(|e| serde::de::Error::custom(e.to_string()))
}
}

#[cfg(test)]
mod test {
use super::PciPath;
use std::str::FromStr;

const TEST_CASES: &[(&str, Result<PciPath, ()>)] = &[
("0.7.0", Ok(PciPath { bus: 0, device: 7, function: 0 })),
("1.2.3", Ok(PciPath { bus: 1, device: 2, function: 3 })),
("0.40.0", Err(())),
("0.1.9", Err(())),
("255.254.253", Err(())),
("1000.0.0", Err(())),
("4/3/4", Err(())),
("a.b.c", Err(())),
("1.5#4", Err(())),
("", Err(())),
("alas, poor PCI device", Err(())),
];

#[test]
fn pci_path_from_str() {
const TEST_CASES: &[(&str, Result<PciPath, ()>)] = &[
("0.7.0", Ok(PciPath { bus: 0, device: 7, function: 0 })),
("1.2.3", Ok(PciPath { bus: 1, device: 2, function: 3 })),
("0.40.0", Err(())),
("0.1.9", Err(())),
("255.254.253", Err(())),
("1000.0.0", Err(())),
("4/3/4", Err(())),
("a.b.c", Err(())),
("1.5#4", Err(())),
("", Err(())),
("alas, poor PCI device", Err(())),
];

for (input, expected) in TEST_CASES {
match PciPath::from_str(input) {
Ok(path) => assert_eq!(path, expected.unwrap()),
Expand All @@ -151,6 +166,65 @@ mod test {
}
}
}

fn check_pci_path_deserialization<E>(
input: &str,
expected: Result<PciPath, E>,
) {
let actual = serde_json::from_str::<PciPath>(input);
match (actual, expected) {
(Ok(parsed), Ok(expected)) => assert_eq!(parsed, expected),
(Ok(_), Err(_)) => {
panic!("expected to fail to deserialize input: {input}")
}
(Err(e), Ok(_)) => {
panic!("failed to deserialize input {input}: {e}")
}
(Err(_), Err(_)) => {}
}
}

#[test]
fn pci_path_deserialization() {
const TEST_CASES: &[(&str, Result<PciPath, ()>)] = &[
(
r#"{"bus": 0, "device": 7, "function": 0}"#,
Ok(PciPath { bus: 0, device: 7, function: 0 }),
),
(
r#"{"bus": 1, "device": 2, "function": 3}"#,
Ok(PciPath { bus: 1, device: 2, function: 3 }),
),
(r#"{"bus": 0, "device": 40, "function": 0}"#, Err(())),
(r#"{"bus": 0, "device": 1, "function": 9}"#, Err(())),
];

for (input, expected) in TEST_CASES {
check_pci_path_deserialization(input, *expected);
}
}

// This test is expensive, so don't run it by default.
#[test]
#[ignore]
fn pci_path_deserialization_exhaustive() {
for bus in 0..=255 {
for device in 0..=255 {
for function in 0..=255 {
let expected = PciPath::new(bus, device, function);
let json = format!(
"{{\
\"bus\": {bus},\
\"device\": {device},\
\"function\": {function}\
}}"
);

check_pci_path_deserialization(&json, expected);
}
}
}
}
}

/// A CPUID leaf/subleaf (function/index) specifier.
Expand Down

0 comments on commit 7627cff

Please sign in to comment.