diff --git a/Cargo.lock b/Cargo.lock index 04fd2b4c..ee24cf2f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4469,6 +4469,7 @@ version = "0.0.0" dependencies = [ "schemars", "serde", + "serde_json", "serde_test", ] diff --git a/crates/propolis-types/Cargo.toml b/crates/propolis-types/Cargo.toml index a963a5bb..3a24d8ad 100644 --- a/crates/propolis-types/Cargo.toml +++ b/crates/propolis-types/Cargo.toml @@ -5,7 +5,6 @@ license = "MPL-2.0" edition = "2021" [lib] -test = false doctest = false [dependencies] @@ -13,4 +12,5 @@ schemars = { workspace = true, features = [ "uuid1" ] } serde.workspace = true [dev-dependencies] +serde_json.workspace = true serde_test.workspace = true diff --git a/crates/propolis-types/src/lib.rs b/crates/propolis-types/src/lib.rs index 76831255..5325d565 100644 --- a/crates/propolis-types/src/lib.rs +++ b/crates/propolis-types/src/lib.rs @@ -20,17 +20,11 @@ const PCI_DEVICES_PER_BUS: u8 = 32; const PCI_FUNCTIONS_PER_DEVICE: u8 = 8; /// A PCI bus/device/function tuple. +// +// N.B. Field names here should be kept in sync with the helper struct in the +// Deserialize impl below. #[derive( - Clone, - Copy, - PartialEq, - Eq, - PartialOrd, - Ord, - Debug, - JsonSchema, - Serialize, - Deserialize, + Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, JsonSchema, Serialize, )] pub struct PciPath { bus: u8, @@ -119,27 +113,48 @@ impl Display for PciPath { } } +impl<'de> Deserialize<'de> for PciPath { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + // N.B. The field names here should be kept in sync with the actual + // PciPath structure above. + #[derive(Deserialize)] + struct Raw { + bus: u8, + device: u8, + function: u8, + } + + let raw = Raw::deserialize(deserializer)?; + + Self::new(raw.bus, raw.device, raw.function) + .map_err(|e| serde::de::Error::custom(e.to_string())) + } +} + #[cfg(test)] mod test { use super::PciPath; use std::str::FromStr; - const TEST_CASES: &[(&str, Result)] = &[ - ("0.7.0", Ok(PciPath { bus: 0, device: 7, function: 0 })), - ("1.2.3", Ok(PciPath { bus: 1, device: 2, function: 3 })), - ("0.40.0", Err(())), - ("0.1.9", Err(())), - ("255.254.253", Err(())), - ("1000.0.0", Err(())), - ("4/3/4", Err(())), - ("a.b.c", Err(())), - ("1.5#4", Err(())), - ("", Err(())), - ("alas, poor PCI device", Err(())), - ]; - #[test] fn pci_path_from_str() { + const TEST_CASES: &[(&str, Result)] = &[ + ("0.7.0", Ok(PciPath { bus: 0, device: 7, function: 0 })), + ("1.2.3", Ok(PciPath { bus: 1, device: 2, function: 3 })), + ("0.40.0", Err(())), + ("0.1.9", Err(())), + ("255.254.253", Err(())), + ("1000.0.0", Err(())), + ("4/3/4", Err(())), + ("a.b.c", Err(())), + ("1.5#4", Err(())), + ("", Err(())), + ("alas, poor PCI device", Err(())), + ]; + for (input, expected) in TEST_CASES { match PciPath::from_str(input) { Ok(path) => assert_eq!(path, expected.unwrap()), @@ -151,6 +166,65 @@ mod test { } } } + + fn check_pci_path_deserialization( + input: &str, + expected: Result, + ) { + let actual = serde_json::from_str::(input); + match (actual, expected) { + (Ok(parsed), Ok(expected)) => assert_eq!(parsed, expected), + (Ok(_), Err(_)) => { + panic!("expected to fail to deserialize input: {input}") + } + (Err(e), Ok(_)) => { + panic!("failed to deserialize input {input}: {e}") + } + (Err(_), Err(_)) => {} + } + } + + #[test] + fn pci_path_deserialization() { + const TEST_CASES: &[(&str, Result)] = &[ + ( + r#"{"bus": 0, "device": 7, "function": 0}"#, + Ok(PciPath { bus: 0, device: 7, function: 0 }), + ), + ( + r#"{"bus": 1, "device": 2, "function": 3}"#, + Ok(PciPath { bus: 1, device: 2, function: 3 }), + ), + (r#"{"bus": 0, "device": 40, "function": 0}"#, Err(())), + (r#"{"bus": 0, "device": 1, "function": 9}"#, Err(())), + ]; + + for (input, expected) in TEST_CASES { + check_pci_path_deserialization(input, *expected); + } + } + + // This test is expensive, so don't run it by default. + #[test] + #[ignore] + fn pci_path_deserialization_exhaustive() { + for bus in 0..=255 { + for device in 0..=255 { + for function in 0..=255 { + let expected = PciPath::new(bus, device, function); + let json = format!( + "{{\ + \"bus\": {bus},\ + \"device\": {device},\ + \"function\": {function}\ + }}" + ); + + check_pci_path_deserialization(&json, expected); + } + } + } + } } /// A CPUID leaf/subleaf (function/index) specifier.