diff --git a/src/types.rs b/src/types.rs index ba9749c3..861719a1 100644 --- a/src/types.rs +++ b/src/types.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::fmt; +use std::hash::{Hash, Hasher}; + use crate::traits::*; pub type NewRootContext = fn(context_id: u32) -> Box; @@ -77,3 +80,261 @@ pub enum PeerType { } pub type Bytes = Vec; + +/// Represents an HTTP header value that is not necessarily a UTF-8 encoded string. +#[derive(Eq)] +pub struct HeaderValue { + inner: Result, +} + +impl HeaderValue { + fn new(inner: Result) -> Self { + HeaderValue { inner } + } + + pub fn into_vec(self) -> Vec { + match self.inner { + Ok(string) => string.into_bytes(), + Err(bytes) => bytes, + } + } + + pub fn into_string_or_vec(self) -> Result> { + self.inner + } +} + +impl From> for HeaderValue { + #[inline] + fn from(data: Vec) -> Self { + Self::new(match String::from_utf8(data) { + Ok(string) => Ok(string), + Err(err) => Err(err.into_bytes()), + }) + } +} + +impl From<&[u8]> for HeaderValue { + fn from(data: &[u8]) -> Self { + data.to_owned().into() + } +} + +impl From for HeaderValue { + #[inline] + fn from(string: String) -> Self { + Self::new(Ok(string)) + } +} + +impl From<&str> for HeaderValue { + fn from(data: &str) -> Self { + data.to_owned().into() + } +} + +impl fmt::Display for HeaderValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.inner { + Ok(ref string) => fmt::Display::fmt(string, f), + Err(ref bytes) => fmt::Debug::fmt(bytes, f), + } + } +} + +impl fmt::Debug for HeaderValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.inner { + Ok(ref string) => fmt::Debug::fmt(string, f), + Err(ref bytes) => fmt::Debug::fmt(bytes, f), + } + } +} + +impl AsRef<[u8]> for HeaderValue { + #[inline] + fn as_ref(&self) -> &[u8] { + match self.inner { + Ok(ref string) => string.as_bytes(), + Err(ref bytes) => bytes.as_slice(), + } + } +} + +impl PartialEq for HeaderValue { + #[inline] + fn eq(&self, other: &HeaderValue) -> bool { + self.inner == other.inner + } +} + +impl PartialEq for HeaderValue { + #[inline] + fn eq(&self, other: &String) -> bool { + self.as_ref() == other.as_bytes() + } +} + +impl PartialEq> for HeaderValue { + #[inline] + fn eq(&self, other: &Vec) -> bool { + self.as_ref() == other.as_slice() + } +} + +impl PartialEq<&str> for HeaderValue { + #[inline] + fn eq(&self, other: &&str) -> bool { + self.as_ref() == other.as_bytes() + } +} + +impl PartialEq<&[u8]> for HeaderValue { + #[inline] + fn eq(&self, other: &&[u8]) -> bool { + self.as_ref() == *other + } +} + +impl PartialEq for String { + #[inline] + fn eq(&self, other: &HeaderValue) -> bool { + *other == *self + } +} + +impl PartialEq for Vec { + #[inline] + fn eq(&self, other: &HeaderValue) -> bool { + *other == *self + } +} + +impl PartialEq for &str { + #[inline] + fn eq(&self, other: &HeaderValue) -> bool { + *other == *self + } +} + +impl PartialEq for &[u8] { + #[inline] + fn eq(&self, other: &HeaderValue) -> bool { + *other == *self + } +} + +impl Hash for HeaderValue { + #[inline] + fn hash(&self, state: &mut H) { + match self.inner { + Ok(ref string) => Hash::hash(string, state), + Err(ref bytes) => Hash::hash(bytes, state), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + #[test] + fn test_header_value_display_string() { + let string: HeaderValue = String::from("utf-8 encoded string").into(); + + assert_eq!(format!("{}", string), "utf-8 encoded string"); + } + + #[test] + fn test_header_value_display_bytes() { + let bytes: HeaderValue = vec![144u8, 145u8, 146u8].into(); + + assert_eq!(format!("{}", bytes), "[144, 145, 146]"); + } + + #[test] + fn test_header_value_debug_string() { + let string: HeaderValue = String::from("utf-8 encoded string").into(); + + assert_eq!(format!("{:?}", string), "\"utf-8 encoded string\""); + } + + #[test] + fn test_header_value_debug_bytes() { + let bytes: HeaderValue = vec![144u8, 145u8, 146u8].into(); + + assert_eq!(format!("{:?}", bytes), "[144, 145, 146]"); + } + + #[test] + fn test_header_value_as_ref() { + fn receive(value: T) + where + T: AsRef<[u8]>, + { + value.as_ref(); + } + + let string: HeaderValue = String::from("utf-8 encoded string").into(); + receive(string); + + let bytes: HeaderValue = vec![144u8, 145u8, 146u8].into(); + receive(bytes); + } + + #[test] + fn test_header_value_eq_string() { + let string: HeaderValue = String::from("utf-8 encoded string").into(); + + assert_eq!(string, "utf-8 encoded string"); + assert_eq!(string, b"utf-8 encoded string" as &[u8]); + + assert_eq!("utf-8 encoded string", string); + assert_eq!(b"utf-8 encoded string" as &[u8], string); + + assert_eq!(string, string); + } + + #[test] + fn test_header_value_eq_bytes() { + let bytes: HeaderValue = vec![144u8, 145u8, 146u8].into(); + + assert_eq!(bytes, vec![144u8, 145u8, 146u8]); + assert_eq!(bytes, b"\x90\x91\x92" as &[u8]); + + assert_eq!(vec![144u8, 145u8, 146u8], bytes); + assert_eq!(b"\x90\x91\x92" as &[u8], bytes); + + assert_eq!(bytes, bytes); + } + + fn hash(t: &T) -> u64 { + let mut h = DefaultHasher::new(); + t.hash(&mut h); + h.finish() + } + + #[test] + fn test_header_value_hash_string() { + let string: HeaderValue = String::from("utf-8 encoded string").into(); + + assert_eq!(hash(&string), hash(&"utf-8 encoded string")); + assert_ne!(hash(&string), hash(&b"utf-8 encoded string")); + + assert_eq!(hash(&"utf-8 encoded string"), hash(&string)); + assert_ne!(hash(&b"utf-8 encoded string"), hash(&string)); + } + + #[test] + fn test_header_value_hash_bytes() { + let bytes: HeaderValue = vec![144u8, 145u8, 146u8].into(); + + assert_eq!(hash(&bytes), hash(&vec![144u8, 145u8, 146u8])); + assert_eq!(hash(&bytes), hash(&[144u8, 145u8, 146u8])); + + assert_eq!(hash(&vec![144u8, 145u8, 146u8]), hash(&bytes)); + assert_eq!(hash(&[144u8, 145u8, 146u8]), hash(&bytes)); + } +}