diff --git a/edgedb-protocol/src/codec.rs b/edgedb-protocol/src/codec.rs index 685e417a..8e3bdad0 100644 --- a/edgedb-protocol/src/codec.rs +++ b/edgedb-protocol/src/codec.rs @@ -785,24 +785,28 @@ impl Codec for ArrayAdapter { impl<'a> From<&'a [descriptors::ShapeElement]> for ObjectShape { fn from(shape: &'a [descriptors::ShapeElement]) -> ObjectShape { ObjectShape(Arc::new(ObjectShapeInfo { - elements: shape.iter().map(|e| { - let descriptors::ShapeElement { - flag_implicit, - flag_link_property, - flag_link, - cardinality, - name, - type_pos: _, - } = e; - ShapeElement { - flag_implicit: *flag_implicit, - flag_link_property: *flag_link_property, - flag_link: *flag_link, - cardinality: *cardinality, - name: name.clone(), - } - }).collect(), - })) + elements: shape.iter().map(ShapeElement::from).collect(), + })) + } +} + +impl<'a> From<&'a descriptors::ShapeElement> for ShapeElement { + fn from(e: &'a descriptors::ShapeElement) -> ShapeElement { + let descriptors::ShapeElement { + flag_implicit, + flag_link_property, + flag_link, + cardinality, + name, + type_pos: _, + } = e; + ShapeElement { + flag_implicit: *flag_implicit, + flag_link_property: *flag_link_property, + flag_link: *flag_link, + cardinality: *cardinality, + name: name.clone(), + } } } diff --git a/edgedb-protocol/src/lib.rs b/edgedb-protocol/src/lib.rs index e3e0c0f6..a50a05b6 100644 --- a/edgedb-protocol/src/lib.rs +++ b/edgedb-protocol/src/lib.rs @@ -59,20 +59,21 @@ pub enum Value { mod query_result; // sealed trait should remain non-public -pub mod encoding; +pub mod client_message; +pub mod codec; pub mod common; +pub mod descriptors; +pub mod encoding; +pub mod error_response; +pub mod errors; pub mod features; +pub mod queryable; pub mod serialization; -pub mod client_message; pub mod server_message; -pub mod errors; -pub mod error_response; -pub mod descriptors; pub mod value; -pub mod codec; -pub mod queryable; +#[macro_use] +pub mod value_opt; pub mod query_arg; pub mod model; - pub use query_result::QueryResult; diff --git a/edgedb-protocol/src/query_result.rs b/edgedb-protocol/src/query_result.rs index 91be7964..6fe8cf91 100644 --- a/edgedb-protocol/src/query_result.rs +++ b/edgedb-protocol/src/query_result.rs @@ -11,7 +11,7 @@ use edgedb_errors::{ProtocolEncodingError, DescriptorMismatch}; use crate::codec::Codec; use crate::queryable::{Queryable, Decoder, DescriptorContext}; -use crate::descriptors::{TypePos}; +use crate::descriptors::TypePos; use crate::value::Value; pub trait Sealed: Sized {} diff --git a/edgedb-protocol/src/value_opt.rs b/edgedb-protocol/src/value_opt.rs new file mode 100644 index 00000000..1df03cf9 --- /dev/null +++ b/edgedb-protocol/src/value_opt.rs @@ -0,0 +1,125 @@ +use std::collections::HashMap; + +use edgedb_errors::{ClientEncodingError, Error, ErrorKind}; + +use crate::codec::{ObjectShape, ShapeElement}; +use crate::descriptors::Descriptor; +use crate::query_arg::{Encoder, QueryArgs}; +use crate::value::Value; + +/// An optional [Value] that can be constructed from `impl Into`, +/// `Option>`, `Vec>` or +/// `Option>>`. +/// Used by [eargs!] macro. +pub struct ValueOpt(Option); + +impl> From for ValueOpt { + fn from(value: V) -> Self { + ValueOpt(Some(value.into())) + } +} +impl> From> for ValueOpt +where + Value: From, +{ + fn from(value: Option) -> Self { + ValueOpt(value.map(Value::from)) + } +} +impl> From> for ValueOpt +where + Value: From, +{ + fn from(value: Vec) -> Self { + ValueOpt(Some(Value::Array( + value.into_iter().map(Value::from).collect(), + ))) + } +} +impl> From>> for ValueOpt +where + Value: From, +{ + fn from(value: Option>) -> Self { + let mapped = value.map(|value| Value::Array(value.into_iter().map(Value::from).collect())); + ValueOpt(mapped) + } +} +impl From for Option { + fn from(value: ValueOpt) -> Self { + value.0 + } +} + +impl QueryArgs for HashMap<&str, ValueOpt> { + fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { + if self.len() == 0 && encoder.ctx.root_pos.is_none() { + return Ok(()); + } + + let root_pos = encoder.ctx.root_pos.ok_or_else(|| { + ClientEncodingError::with_message(format!( + "provided {} named arguments, but no arguments were expected by the server", + self.len() + )) + })?; + + let Descriptor::ObjectShape(target_shape) = encoder.ctx.get(root_pos)? else { + return Err(ClientEncodingError::with_message( + "query didn't expect named arguments", + )); + }; + + let mut shape_elements: Vec = Vec::new(); + let mut fields: Vec> = Vec::new(); + + for param_descriptor in target_shape.elements.iter() { + let value = self.get(param_descriptor.name.as_str()); + + let Some(value) = value else { + return Err(ClientEncodingError::with_message(format!( + "argument for ${} missing", + param_descriptor.name + ))); + }; + + shape_elements.push(ShapeElement::from(param_descriptor)); + fields.push(value.0.clone()); + } + + Value::Object { + shape: ObjectShape::new(shape_elements), + fields, + } + .encode(encoder) + } +} + +/// Constructs named query arguments that implement [QueryArgs] so they can be passed +/// into any query method. +/// ```no_run +/// use edgedb_protocol::value::Value; +/// +/// let query = "SELECT ($my_str, $my_int)"; +/// let args = edgedb_protocol::named_args! { +/// "my_str" => "Hello world!".to_string(), +/// "my_int" => Value::Int64(42), +/// }; +/// ``` +/// +/// The value side of an argument must be `impl Into`. +/// The type of the returned object is `HashMap<&str, ValueOpt>`. +#[macro_export] +macro_rules! named_args { + ($($key:expr => $value:expr,)+) => { $crate::named_args!($($key => $value),+) }; + ($($key:expr => $value:expr),*) => { + { + const CAP: usize = <[()]>::len(&[$({ stringify!($key); }),*]); + let mut map = ::std::collections::HashMap::<&str, $crate::value_opt::ValueOpt>::with_capacity(CAP); + $( + map.insert($key, $crate::value_opt::ValueOpt::from($value)); + )* + map + } + }; +} diff --git a/edgedb-tokio/tests/func/client.rs b/edgedb-tokio/tests/func/client.rs index e8816929..532f40ba 100644 --- a/edgedb-tokio/tests/func/client.rs +++ b/edgedb-tokio/tests/func/client.rs @@ -1,3 +1,4 @@ +use edgedb_protocol::named_args; use edgedb_protocol::value::{EnumValue, Value}; use edgedb_tokio::Client; use edgedb_errors::NoDataError; @@ -70,6 +71,22 @@ async fn simple() -> anyhow::Result<()> { true ); + // named args + let value = client.query_required_single::( + "select ( + std::array_join(>$msg1, ' ') + ++ ($question ?? ' the ultimate question of life') + ++ ': ' + ++ $answer + );", + &named_args! { + "msg1" => vec!["the".to_string(), "answer".to_string(), "to".to_string()], + "question" => None::, + "answer" => 42 as i64, + } + ).await.unwrap(); + assert_eq!(value.as_str(), "the answer to the ultimate question of life: 42"); + Ok(()) }