Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bytes deserialization: accept arrays #7

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ serde = { version = "1", default-features = false }
[dev-dependencies]
heapless = { version = "0.7", features = ["serde"] }
serde = { version = "1", default-features = false, features = ["derive"] }
serde_bytes = "0.11.12"

[features]
bytes-from-array = []
log-all = []
log-none = []
log-info = []
Expand Down
3 changes: 1 addition & 2 deletions fuzz/fuzz_targets/fuzz_target_1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ enum AllEnums {
U16(u16),
I32(i32),
U32(u32),
// Not implemented
// I64(i64),
I64(i64),
U64(u64),
Struct(Struct),
Array([Struct; 4]),
Expand Down
1 change: 1 addition & 0 deletions src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub const MAJOR_STR: u8 = 3;
pub const MAJOR_ARRAY: u8 = 4;
pub const MAJOR_MAP: u8 = 5;
pub const MAJOR_SIMPLE: u8 = 7;
pub const MAJOR_FLOAT: u8 = 7;

pub const SIMPLE_FALSE: u8 = 20;
pub const SIMPLE_TRUE: u8 = 21;
Expand Down
213 changes: 207 additions & 6 deletions src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,82 @@ impl<'de> Deserializer<'de> {
}
}

fn ignore_int(&mut self, major: u8) -> Result<()> {
let additional = self.expect_major(major)?;
match additional {
0..=23 => {}
24 => {
self.try_take_n(1)?;
}
25 => {
self.try_take_n(2)?;
}
26 => {
self.try_take_n(4)?;
}
27 => {
self.try_take_n(8)?;
}
_ => return Err(Error::DeserializeBadU16),
};
Ok(())
}

fn ignore_bytes(&mut self, major: u8) -> Result<()> {
let length = self.raw_deserialize_u32(major)? as usize;
self.try_take_n(length)?;
Ok(())
}

fn ignore_array(&mut self, major: u8, mult: usize) -> Result<()> {
let length = self.raw_deserialize_u32(major)? as usize;
let Some(real_length) = length.checked_mul(mult) else {
return Err(Error::InexistentSliceToArrayError);
};
for _ in 0..real_length {
self.ignore()?;
}
Ok(())
}

fn ignore_float(&mut self) -> Result<()> {
let additional = self.expect_major(7)?;
match additional {
0..=23 => {}
24 => {
self.try_take_n(1)?;
}
25 => {
self.try_take_n(2)?;
}
26 => {
self.try_take_n(4)?;
}
27 => {
self.try_take_n(8)?;
}
_ => return Err(Error::DeserializeBadMajor),
};
Ok(())
}

fn ignore(&mut self) -> Result<()> {
let major = self.peek_major()?;
match major {
MAJOR_POSINT | MAJOR_NEGINT => self.ignore_int(major)?,
MAJOR_BYTES | MAJOR_STR => self.ignore_bytes(major)?,
MAJOR_ARRAY => self.ignore_array(MAJOR_ARRAY, 1)?,
MAJOR_MAP => self.ignore_array(MAJOR_MAP, 2)?,
6 => {
self.ignore_int(6)?;
self.ignore()?;
}
MAJOR_FLOAT => self.ignore_float()?,
_ => return Err(Error::DeserializeBadMajor),
}
Ok(())
}

// fn try_take_varint(&mut self) -> Result<usize> {
// for i in 0..VarintUsize::varint_usize_max() {
// let val = self.input.get(i).ok_or(Error::DeserializeUnexpectedEnd)?;
Expand Down Expand Up @@ -435,11 +511,26 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
}
}

fn deserialize_i64<V>(self, _visitor: V) -> Result<V::Value>
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
Err(Error::NotYetImplemented)
match self.peek_major()? {
// TODO: figure out if this is BAAAAD for size or speed
major @ 0..=1 => {
let raw = self.raw_deserialize_u64(major)?;
if raw <= i64::max_value() as u64 {
if major == MAJOR_POSINT {
visitor.visit_i64(raw as i64)
} else {
visitor.visit_i64(-1 - (raw as i64))
}
} else {
Err(Error::DeserializeBadI64)
}
}
_ => Err(Error::DeserializeBadI16),
}
}

fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
Expand Down Expand Up @@ -506,10 +597,24 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
where
V: Visitor<'de>,
{
// major type 2: "byte string"
let length = self.raw_deserialize_u32(MAJOR_BYTES)? as usize;
let bytes: &'de [u8] = self.try_take_n(length)?;
visitor.visit_borrowed_bytes(bytes)
let major = self.peek_major()?;
match major {
#[cfg(feature = "bytes-from-array")]
MAJOR_ARRAY => {
let len = self.raw_deserialize_u32(MAJOR_ARRAY)?;
visitor.visit_seq(SeqAccess {
deserializer: self,
len: len as usize,
})
}
MAJOR_BYTES => {
// major type 2: "byte string"
let length = self.raw_deserialize_u32(MAJOR_BYTES)? as usize;
let bytes: &'de [u8] = self.try_take_n(length)?;
visitor.visit_borrowed_bytes(bytes)
}
_ => Err(Error::DeserializeBadMajor),
}
}

fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
Expand Down Expand Up @@ -706,6 +811,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
V: Visitor<'de>,
{
// Ignore extra fields/options
self.ignore()?;
visitor.visit_none()
}
}
Expand Down Expand Up @@ -790,6 +896,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {

#[cfg(test)]
mod tests {

// use super::*;
use super::from_bytes;

Expand Down Expand Up @@ -1063,6 +1170,100 @@ mod tests {
assert_eq!(de, e);
}

#[test]
fn de_ignored_any() {
use serde::de::IgnoredAny;
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
pub struct ValOuter<'a> {
inner: Outer<'a>,
val: &'a str,
}
#[derive(Deserialize)]
pub struct ValIgnored<'a> {
#[allow(unused)]
inner: IgnoredAny,
val: &'a str,
}

#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
pub struct Inner<'a> {
u8: u8,
u16: u16,
u32: u32,
u64: u64,
i8: i8,
i16: i16,
i32: i32,
i64: i64,
str: &'a str,
option: Option<&'a str>,
#[serde(with = "serde_bytes")]
bytes: &'a [u8],
unit: (),
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
pub struct Outer<'a> {
u8: u8,
u16: u16,
u32: u32,
u64: u64,
i8: i8,
i16: i16,
i32: i32,
i64: i64,
str: &'a str,
#[serde(with = "serde_bytes")]
bytes: &'a [u8],
unit: (),
option: Option<&'a str>,
nested: Inner<'a>,
}

let mut buf = [0; 1024];
let val = Outer {
u8: u8::MAX,
u16: u16::MAX,
u32: u32::MAX,
u64: u64::MAX,
i8: i8::MIN,
i16: i16::MIN,
i32: i32::MIN,
i64: i64::MIN,
str: "string",
bytes: b"bytes",
unit: (),
option: Some("option"),
nested: Inner {
u8: 0,
u16: 0,
u32: 0,
u64: 0,
i8: i8::MIN,
i16: i16::MIN,
i32: i32::MIN,
i64: i64::MIN,
str: "",
option: None,
bytes: b"",
unit: (),
},
};
let ser = cbor_serialize(&val, &mut buf).unwrap();
let _: IgnoredAny = cbor_deserialize(ser).unwrap();

let val = ValOuter {
inner: val,
val: "value",
};
let ser = cbor_serialize(&val, &mut buf).unwrap();
let de: ValOuter = cbor_deserialize(ser).unwrap();
assert_eq!(val, de);
let de: ValIgnored = cbor_deserialize(ser).unwrap();
assert_eq!(de.val, "value");
}

// #[test]
// fn fuzzer_things() {
// let data: [u8; 2] = [160, 96];
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ pub enum Error {
DeserializeBadI16,
/// Expected a i32, was too large
DeserializeBadI32,
/// Expected a i64, was too large
DeserializeBadI64,
/// Expected a u8
DeserializeBadU8,
/// Expected a u16
Expand Down Expand Up @@ -91,6 +93,7 @@ impl Display for Error {
DeserializeBadI8 => "Expected a i8",
DeserializeBadI16 => "Expected a i16",
DeserializeBadI32 => "Expected a i32",
DeserializeBadI64 => "Expected a i64",
DeserializeBadMajor => "Expected a different major type",
DeserializeBadU8 => "Expected a u8",
DeserializeBadU16 => "Expected a u16",
Expand Down