diff --git a/edgedb-protocol/src/query_arg.rs b/edgedb-protocol/src/query_arg.rs index 04e8c2c9..2cf317c5 100644 --- a/edgedb-protocol/src/query_arg.rs +++ b/edgedb-protocol/src/query_arg.rs @@ -3,24 +3,24 @@ Contains the [QueryArg](crate::query_arg::QueryArg) and [QueryArgs](crate::query */ use std::convert::{TryFrom, TryInto}; +use std::ops::Deref; use std::sync::Arc; -use bytes::{BytesMut, BufMut}; +use bytes::{BufMut, BytesMut}; use snafu::OptionExt; use uuid::Uuid; -use edgedb_errors::{Error, ErrorKind}; -use edgedb_errors::{ClientEncodingError, ProtocolError, DescriptorMismatch}; -use edgedb_errors::{ParameterTypeMismatchError}; +use edgedb_errors::ParameterTypeMismatchError; +use edgedb_errors::{ClientEncodingError, DescriptorMismatch, ProtocolError}; +use edgedb_errors::{Error, ErrorKind, InvalidReferenceError}; -use crate::codec::{self, Codec, build_codec}; -use crate::descriptors::Descriptor; +use crate::codec::{self, build_codec, Codec}; use crate::descriptors::TypePos; +use crate::descriptors::{Descriptor, EnumerationTypeDescriptor}; use crate::errors; use crate::features::ProtocolVersion; -use crate::value::Value; use crate::model::range; - +use crate::value::Value; pub struct Encoder<'a> { pub ctx: &'a DescriptorContext<'a>, @@ -29,18 +29,14 @@ pub struct Encoder<'a> { /// A single argument for a query. pub trait QueryArg: Send + Sync + Sized { - fn encode_slot(&self, encoder: &mut Encoder) - -> Result<(), Error>; - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) - -> Result<(), Error>; + fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error>; + fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>; fn to_value(&self) -> Result; } pub trait ScalarArg: Send + Sync + Sized { - fn encode(&self, encoder: &mut Encoder) - -> Result<(), Error>; - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) - -> Result<(), Error>; + fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>; + fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>; fn to_value(&self) -> Result; } @@ -50,8 +46,7 @@ pub trait ScalarArg: Send + Sync + Sized { /// it for a structure in this case it's treated as a named tuple (i.e. query /// should include named arguments rather than numeric ones). pub trait QueryArgs: Send + Sync { - fn encode(&self, encoder: &mut Encoder) - -> Result<(), Error>; + fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>; } pub struct DescriptorContext<'a> { @@ -62,51 +57,47 @@ pub struct DescriptorContext<'a> { } impl<'a> Encoder<'a> { - pub fn new(ctx: &'a DescriptorContext<'a>, buf: &'a mut BytesMut) - -> Encoder<'a> - { + pub fn new(ctx: &'a DescriptorContext<'a>, buf: &'a mut BytesMut) -> Encoder<'a> { Encoder { ctx, buf } } - pub fn length_prefixed(&mut self, - f: impl FnOnce(&mut Encoder) -> Result<(), Error>) - -> Result<(), Error> - { + pub fn length_prefixed( + &mut self, + f: impl FnOnce(&mut Encoder) -> Result<(), Error>, + ) -> Result<(), Error> { self.buf.reserve(4); let pos = self.buf.len(); - self.buf.put_u32(0); // replaced after serializing a value - // + self.buf.put_u32(0); // replaced after serializing a value + // f(self)?; - let len = self.buf.len()-pos-4; - self.buf[pos..pos+4].copy_from_slice(&u32::try_from(len) - .map_err(|_| ClientEncodingError::with_message( - "alias is too long"))? - .to_be_bytes()); + let len = self.buf.len() - pos - 4; + self.buf[pos..pos + 4].copy_from_slice( + &u32::try_from(len) + .map_err(|_| ClientEncodingError::with_message("alias is too long"))? + .to_be_bytes(), + ); Ok(()) } } impl DescriptorContext<'_> { - pub fn get(&self, type_pos: TypePos) - -> Result<&Descriptor, Error> - { - self.descriptors.get(type_pos.0 as usize) - .ok_or_else(|| ProtocolError::with_message( - "invalid type descriptor")) + pub fn get(&self, type_pos: TypePos) -> Result<&Descriptor, Error> { + self.descriptors + .get(type_pos.0 as usize) + .ok_or_else(|| ProtocolError::with_message("invalid type descriptor")) } pub fn build_codec(&self) -> Result, Error> { build_codec(self.root_pos, self.descriptors) .map_err(|e| ProtocolError::with_source(e) .context("error decoding input codec")) } - pub fn wrong_type(&self, descriptor: &Descriptor, expected: &str) -> Error - { - DescriptorMismatch::with_message(format!("\nEdgeDB returned unexpected type {descriptor:?}\nClient expected {expected}")) + pub fn wrong_type(&self, descriptor: &Descriptor, expected: &str) -> Error { + DescriptorMismatch::with_message(format!( + "\nEdgeDB returned unexpected type {descriptor:?}\nClient expected {expected}" + )) } - pub fn field_number(&self, expected: usize, unexpected: usize) - -> Error - { + pub fn field_number(&self, expected: usize, unexpected: usize) -> Error { DescriptorMismatch::with_message(format!( "expected {} fields, got {}", expected, unexpected)) @@ -114,15 +105,11 @@ impl DescriptorContext<'_> { } impl ScalarArg for &T { - fn encode(&self, encoder: &mut Encoder) - -> Result<(), Error> - { + fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { (*self).encode(encoder) } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) - -> Result<(), Error> - { + fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { T::check_descriptor(ctx, pos) } @@ -132,23 +119,23 @@ impl ScalarArg for &T { } impl QueryArgs for () { - fn encode(&self, enc: &mut Encoder) - -> Result<(), Error> - { + fn encode(&self, enc: &mut Encoder) -> Result<(), Error> { if enc.ctx.root_pos.is_some() { if enc.ctx.proto.is_at_most(0, 11) { let root = enc.ctx.root_pos.and_then(|p| enc.ctx.get(p).ok()); match root { Some(Descriptor::Tuple(t)) - if t.id == Uuid::from_u128(0xFF) - && t.element_types.is_empty() - => {} - _ => return Err(ParameterTypeMismatchError::with_message( - "query arguments expected")), + if t.id == Uuid::from_u128(0xFF) && t.element_types.is_empty() => {} + _ => { + return Err(ParameterTypeMismatchError::with_message( + "query arguments expected", + )) + } }; } else { return Err(ParameterTypeMismatchError::with_message( - "query arguments expected")); + "query arguments expected", + )); } } if enc.ctx.proto.is_at_most(0, 11) { @@ -160,9 +147,7 @@ impl QueryArgs for () { } impl QueryArg for Value { - fn encode_slot(&self, enc: &mut Encoder) - -> Result<(), Error> - { + fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> { use Value::*; match self { Nothing => { @@ -207,9 +192,7 @@ impl QueryArg for Value { Ok(()) } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) - -> Result<(), Error> - { + fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { use Descriptor::*; use Value::*; let mut desc = ctx.get(pos)?; @@ -217,7 +200,7 @@ impl QueryArg for Value { desc = ctx.get(d.base_type_pos)?; } match (self, desc) { - (Nothing, _) => Ok(()), // any descriptor works + (Nothing, _) => Ok(()), // any descriptor works (_, Scalar(_)) => unreachable!("scalar dereference to a non-base type"), (BigInt(_), BaseScalar(d)) if d.id == codec::STD_BIGINT => Ok(()), (Bool(_), BaseScalar(d)) if d.id == codec::STD_BOOL => Ok(()), @@ -237,24 +220,40 @@ impl QueryArg for Value { (LocalDatetime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATETIME => Ok(()), (LocalTime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_TIME => Ok(()), (RelativeDuration(_), BaseScalar(d)) if d.id == codec::CAL_RELATIVE_DURATION => Ok(()), - (Str(_), BaseScalar(d)) if d.id == codec::STD_STR => Ok(()), + (Str(_), BaseScalar(d)) if d.id == codec::STD_STR => Ok(()), (Uuid(_), BaseScalar(d)) if d.id == codec::STD_UUID => Ok(()), + (Enum(val), Enumeration(EnumerationTypeDescriptor { members, .. })) => { + let val = val.deref(); + if members.iter().any(|c| c == val) { + Ok(()) + } else { + let members = { + let mut members = members + .into_iter() + .map(|c| format!("'{c}'")) + .collect::>(); + members.sort_unstable(); + members.join(", ") + }; + Err(InvalidReferenceError::with_message(format!( + "Expected one of: {members}, while enum value '{val}' was provided" + ))) + } + } // TODO(tailhook) all types (_, desc) => Err(ctx.wrong_type(desc, self.kind())), } } - fn to_value(&self) -> Result - { + fn to_value(&self) -> Result { Ok(self.clone()) } } impl QueryArgs for Value { - fn encode(&self, enc: &mut Encoder) - -> Result<(), Error> - { + fn encode(&self, enc: &mut Encoder) -> Result<(), Error> { let codec = enc.ctx.build_codec()?; - codec.encode(&mut enc.buf, self) + codec + .encode(&mut enc.buf, self) .map_err(ClientEncodingError::with_source) } } @@ -265,16 +264,17 @@ impl QueryArg for T { let pos = enc.buf.len(); enc.buf.put_u32(0); // will fill after encoding ScalarArg::encode(self, enc)?; - let len = enc.buf.len()-pos-4; - enc.buf[pos..pos+4].copy_from_slice(&i32::try_from(len) - .ok().context(errors::ElementTooLong) + let len = enc.buf.len() - pos - 4; + enc.buf[pos..pos + 4].copy_from_slice( + &i32::try_from(len) + .ok() + .context(errors::ElementTooLong) .map_err(ClientEncodingError::with_source)? - .to_be_bytes()); + .to_be_bytes(), + ); Ok(()) } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) - -> Result<(), Error> - { + fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { T::check_descriptor(ctx, pos) } fn to_value(&self) -> Result { @@ -292,9 +292,7 @@ impl QueryArg for Option { Ok(()) } } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) - -> Result<(), Error> - { + fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { T::check_descriptor(ctx, pos) } fn to_value(&self) -> Result { @@ -311,28 +309,28 @@ impl QueryArg for Vec { enc.length_prefixed(|enc| { if self.is_empty() { enc.buf.reserve(12); - enc.buf.put_u32(0); // ndims - enc.buf.put_u32(0); // reserved0 - enc.buf.put_u32(0); // reserved1 + enc.buf.put_u32(0); // ndims + enc.buf.put_u32(0); // reserved0 + enc.buf.put_u32(0); // reserved1 return Ok(()); } enc.buf.reserve(20); - enc.buf.put_u32(1); // ndims - enc.buf.put_u32(0); // reserved0 - enc.buf.put_u32(0); // reserved1 - enc.buf.put_u32(self.len().try_into() - .map_err(|_| ClientEncodingError::with_message( - "array is too long"))?); - enc.buf.put_u32(1); // lower + enc.buf.put_u32(1); // ndims + enc.buf.put_u32(0); // reserved0 + enc.buf.put_u32(0); // reserved1 + enc.buf.put_u32( + self.len() + .try_into() + .map_err(|_| ClientEncodingError::with_message("array is too long"))?, + ); + enc.buf.put_u32(1); // lower for item in self { enc.length_prefixed(|enc| item.encode(enc))?; } Ok(()) }) } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) - -> Result<(), Error> - { + fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { let desc = ctx.get(pos)?; if let Descriptor::Array(arr) = desc { T::check_descriptor(ctx, arr.type_pos) @@ -341,8 +339,11 @@ impl QueryArg for Vec { } } fn to_value(&self) -> Result { - Ok(Value::Array(self.iter().map(|v| v.to_value()) - .collect::>()?)) + Ok(Value::Array( + self.iter() + .map(|v| v.to_value()) + .collect::>()?, + )) } } @@ -352,28 +353,28 @@ impl QueryArg for Vec { enc.length_prefixed(|enc| { if self.is_empty() { enc.buf.reserve(12); - enc.buf.put_u32(0); // ndims - enc.buf.put_u32(0); // reserved0 - enc.buf.put_u32(0); // reserved1 + enc.buf.put_u32(0); // ndims + enc.buf.put_u32(0); // reserved0 + enc.buf.put_u32(0); // reserved1 return Ok(()); } enc.buf.reserve(20); - enc.buf.put_u32(1); // ndims - enc.buf.put_u32(0); // reserved0 - enc.buf.put_u32(0); // reserved1 - enc.buf.put_u32(self.len().try_into() - .map_err(|_| ClientEncodingError::with_message( - "array is too long"))?); - enc.buf.put_u32(1); // lower + enc.buf.put_u32(1); // ndims + enc.buf.put_u32(0); // reserved0 + enc.buf.put_u32(0); // reserved1 + enc.buf.put_u32( + self.len() + .try_into() + .map_err(|_| ClientEncodingError::with_message("array is too long"))?, + ); + enc.buf.put_u32(1); // lower for item in self { enc.length_prefixed(|enc| item.encode(enc))?; } Ok(()) }) } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) - -> Result<(), Error> - { + fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { let desc = ctx.get(pos)?; if let Descriptor::Array(arr) = desc { for val in self { @@ -385,49 +386,55 @@ impl QueryArg for Vec { } } fn to_value(&self) -> Result { - Ok(Value::Array(self.iter().map(|v| v.to_value()) - .collect::>()?)) + Ok(Value::Array( + self.iter() + .map(|v| v.to_value()) + .collect::>()?, + )) } } impl QueryArg for range::Range> { - fn encode_slot(&self, encoder: &mut Encoder) - -> Result<(), Error> - { + fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error> { encoder.length_prefixed(|encoder| { - let flags = - if self.empty { range::EMPTY } else { - (if self.inc_lower { range::LB_INC } else { 0 }) | - (if self.inc_upper { range::UB_INC } else { 0 }) | - (if self.lower.is_none() { range::LB_INF } else { 0 }) | - (if self.upper.is_none() { range::UB_INF } else { 0 }) - }; + let flags = if self.empty { + range::EMPTY + } else { + (if self.inc_lower { range::LB_INC } else { 0 }) + | (if self.inc_upper { range::UB_INC } else { 0 }) + | (if self.lower.is_none() { + range::LB_INF + } else { + 0 + }) + | (if self.upper.is_none() { + range::UB_INF + } else { + 0 + }) + }; encoder.buf.reserve(1); encoder.buf.put_u8(flags as u8); if let Some(lower) = &self.lower { - encoder.length_prefixed(|encoder| { - lower.encode(encoder) - })? + encoder.length_prefixed(|encoder| lower.encode(encoder))? } if let Some(upper) = &self.upper { - encoder.length_prefixed(|encoder| { - upper.encode(encoder) - })?; + encoder.length_prefixed(|encoder| upper.encode(encoder))?; } Ok(()) }) } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) - -> Result<(), Error> - { + fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { let desc = ctx.get(pos)?; if let Descriptor::Range(rng) = desc { - self.lower.as_ref() + self.lower + .as_ref() .map(|v| v.check_descriptor(ctx, rng.type_pos)) .transpose()?; - self.upper.as_ref() + self.upper + .as_ref() .map(|v| v.check_descriptor(ctx, rng.type_pos)) .transpose()?; Ok(()) @@ -440,7 +447,6 @@ impl QueryArg for range::Range> { } } - macro_rules! implement_tuple { ( $count:expr, $($name:ident,)+ ) => { impl<$($name:QueryArg),+> QueryArgs for ($($name,)+) { @@ -508,15 +514,15 @@ macro_rules! implement_tuple { } } -implement_tuple!{1, T0, } -implement_tuple!{2, T0, T1, } -implement_tuple!{3, T0, T1, T2, } -implement_tuple!{4, T0, T1, T2, T3, } -implement_tuple!{5, T0, T1, T2, T3, T4, } -implement_tuple!{6, T0, T1, T2, T3, T4, T5, } -implement_tuple!{7, T0, T1, T2, T3, T4, T5, T6, } -implement_tuple!{8, T0, T1, T2, T3, T4, T5, T6, T7, } -implement_tuple!{9, T0, T1, T2, T3, T4, T5, T6, T7, T8, } -implement_tuple!{10, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, } -implement_tuple!{11, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, } -implement_tuple!{12, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, } +implement_tuple! {1, T0, } +implement_tuple! {2, T0, T1, } +implement_tuple! {3, T0, T1, T2, } +implement_tuple! {4, T0, T1, T2, T3, } +implement_tuple! {5, T0, T1, T2, T3, T4, } +implement_tuple! {6, T0, T1, T2, T3, T4, T5, } +implement_tuple! {7, T0, T1, T2, T3, T4, T5, T6, } +implement_tuple! {8, T0, T1, T2, T3, T4, T5, T6, T7, } +implement_tuple! {9, T0, T1, T2, T3, T4, T5, T6, T7, T8, } +implement_tuple! {10, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, } +implement_tuple! {11, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, } +implement_tuple! {12, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, }