From 646f5b9427302891efdb37da6df5672294d5892e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sosth=C3=A8ne=20Gu=C3=A9don?= Date: Fri, 17 Feb 2023 16:56:50 +0100 Subject: [PATCH 1/8] Run cargo fmt --- src/de.rs | 82 +++++++++++++++++++++++----------------------------- src/error.rs | 4 +-- src/lib.rs | 10 ++----- src/ser.rs | 37 ++++++++++-------------- 4 files changed, 55 insertions(+), 78 deletions(-) diff --git a/src/de.rs b/src/de.rs index 78fece31..40600b16 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1,8 +1,6 @@ use serde::Deserialize; -use serde::de::{ - IntoDeserializer, -}; +use serde::de::IntoDeserializer; use super::error::{Error, Result}; @@ -36,11 +34,7 @@ where use core::convert::TryInto; -use serde::de::{ - self, - DeserializeSeed, - Visitor, -}; +use serde::de::{self, DeserializeSeed, Visitor}; /// A structure for deserializing a cbor-smol message. pub struct Deserializer<'de> { @@ -104,24 +98,20 @@ impl<'de> Deserializer<'de> { } // TODO: name something like "one-byte-integer" - fn raw_deserialize_u8(&mut self, major: u8) -> Result - { + fn raw_deserialize_u8(&mut self, major: u8) -> Result { let additional = self.expect_major(major)?; match additional { byte @ 0..=23 => Ok(byte), - 24 => { - match self.try_take_n(1)?[0] { - 0..=23 => Err(Error::DeserializeNonMinimal), - byte => Ok(byte), - } + 24 => match self.try_take_n(1)?[0] { + 0..=23 => Err(Error::DeserializeNonMinimal), + byte => Ok(byte), }, _ => Err(Error::DeserializeBadU8), } } - fn raw_deserialize_u16(&mut self, major: u8) -> Result - { + fn raw_deserialize_u16(&mut self, major: u8) -> Result { let number = self.raw_deserialize_u32(major)?; if number <= u16::max_value() as u32 { Ok(number as u16) @@ -130,38 +120,37 @@ impl<'de> Deserializer<'de> { } } - fn raw_deserialize_u32(&mut self, major: u8) -> Result - { + fn raw_deserialize_u32(&mut self, major: u8) -> Result { let additional = self.expect_major(major)?; match additional { byte @ 0..=23 => Ok(byte as u32), - 24 => { - match self.try_take_n(1)?[0] { - 0..=23 => Err(Error::DeserializeNonMinimal), - byte => Ok(byte as u32), - } + 24 => match self.try_take_n(1)?[0] { + 0..=23 => Err(Error::DeserializeNonMinimal), + byte => Ok(byte as u32), }, 25 => { let unsigned = u16::from_be_bytes( self.try_take_n(2)? - .try_into().map_err(|_| Error::InexistentSliceToArrayError)? + .try_into() + .map_err(|_| Error::InexistentSliceToArrayError)?, ); match unsigned { 0..=255 => Err(Error::DeserializeNonMinimal), unsigned => Ok(unsigned as u32), } - }, + } 26 => { let unsigned = u32::from_be_bytes( self.try_take_n(4)? - .try_into().map_err(|_| Error::InexistentSliceToArrayError)? + .try_into() + .map_err(|_| Error::InexistentSliceToArrayError)?, ); match unsigned { 0..=65535 => Err(Error::DeserializeNonMinimal), unsigned => Ok(unsigned as u32), } - }, + } _ => Err(Error::DeserializeBadU32), } } @@ -195,7 +184,7 @@ impl<'a, 'b: 'a> serde::de::SeqAccess<'b> for SeqAccess<'a, 'b> { fn next_element_seed(&mut self, seed: V) -> Result> where - V: DeserializeSeed<'b> + V: DeserializeSeed<'b>, { if self.len > 0 { self.len -= 1; @@ -220,7 +209,7 @@ impl<'a, 'b: 'a> serde::de::MapAccess<'b> for MapAccess<'a, 'b> { fn next_key_seed(&mut self, seed: V) -> Result> where - V: DeserializeSeed<'b> + V: DeserializeSeed<'b>, { if self.len > 0 { self.len -= 1; @@ -317,7 +306,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } else { Err(Error::DeserializeBadI8) } - }, + } 1 => { let raw_u8 = self.raw_deserialize_u8(1)?; // if raw_u8 <= 1 + i8::max_value() as u8 { @@ -326,7 +315,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } else { Err(Error::DeserializeBadI8) } - }, + } _ => Err(Error::DeserializeBadI8), } } @@ -343,7 +332,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } else { Err(Error::DeserializeBadI16) } - }, + } 1 => { let raw = self.raw_deserialize_u16(1)?; if raw <= i16::max_value() as u16 { @@ -351,7 +340,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } else { Err(Error::DeserializeBadI16) } - }, + } _ => Err(Error::DeserializeBadI16), } } @@ -373,7 +362,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } else { Err(Error::DeserializeBadI32) } - }, + } _ => Err(Error::DeserializeBadI16), } } @@ -503,7 +492,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { self.consume()?; visitor.visit_unit() } - _ => Err(Error::DeserializeExpectedNull) + _ => Err(Error::DeserializeExpectedNull), } } @@ -596,7 +585,6 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { // todo!("implement `deserialize_enum`"); // } - // fn parse_enum(&mut self, mut len: usize, visitor: V) -> Result // where // V: de::Visitor<'de>, @@ -791,7 +779,7 @@ mod tests { // use crate::serde::{cbor_serialize, cbor_serialize2, cbor_deserialize}; // use crate::serde::{cbor_serialize, cbor_serialize_old, cbor_deserialize}; - use crate::{cbor_serialize, cbor_deserialize}; + use crate::{cbor_deserialize, cbor_serialize}; #[test] fn de_bool() { @@ -829,7 +817,6 @@ mod tests { } } - #[test] fn de_u16() { let mut buf = [0u8; 64]; @@ -858,7 +845,7 @@ mod tests { fn de_u32() { let mut buf = [0u8; 64]; - for number in 0..=3*(u16::max_value() as u32) { + for number in 0..=3 * (u16::max_value() as u32) { println!("testing {}", number); let _n = cbor_serialize(&number, &mut buf).unwrap(); let de: u32 = from_bytes(&buf).unwrap(); @@ -883,7 +870,7 @@ mod tests { let de: i32 = from_bytes(ser).unwrap(); assert_eq!(de, number); - for number in (3*i16::min_value() as i32)..=3*(i16::max_value() as i32) { + for number in (3 * i16::min_value() as i32)..=3 * (i16::max_value() as i32) { println!("testing {}", number); let ser = cbor_serialize(&number, &mut buf).unwrap(); let de: i32 = from_bytes(ser).unwrap(); @@ -913,7 +900,7 @@ mod tests { let bytes = crate::Bytes::<64>::from_slice(slice).unwrap(); let ser = cbor_serialize(&bytes, &mut buf).unwrap(); println!("serialized bytes = {:?}", ser); - let de: crate::Bytes::<64> = from_bytes(&buf).unwrap(); + let de: crate::Bytes<64> = from_bytes(&buf).unwrap(); println!("deserialized bytes = {:?}", &de); assert_eq!(&de, slice); } @@ -976,7 +963,6 @@ mod tests { #[test] fn de_enum() { - let mut buf = [0u8; 64]; let e = Some(3); let ser = cbor_serialize(&e, &mut buf).unwrap(); @@ -984,7 +970,11 @@ mod tests { let de: Option = cbor_deserialize(ser).unwrap(); assert_eq!(de, e); let e: Option = None; - println!("ser({:?}) = {:x?}", &e, cbor_serialize(&e, &mut buf).unwrap()); + println!( + "ser({:?}) = {:x?}", + &e, + cbor_serialize(&e, &mut buf).unwrap() + ); // let mut buf = [0u8; 64]; // let _n = cbor_serialize(&None, &mut buf).unwrap(); @@ -992,7 +982,7 @@ mod tests { // use serde_indexed::{DeserializeIndexed, SerializeIndexed}; use serde::{Deserialize, Serialize}; - #[derive(Clone,Debug,Eq,PartialEq,Serialize,Deserialize)] + #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum Enum { Alpha(u8), // Beta((i32, u32)), @@ -1008,7 +998,7 @@ mod tests { let de: Enum = cbor_deserialize(ser).unwrap(); assert_eq!(de, e); - #[derive(Clone,Debug,Eq,PartialEq,Serialize,Deserialize)] + #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum SimpleEnum { // Alpha(u8), Alpha(u8), diff --git a/src/error.rs b/src/error.rs index a57f61cc..6fe8b575 100644 --- a/src/error.rs +++ b/src/error.rs @@ -98,7 +98,7 @@ impl Display for Error { DeserializeNonMinimal => "Value may be valid, but not encoded in minimal way", SerdeSerCustom => "Serde Serialization Error", SerdeDeCustom => "Serde Deserialization Error", - SerdeMissingField => "Serde Missing Required Field" + SerdeMissingField => "Serde Missing Required Field", } ) } @@ -132,7 +132,7 @@ impl serde::de::Error for Error { // // `invalid length 297, expected a sequence` // - info_now!("deser error: {}",&msg); + info_now!("deser error: {}", &msg); Error::SerdeDeCustom } fn missing_field(field: &'static str) -> Self { diff --git a/src/lib.rs b/src/lib.rs index 39820ed6..f8dbe00e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,8 +7,8 @@ generate_macros!(); pub use heapless_bytes::Bytes; pub mod de; -pub mod ser; pub mod error; +pub mod ser; pub use error::{Error, Result}; @@ -31,7 +31,6 @@ pub fn cbor_serialize<'a, 'b, T: serde::Serialize>( Ok(&buffer[..size]) } - /// Append serialization of object to existing bytes, returning length of serialized object. pub fn cbor_serialize_extending_bytes<'a, 'b, T: serde::Serialize, const N: usize>( object: &'a T, @@ -45,7 +44,6 @@ pub fn cbor_serialize_extending_bytes<'a, 'b, T: serde::Serialize, const N: usiz Ok(ser.into_inner().len() - len_before) } - /// Serialize object into newly allocated Bytes. pub fn cbor_serialize_bytes(object: &T) -> Result> { let mut data = Bytes::::new(); @@ -53,11 +51,7 @@ pub fn cbor_serialize_bytes(object: &T) -> Ok(data) } - -pub fn cbor_deserialize<'de, T: serde::Deserialize<'de>>( - buffer: &'de [u8], -) -> Result { +pub fn cbor_deserialize<'de, T: serde::Deserialize<'de>>(buffer: &'de [u8]) -> Result { // cortex_m_semihosting::hprintln!("deserializing {:?}", buffer).ok(); de::from_bytes(buffer) } - diff --git a/src/ser.rs b/src/ser.rs index ce7ab461..0eb36f76 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -1,6 +1,6 @@ -use serde::Serialize; -use serde::ser; use super::error::{Error, Result}; +use serde::ser; +use serde::Serialize; // pub fn to_slice<'a, 'b, T>(value: &'a T, buf: &'b mut [u8]) -> Result<&'b mut [u8]> // where @@ -55,13 +55,12 @@ impl<'a> Writer for SliceWriter<'a> { } } -impl<'a, const N: usize> Writer for &'a mut crate::Bytes -{ +impl<'a, const N: usize> Writer for &'a mut crate::Bytes { type Error = Error; fn write_all(&mut self, buf: &[u8]) -> Result<()> { - self.extend_from_slice(buf).map_err( - |_| Error::SerializeBufferFull(buf.len())) + self.extend_from_slice(buf) + .map_err(|_| Error::SerializeBufferFull(buf.len())) } } @@ -73,7 +72,6 @@ pub struct Serializer } impl Serializer { - #[inline] pub fn new(writer: W) -> Self { Serializer { @@ -181,8 +179,7 @@ where type SerializeTupleVariant = &'a mut Serializer; type SerializeMap = CollectionSerializer<'a, W>; type SerializeStruct = &'a mut Serializer; - type SerializeStructVariant= &'a mut Serializer; - + type SerializeStructVariant = &'a mut Serializer; #[inline] fn serialize_bool(self, value: bool) -> Result<()> { @@ -302,18 +299,14 @@ where _variant: &'static str, ) -> Result<()> { // if self.packed { - self.serialize_u32(variant_index) + self.serialize_u32(variant_index) // } else { // self.serialize_str(variant) // } } #[inline] - fn serialize_newtype_struct( - self, - _name: &'static str, - value: &T, - ) -> Result<()> + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result<()> where T: ?Sized + ser::Serialize, { @@ -340,8 +333,8 @@ where // self.write_u64(5, 1u64)?; // variant.serialize(&mut *self)?; // } else { - self.writer.write_all(&[4 << 5 | 2]).map_err(|e| e.into())?; - self.serialize_unit_variant(name, variant_index, variant)?; + self.writer.write_all(&[4 << 5 | 2]).map_err(|e| e.into())?; + self.serialize_unit_variant(name, variant_index, variant)?; // } value.serialize(self) } @@ -379,9 +372,9 @@ where // variant.serialize(&mut *self)?; // self.serialize_tuple(len) // } else { - self.write_u64(4, (len + 1) as u64)?; - self.serialize_unit_variant(name, variant_index, variant)?; - Ok(self) + self.write_u64(4, (len + 1) as u64)?; + self.serialize_unit_variant(name, variant_index, variant)?; + Ok(self) // } } @@ -420,7 +413,7 @@ where // if self.enum_as_map { // self.write_u64(5, 1u64)?; // } else { - self.writer.write_all(&[4 << 5 | 2]).map_err(|e| e.into())?; + self.writer.write_all(&[4 << 5 | 2]).map_err(|e| e.into())?; // } self.serialize_unit_variant(name, variant_index, variant)?; self.serialize_struct(name, len) @@ -551,7 +544,7 @@ where } } -impl<'a, W> ser::SerializeStructVariant for &'a mut Serializer +impl<'a, W> ser::SerializeStructVariant for &'a mut Serializer where W: Writer, { From 0b2dae7f786a5e8b6e27ae7169b78eccf616ae3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sosth=C3=A8ne=20Gu=C3=A9don?= Date: Fri, 17 Feb 2023 16:57:08 +0100 Subject: [PATCH 2/8] Fix enum struct variant deserialization bug --- src/de.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/de.rs b/src/de.rs index 40600b16..5ee35224 100644 --- a/src/de.rs +++ b/src/de.rs @@ -248,10 +248,10 @@ impl<'de, 'a> serde::de::VariantAccess<'de> for &'a mut Deserializer<'de> { fn struct_variant>( self, - fields: &'static [&'static str], + _fields: &'static [&'static str], visitor: V, ) -> Result { - serde::de::Deserializer::deserialize_tuple(self, fields.len(), visitor) + serde::de::Deserializer::deserialize_map(self, visitor) } } From 9de18c3268bd8894490529b9cf14d443d6a55da1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sosth=C3=A8ne=20Gu=C3=A9don?= Date: Fri, 17 Feb 2023 17:01:31 +0100 Subject: [PATCH 3/8] Add basic fuzzing --- fuzz/.gitignore | 4 ++ fuzz/Cargo.toml | 31 +++++++++++ fuzz/fuzz_targets/fuzz_target_1.rs | 86 ++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+) create mode 100644 fuzz/.gitignore create mode 100644 fuzz/Cargo.toml create mode 100644 fuzz/fuzz_targets/fuzz_target_1.rs diff --git a/fuzz/.gitignore b/fuzz/.gitignore new file mode 100644 index 00000000..1a45eee7 --- /dev/null +++ b/fuzz/.gitignore @@ -0,0 +1,4 @@ +target +corpus +artifacts +coverage diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml new file mode 100644 index 00000000..914c8f9b --- /dev/null +++ b/fuzz/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "cbor-smol-fuzz" +version = "0.0.0" +publish = false +edition = "2018" + +[package.metadata] +cargo-fuzz = true + +[dependencies] +libfuzzer-sys = "0.4" +arbitrary = { version = "1.2.3", features = ["derive"] } +serde = { version = "1.0.152", features = ["derive"] } +serde_bytes = "0.11.9" +serde_cbor = "0.11.2" + +[dependencies.cbor-smol] +path = ".." + +# Prevent this from interfering with workspaces +[workspace] +members = ["."] + +[profile.release] +debug = 1 + +[[bin]] +name = "fuzz_target_1" +path = "fuzz_targets/fuzz_target_1.rs" +test = false +doc = false diff --git a/fuzz/fuzz_targets/fuzz_target_1.rs b/fuzz/fuzz_targets/fuzz_target_1.rs new file mode 100644 index 00000000..8f61ffec --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_target_1.rs @@ -0,0 +1,86 @@ +#![no_main] + +use arbitrary::{Arbitrary, Unstructured}; +use libfuzzer_sys::fuzz_target; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, PartialEq, Arbitrary, Serialize, Deserialize)] +enum AllEnums { + I8(i8), + U8(u8), + I16(i16), + U16(u16), + I32(i32), + U32(u32), + // Not implemented + // I64(i64), + U64(u64), + Struct(Struct), + Array([Struct; 4]), + Option(Option), + Vec(Vec), + Bytes(#[serde(with = "serde_bytes")] Vec), + String(String), + Tuple((Struct, Struct)), + TupleVariant(Struct, Struct), + TupleVariantBytes(Struct, Struct, #[serde(with = "serde_bytes")] Vec), + StructVariant { + x: Struct, + y: Struct, + }, + StructVariantBytes { + x: Struct, + y: Struct, + #[serde(with = "serde_bytes")] + z: Vec, + }, +} + +#[derive(Debug, PartialEq, Arbitrary, Serialize, Deserialize)] +struct Struct { + a: Box, + b: Box, +} + +/// Workaround https://github.com/rust-fuzz/arbitrary/issues/144 +#[derive(Debug)] +struct Input<'i>(AllEnums, &'i [u8]); + +impl<'i> Arbitrary<'i> for Input<'i> { + fn arbitrary(u: &mut Unstructured<'i>) -> Result { + Ok(Self(AllEnums::arbitrary(u)?, Arbitrary::arbitrary(u)?)) + } + + fn arbitrary_take_rest(mut u: Unstructured<'i>) -> Result { + Ok(Self( + AllEnums::arbitrary(&mut u)?, + Arbitrary::arbitrary_take_rest(u)?, + )) + } + fn size_hint(_depth: usize) -> (usize, Option) { + (0, None) + } +} + +fuzz_target!(|data: Input<'_>| { + let bytes = data.1; + let data = data.0; + let _res: Option = cbor_smol::cbor_deserialize(&bytes).ok(); + let mut buffer = vec![0; 1024 * 20]; + let res = cbor_smol::cbor_serialize(&data, &mut buffer).unwrap(); + cbor_smol::cbor_deserialize(&res) + .map(|b: AllEnums| { + assert_eq!(data, b); + }) + .map_err(|err| { + let v: Result = serde_cbor::from_slice(&res); + panic!( + "Failed to deserialize: {:?}\n\ + input: {:#?}\n\ + data: {:02x?}\n\ + serde_cbor gives: {:#?}\n", + err, data, res, v + ); + }) + .ok(); +}); From a3f7ece7d1c367c0920a17f3ec4ff75ea31a4386 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sosth=C3=A8ne=20Gu=C3=A9don?= Date: Mon, 20 Feb 2023 09:27:17 +0100 Subject: [PATCH 4/8] Fix clippy warnings --- src/de.rs | 8 ++++---- src/lib.rs | 14 +++++++------- src/ser.rs | 6 +++--- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/de.rs b/src/de.rs index 5ee35224..d43fba0e 100644 --- a/src/de.rs +++ b/src/de.rs @@ -62,7 +62,7 @@ impl<'de> Deserializer<'de> { } fn peek_major(&mut self) -> Result { - if self.input.len() != 0 { + if !self.input.is_empty() { let byte = self.input[0]; Ok(byte >> 5) } else { @@ -71,7 +71,7 @@ impl<'de> Deserializer<'de> { } fn peek(&mut self) -> Result { - if self.input.len() != 0 { + if !self.input.is_empty() { Ok(self.input[0]) } else { Err(Error::DeserializeUnexpectedEnd) @@ -79,7 +79,7 @@ impl<'de> Deserializer<'de> { } fn consume(&mut self) -> Result<()> { - if self.input.len() != 0 { + if !self.input.is_empty() { self.input = &self.input[1..]; Ok(()) } else { @@ -148,7 +148,7 @@ impl<'de> Deserializer<'de> { ); match unsigned { 0..=65535 => Err(Error::DeserializeNonMinimal), - unsigned => Ok(unsigned as u32), + unsigned => Ok(unsigned), } } _ => Err(Error::DeserializeBadU32), diff --git a/src/lib.rs b/src/lib.rs index f8dbe00e..a9c23e23 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,10 +16,10 @@ pub use error::{Error, Result}; // pub use de::take_from_bytes; // kudos to postcard, this is much nicer than returning size -pub fn cbor_serialize<'a, 'b, T: serde::Serialize>( - object: &'a T, - buffer: &'b mut [u8], -) -> Result<&'b [u8]> { +pub fn cbor_serialize<'a, T: serde::Serialize>( + object: &T, + buffer: &'a mut [u8], +) -> Result<&'a [u8]> { let writer = ser::SliceWriter::new(buffer); let mut ser = ser::Serializer::new(writer); @@ -32,9 +32,9 @@ pub fn cbor_serialize<'a, 'b, T: serde::Serialize>( } /// Append serialization of object to existing bytes, returning length of serialized object. -pub fn cbor_serialize_extending_bytes<'a, 'b, T: serde::Serialize, const N: usize>( - object: &'a T, - bytes: &'b mut Bytes, +pub fn cbor_serialize_extending_bytes( + object: &T, + bytes: &mut Bytes, ) -> Result { let len_before = bytes.len(); let mut ser = ser::Serializer::new(bytes); diff --git a/src/ser.rs b/src/ser.rs index 0eb36f76..cf132be3 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -132,11 +132,11 @@ impl Serializer { } #[inline] - fn serialize_collection<'a>( - &'a mut self, + fn serialize_collection( + &mut self, major: u8, len: Option, - ) -> Result> { + ) -> Result> { let needs_eof = match len { Some(len) => { self.write_u64(major, len as u64)?; From dc72f59ff7fefcdccda2899cfa3ac36b80b12120 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sosth=C3=A8ne=20Gu=C3=A9don?= Date: Mon, 20 Feb 2023 10:24:45 +0100 Subject: [PATCH 5/8] Migrate to edition 2021 --- Cargo.toml | 2 +- fuzz/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 97e990c2..8cd3376b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "cbor-smol" version = "0.4.0" authors = ["Nicolas Stalder "] -edition = "2018" +edition = "2021" description = "Streamlined serde serializer/deserializer for CBOR" repository = "https://github.com/nickray/cbor-smol" readme = "README.md" diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 914c8f9b..f71d5750 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -2,7 +2,7 @@ name = "cbor-smol-fuzz" version = "0.0.0" publish = false -edition = "2018" +edition = "2021" [package.metadata] cargo-fuzz = true From a3e2a27629fbdc4ce6052b48123efc040ae70ba6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sosth=C3=A8ne=20Gu=C3=A9don?= Date: Mon, 20 Feb 2023 11:40:58 +0100 Subject: [PATCH 6/8] Improve readability with consts --- src/consts.rs | 19 +++++++++++++++++++ src/de.rs | 47 +++++++++++++++++++++++------------------------ src/lib.rs | 1 + src/ser.rs | 44 +++++++++++++++++++++++--------------------- 4 files changed, 66 insertions(+), 45 deletions(-) create mode 100644 src/consts.rs diff --git a/src/consts.rs b/src/consts.rs new file mode 100644 index 00000000..f24bfeb7 --- /dev/null +++ b/src/consts.rs @@ -0,0 +1,19 @@ +pub const MAJOR_OFFSET: u8 = 5; + +pub const MAJOR_POSINT: u8 = 0; +pub const MAJOR_NEGINT: u8 = 1; +pub const MAJOR_BYTES: u8 = 2; +pub const MAJOR_STR: u8 = 3; +pub const MAJOR_ARRAY: u8 = 4; +pub const MAJOR_MAP: u8 = 5; +pub const MAJOR_SIMPLE: u8 = 7; + +pub const SIMPLE_FALSE: u8 = 20; +pub const SIMPLE_TRUE: u8 = 21; +pub const SIMPLE_NULL: u8 = 22; +// pub const SIMPLE_UNDEFINED: u8 = 23; + +pub const VALUE_FALSE: u8 = (MAJOR_SIMPLE << MAJOR_OFFSET) | SIMPLE_FALSE; +pub const VALUE_TRUE: u8 = (MAJOR_SIMPLE << MAJOR_OFFSET) | SIMPLE_TRUE; +pub const VALUE_NULL: u8 = (MAJOR_SIMPLE << MAJOR_OFFSET) | SIMPLE_NULL; +// pub const VALUE_UNDEFINED: u8 = (MAJOR_SIMPLE << MAJOR_LEN) | SIMPLE_UNDEFINED; diff --git a/src/de.rs b/src/de.rs index d43fba0e..c26334c1 100644 --- a/src/de.rs +++ b/src/de.rs @@ -3,6 +3,7 @@ use serde::Deserialize; use serde::de::IntoDeserializer; use super::error::{Error, Result}; +use crate::consts::*; /// Deserialize a message of type `T` from a byte slice. The unused portion (if any) /// of the byte slice is returned for further usage @@ -64,7 +65,7 @@ impl<'de> Deserializer<'de> { fn peek_major(&mut self) -> Result { if !self.input.is_empty() { let byte = self.input[0]; - Ok(byte >> 5) + Ok(byte >> MAJOR_OFFSET) } else { Err(Error::DeserializeUnexpectedEnd) } @@ -89,12 +90,12 @@ impl<'de> Deserializer<'de> { fn expect_major(&mut self, major: u8) -> Result { let byte = self.try_take_n(1)?[0]; - if major != (byte >> 5) { + if major != (byte >> MAJOR_OFFSET) { // logging::info_now!("expecting {}, got {} in byte {}", major, byte >> 5, byte).ok(); // logging::info_now!("remaining data: {:?}", &self.input).ok(); return Err(Error::DeserializeBadMajor); } - Ok(byte & ((1 << 5) - 1)) + Ok(byte & ((1 << MAJOR_OFFSET) - 1)) } // TODO: name something like "one-byte-integer" @@ -287,8 +288,8 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { let val = match self.try_take_n(1)?[0] { - 0xf4 => false, - 0xf5 => true, + VALUE_FALSE => false, + VALUE_TRUE => true, _ => return Err(Error::DeserializeBadBool), }; visitor.visit_bool(val) @@ -299,7 +300,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { match self.peek_major()? { - 0 => { + MAJOR_POSINT => { let raw_u8 = self.raw_deserialize_u8(0)?; if raw_u8 <= i8::max_value() as u8 { visitor.visit_i8(raw_u8 as i8) @@ -307,7 +308,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { Err(Error::DeserializeBadI8) } } - 1 => { + MAJOR_NEGINT => { let raw_u8 = self.raw_deserialize_u8(1)?; // if raw_u8 <= 1 + i8::max_value() as u8 { if raw_u8 <= 128 { @@ -325,7 +326,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { match self.peek_major()? { - 0 => { + MAJOR_POSINT => { let raw = self.raw_deserialize_u16(0)?; if raw <= i16::max_value() as u16 { visitor.visit_i16(raw as i16) @@ -333,7 +334,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { Err(Error::DeserializeBadI16) } } - 1 => { + MAJOR_NEGINT => { let raw = self.raw_deserialize_u16(1)?; if raw <= i16::max_value() as u16 { visitor.visit_i16(-1 - (raw as i16)) @@ -354,7 +355,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { major @ 0..=1 => { let raw = self.raw_deserialize_u32(major)?; if raw <= i32::max_value() as u32 { - if major == 0 { + if major == MAJOR_POSINT { visitor.visit_i32(raw as i32) } else { visitor.visit_i32(-1 - (raw as i32)) @@ -378,7 +379,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - let raw = self.raw_deserialize_u8(0)?; + let raw = self.raw_deserialize_u8(MAJOR_POSINT)?; visitor.visit_u8(raw) } @@ -386,7 +387,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - let raw = self.raw_deserialize_u16(0)?; + let raw = self.raw_deserialize_u16(MAJOR_POSINT)?; visitor.visit_u16(raw) } @@ -394,7 +395,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - let raw = self.raw_deserialize_u32(0)?; + let raw = self.raw_deserialize_u32(MAJOR_POSINT)?; visitor.visit_u32(raw) } @@ -402,7 +403,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - let raw = self.raw_deserialize_u32(0)?; + let raw = self.raw_deserialize_u32(MAJOR_POSINT)?; visitor.visit_u64(raw as u64) } @@ -439,7 +440,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { // major type 2: "byte string" - let length = self.raw_deserialize_u32(2)? as usize; + let length = self.raw_deserialize_u32(MAJOR_BYTES)? as usize; let bytes: &'de [u8] = self.try_take_n(length)?; visitor.visit_borrowed_bytes(bytes) } @@ -456,7 +457,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { // major type 3: "text string" - let length = self.raw_deserialize_u32(3)? as usize; + let length = self.raw_deserialize_u32(MAJOR_STR)? as usize; let bytes: &'de [u8] = self.try_take_n(length)?; let string_slice = core::str::from_utf8(bytes).map_err(|_| Error::DeserializeBadUtf8)?; visitor.visit_borrowed_str(string_slice) @@ -488,7 +489,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { match self.peek()? { - 0xf6 => { + VALUE_NULL => { self.consume()?; visitor.visit_unit() } @@ -515,8 +516,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - // major type 4: "array" - let len = self.raw_deserialize_u32(4)? as usize; + let len = self.raw_deserialize_u32(MAJOR_ARRAY)? as usize; visitor.visit_seq(SeqAccess { deserializer: self, @@ -528,8 +528,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - // major type 4: "array" - let len = self.raw_deserialize_u32(4)? as usize; + let len = self.raw_deserialize_u32(MAJOR_ARRAY)? as usize; visitor.visit_seq(SeqAccess { deserializer: self, len, @@ -552,8 +551,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - // major type 5: "map" - let len = self.raw_deserialize_u32(5)? as usize; + let len = self.raw_deserialize_u32(MAJOR_MAP)? as usize; visitor.visit_map(MapAccess { deserializer: self, @@ -611,8 +609,9 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { + const ARRAY_LEN_2: u8 = MAJOR_ARRAY << MAJOR_OFFSET | 2; match self.peek()? { - 0x82 => { + ARRAY_LEN_2 => { self.consume()?; visitor.visit_enum(self) // // self.parse_enum(2, visitor) diff --git a/src/lib.rs b/src/lib.rs index a9c23e23..48cce453 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ generate_macros!(); pub use heapless_bytes::Bytes; +pub(crate) mod consts; pub mod de; pub mod error; pub mod ser; diff --git a/src/ser.rs b/src/ser.rs index cf132be3..1411c163 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -2,6 +2,8 @@ use super::error::{Error, Result}; use serde::ser; use serde::Serialize; +use crate::consts::*; + // pub fn to_slice<'a, 'b, T>(value: &'a T, buf: &'b mut [u8]) -> Result<&'b mut [u8]> // where // T: Serialize + ?Sized, @@ -90,9 +92,9 @@ impl Serializer { #[inline] fn write_u8(&mut self, major: u8, value: u8) -> Result<()> { if value <= 0x17 { - self.writer.write_all(&[major << 5 | value]) + self.writer.write_all(&[major << MAJOR_OFFSET | value]) } else { - let buf = [major << 5 | 24, value]; + let buf = [major << MAJOR_OFFSET | 24, value]; self.writer.write_all(&buf) } .map_err(|e| e.into()) @@ -103,7 +105,7 @@ impl Serializer { if value <= u16::from(u8::max_value()) { self.write_u8(major, value as u8) } else { - let mut buf = [major << 5 | 25, 0, 0]; + let mut buf = [major << MAJOR_OFFSET | 25, 0, 0]; buf[1..].copy_from_slice(&value.to_be_bytes()); self.writer.write_all(&buf).map_err(|e| e.into()) } @@ -114,7 +116,7 @@ impl Serializer { if value <= u32::from(u16::max_value()) { self.write_u16(major, value as u16) } else { - let mut buf = [major << 5 | 26, 0, 0, 0, 0]; + let mut buf = [major << MAJOR_OFFSET | 26, 0, 0, 0, 0]; buf[1..].copy_from_slice(&value.to_be_bytes()); self.writer.write_all(&buf).map_err(|e| e.into()) } @@ -125,7 +127,7 @@ impl Serializer { if value <= u64::from(u32::max_value()) { self.write_u32(major, value as u32) } else { - let mut buf = [major << 5 | 27, 0, 0, 0, 0, 0, 0, 0, 0]; + let mut buf = [major << MAJOR_OFFSET | 27, 0, 0, 0, 0, 0, 0, 0, 0]; buf[1..].copy_from_slice(&value.to_be_bytes()); self.writer.write_all(&buf).map_err(|e| e.into()) } @@ -144,7 +146,7 @@ impl Serializer { } None => { self.writer - .write_all(&[major << 5 | 31]) + .write_all(&[major << MAJOR_OFFSET | 31]) .map_err(|e| e.into())?; true } @@ -183,7 +185,7 @@ where #[inline] fn serialize_bool(self, value: bool) -> Result<()> { - let value = if value { 0xf5 } else { 0xf4 }; + let value = if value { VALUE_TRUE } else { VALUE_FALSE }; self.writer.write_all(&[value]).map_err(|e| e.into()) } @@ -221,22 +223,22 @@ where #[inline] fn serialize_u8(self, value: u8) -> Result<()> { - self.write_u8(0, value) + self.write_u8(MAJOR_POSINT, value) } #[inline] fn serialize_u16(self, value: u16) -> Result<()> { - self.write_u16(0, value) + self.write_u16(MAJOR_POSINT, value) } #[inline] fn serialize_u32(self, value: u32) -> Result<()> { - self.write_u32(0, value) + self.write_u32(MAJOR_POSINT, value) } #[inline] fn serialize_u64(self, value: u64) -> Result<()> { - self.write_u64(0, value) + self.write_u64(MAJOR_POSINT, value) } fn serialize_f32(self, _v: f32) -> Result<()> { @@ -256,7 +258,7 @@ where #[inline] fn serialize_str(self, value: &str) -> Result<()> { - self.write_u64(3, value.len() as u64)?; + self.write_u64(MAJOR_STR, value.len() as u64)?; self.writer .write_all(value.as_bytes()) .map_err(|e| e.into()) @@ -264,13 +266,13 @@ where #[inline] fn serialize_bytes(self, value: &[u8]) -> Result<()> { - self.write_u64(2, value.len() as u64)?; + self.write_u64(MAJOR_BYTES, value.len() as u64)?; self.writer.write_all(value).map_err(|e| e.into()) } #[inline] fn serialize_none(self) -> Result<()> { - self.writer.write_all(&[0xf6]).map_err(|e| e.into()) + self.writer.write_all(&[VALUE_NULL]).map_err(|e| e.into()) } #[inline] @@ -333,7 +335,7 @@ where // self.write_u64(5, 1u64)?; // variant.serialize(&mut *self)?; // } else { - self.writer.write_all(&[4 << 5 | 2]).map_err(|e| e.into())?; + self.write_u64(MAJOR_ARRAY, 2)?; self.serialize_unit_variant(name, variant_index, variant)?; // } value.serialize(self) @@ -341,12 +343,12 @@ where #[inline] fn serialize_seq(self, len: Option) -> Result> { - self.serialize_collection(4, len) + self.serialize_collection(MAJOR_ARRAY, len) } #[inline] fn serialize_tuple(self, len: usize) -> Result<&'a mut Serializer> { - self.write_u64(4, len as u64)?; + self.write_u64(MAJOR_ARRAY, len as u64)?; Ok(self) } @@ -372,7 +374,7 @@ where // variant.serialize(&mut *self)?; // self.serialize_tuple(len) // } else { - self.write_u64(4, (len + 1) as u64)?; + self.write_u64(MAJOR_ARRAY, (len + 1) as u64)?; self.serialize_unit_variant(name, variant_index, variant)?; Ok(self) // } @@ -380,7 +382,7 @@ where #[inline] fn serialize_map(self, len: Option) -> Result> { - self.serialize_collection(5, len) + self.serialize_collection(MAJOR_MAP, len) } // #[cfg(not(feature = "std"))] @@ -398,7 +400,7 @@ where #[inline] fn serialize_struct(self, _name: &'static str, len: usize) -> Result { - self.write_u64(5, len as u64)?; + self.write_u64(MAJOR_MAP, len as u64)?; Ok(self) } @@ -413,7 +415,7 @@ where // if self.enum_as_map { // self.write_u64(5, 1u64)?; // } else { - self.writer.write_all(&[4 << 5 | 2]).map_err(|e| e.into())?; + self.write_u64(MAJOR_ARRAY, 2)?; // } self.serialize_unit_variant(name, variant_index, variant)?; self.serialize_struct(name, len) From 7b51b1e8dc49a955e51b8f63cb2e53052f456fb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sosth=C3=A8ne=20Gu=C3=A9don?= Date: Mon, 20 Feb 2023 12:33:11 +0100 Subject: [PATCH 7/8] Fix enum struct variant deserialization --- src/de.rs | 130 ++++++++++++++++++++++-------------------------------- 1 file changed, 53 insertions(+), 77 deletions(-) diff --git a/src/de.rs b/src/de.rs index c26334c1..9640e197 100644 --- a/src/de.rs +++ b/src/de.rs @@ -79,10 +79,11 @@ impl<'de> Deserializer<'de> { } } - fn consume(&mut self) -> Result<()> { + fn consume(&mut self) -> Result { if !self.input.is_empty() { + let ret = self.input[0]; self.input = &self.input[1..]; - Ok(()) + Ok(ret) } else { Err(Error::DeserializeUnexpectedEnd) } @@ -232,19 +233,37 @@ impl<'a, 'b: 'a> serde::de::MapAccess<'b> for MapAccess<'a, 'b> { } } -impl<'de, 'a> serde::de::VariantAccess<'de> for &'a mut Deserializer<'de> { +struct EnumAccess<'a, 'b: 'a> { + deserializer: &'a mut Deserializer<'b>, + variant_len: usize, +} + +impl<'de, 'a> serde::de::VariantAccess<'de> for EnumAccess<'a, 'de> { type Error = Error; fn unit_variant(self) -> Result<()> { + if self.variant_len != 0 { + return Err(Error::DeserializeBadEnum); + } Ok(()) } fn newtype_variant_seed>(self, seed: V) -> Result { - DeserializeSeed::deserialize(seed, self) + if 2 != self.variant_len { + return Err(Error::DeserializeBadEnum); + } + DeserializeSeed::deserialize(seed, self.deserializer) } fn tuple_variant>(self, len: usize, visitor: V) -> Result { - serde::de::Deserializer::deserialize_tuple(self, len, visitor) + if len + 1 != self.variant_len { + return Err(Error::DeserializeBadEnum); + } + + visitor.visit_seq(SeqAccess { + deserializer: self.deserializer, + len, + }) } fn struct_variant>( @@ -252,16 +271,19 @@ impl<'de, 'a> serde::de::VariantAccess<'de> for &'a mut Deserializer<'de> { _fields: &'static [&'static str], visitor: V, ) -> Result { - serde::de::Deserializer::deserialize_map(self, visitor) + if 2 != self.variant_len { + return Err(Error::DeserializeBadEnum); + } + serde::de::Deserializer::deserialize_map(self.deserializer, visitor) } } -impl<'de, 'a> serde::de::EnumAccess<'de> for &'a mut Deserializer<'de> { +impl<'de, 'a> serde::de::EnumAccess<'de> for EnumAccess<'a, 'de> { type Error = Error; type Variant = Self; fn variant_seed>(self, seed: V) -> Result<(V::Value, Self)> { - let discriminant = self.raw_deserialize_u32(0)?; + let discriminant = self.deserializer.raw_deserialize_u32(MAJOR_POSINT)?; // if discriminant > 0xFFFF_FFFF { // return Err(Error::DeserializeBadEnum); // } @@ -609,72 +631,22 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - const ARRAY_LEN_2: u8 = MAJOR_ARRAY << MAJOR_OFFSET | 2; - match self.peek()? { - ARRAY_LEN_2 => { - self.consume()?; - visitor.visit_enum(self) - // // self.parse_enum(2, visitor) - // let value = visitor.visit_enum(VariantAccess { - // seq: SeqAccess { self, len: &mut 2 }, - // })?; - - // if len != 0 { - // Err(de.error(ErrorCode::TrailingData)) - // } else { - // Ok(value) - // } + match self.peek_major()? { + // Data variant + MAJOR_ARRAY => { + let len = self.raw_deserialize_u32(MAJOR_ARRAY)?; + visitor.visit_enum(EnumAccess { + deserializer: self, + variant_len: len as usize, + }) } - // _ => Err(Error::DeserializeBadEnum), - _ => visitor.visit_enum(self), + // Unit variant + MAJOR_POSINT => visitor.visit_enum(EnumAccess { + deserializer: self, + variant_len: 0, + }), + _ => Err(Error::DeserializeBadMajor), } - - // Some(byte @ 0x80..=0x9f) => { - // if !self.accept_legacy_enums { - // return Err(self.error(ErrorCode::WrongEnumFormat)); - // } - // self.consume(); - // match byte { - // 0x80..=0x97 => self.parse_enum(byte as usize - 0x80, visitor), - // 0x98 => { - // let len = self.parse_u8()?; - // self.parse_enum(len as usize, visitor) - // } - // 0x99 => { - // let len = self.parse_u16()?; - // self.parse_enum(len as usize, visitor) - // } - // 0x9a => { - // let len = self.parse_u32()?; - // self.parse_enum(len as usize, visitor) - // } - // 0x9b => { - // let len = self.parse_u64()?; - // if len > usize::max_value() as u64 { - // return Err(self.error(ErrorCode::LengthOutOfRange)); - // } - // self.parse_enum(len as usize, visitor) - // } - // _ => Err(Error::DeserializeBadEnum), - // // 0x9c..=0x9e => Err(self.error(ErrorCode::UnassignedCode)), - // // 0x9f => self.parse_indefinite_enum(visitor), - - // // _ => unreachable!(), - // } - // } - // _ => Err(Error::DeserializeBadEnum), - // // Some(0xa1) => { - // // if !self.accept_standard_enums { - // // return Err(self.error(ErrorCode::WrongEnumFormat)); - // // } - // // self.consume(); - // // self.parse_enum_map(visitor) - // // } - // } - // println!("visiting enum"); - // let ret = visitor.visit_enum(self); - // println!("visited enum"); - // ret } fn deserialize_identifier(self, visitor: V) -> Result @@ -984,14 +956,19 @@ mod tests { #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum Enum { Alpha(u8), - // Beta((i32, u32)), - Beta(i32), + Beta((i32, u32)), + Gamma { a: i32, b: u32 }, } let mut buf = [0u8; 64]; - // let e = Enum::Beta((-42, 7)); - let e = Enum::Beta(-42); + let e = Enum::Beta((-42, 7)); + let ser = cbor_serialize(&e, &mut buf).unwrap(); + println!("ser({:?}) = {:?}", &e, ser); + let de: Enum = cbor_deserialize(ser).unwrap(); + assert_eq!(de, e); + + let e = Enum::Gamma { a: -42, b: 7 }; let ser = cbor_serialize(&e, &mut buf).unwrap(); println!("ser({:?}) = {:?}", &e, ser); let de: Enum = cbor_deserialize(ser).unwrap(); @@ -999,7 +976,6 @@ mod tests { #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum SimpleEnum { - // Alpha(u8), Alpha(u8), Beta, } From 6b8dc58c4935b9f31c2c7dd5e52dd49e36ecc210 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sosth=C3=A8ne=20Gu=C3=A9don?= Date: Mon, 20 Feb 2023 12:45:49 +0100 Subject: [PATCH 8/8] Fix deserialization of u64 --- src/de.rs | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++- src/error.rs | 3 +++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/src/de.rs b/src/de.rs index 9640e197..b995d961 100644 --- a/src/de.rs +++ b/src/de.rs @@ -156,6 +156,51 @@ impl<'de> Deserializer<'de> { _ => Err(Error::DeserializeBadU32), } } + fn raw_deserialize_u64(&mut self, major: u8) -> Result { + let additional = self.expect_major(major)?; + + match additional { + byte @ 0..=23 => Ok(byte as u64), + 24 => match self.try_take_n(1)?[0] { + 0..=23 => Err(Error::DeserializeNonMinimal), + byte => Ok(byte as u64), + }, + 25 => { + let unsigned = u16::from_be_bytes( + self.try_take_n(2)? + .try_into() + .map_err(|_| Error::InexistentSliceToArrayError)?, + ); + match unsigned { + 0..=255 => Err(Error::DeserializeNonMinimal), + unsigned => Ok(unsigned as u64), + } + } + 26 => { + let unsigned = u32::from_be_bytes( + self.try_take_n(4)? + .try_into() + .map_err(|_| Error::InexistentSliceToArrayError)?, + ); + match unsigned { + 0..=65535 => Err(Error::DeserializeNonMinimal), + unsigned => Ok(unsigned as u64), + } + } + 27 => { + let unsigned = u64::from_be_bytes( + self.try_take_n(8)? + .try_into() + .map_err(|_| Error::InexistentSliceToArrayError)?, + ); + match unsigned { + 0..=0xFFFFFFFF => Err(Error::DeserializeNonMinimal), + unsigned => Ok(unsigned), + } + } + _ => Err(Error::DeserializeBadU64), + } + } // fn try_take_varint(&mut self) -> Result { // for i in 0..VarintUsize::varint_usize_max() { @@ -425,7 +470,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - let raw = self.raw_deserialize_u32(MAJOR_POSINT)?; + let raw = self.raw_deserialize_u64(MAJOR_POSINT)?; visitor.visit_u64(raw as u64) } @@ -831,6 +876,31 @@ mod tests { } } + #[test] + fn de_u64() { + let mut buf = [0u8; 64]; + + let numbers = [ + 0, + 1, + 2, + 3, + u16::MAX as u64, + u16::MAX as u64 + 1, + u32::MAX as u64, + u32::MAX as u64 + 1, + u64::MAX - 1, + u64::MAX, + ]; + + for number in numbers { + println!("testing {}", number); + let _n = cbor_serialize(&number, &mut buf).unwrap(); + let de: u64 = from_bytes(&buf).unwrap(); + assert_eq!(de, number); + } + } + #[test] fn de_i32() { let mut buf = [0u8; 64]; diff --git a/src/error.rs b/src/error.rs index 6fe8b575..0fc83044 100644 --- a/src/error.rs +++ b/src/error.rs @@ -48,6 +48,8 @@ pub enum Error { DeserializeBadU16, /// Expected a u32 DeserializeBadU32, + /// Expected a u64 + DeserializeBadU64, /// Expected a NULL marker DeserializeExpectedNull, /// Inexistent slice-to-array cast error. Used here to avoid calling unwrap. @@ -93,6 +95,7 @@ impl Display for Error { DeserializeBadU8 => "Expected a u8", DeserializeBadU16 => "Expected a u16", DeserializeBadU32 => "Expected a u32", + DeserializeBadU64 => "Expected a u64", DeserializeExpectedNull => "Expected 0xf6", InexistentSliceToArrayError => "", DeserializeNonMinimal => "Value may be valid, but not encoded in minimal way",