Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support named arguments #304

Merged
merged 16 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions edgedb-protocol/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -785,24 +785,28 @@ impl Codec for ArrayAdapter {
impl<'a> From<&'a [descriptors::ShapeElement]> for ObjectShape {
fn from(shape: &'a [descriptors::ShapeElement]) -> ObjectShape {
ObjectShape(Arc::new(ObjectShapeInfo {
elements: shape.iter().map(|e| {
let descriptors::ShapeElement {
flag_implicit,
flag_link_property,
flag_link,
cardinality,
name,
type_pos: _,
} = e;
ShapeElement {
flag_implicit: *flag_implicit,
flag_link_property: *flag_link_property,
flag_link: *flag_link,
cardinality: *cardinality,
name: name.clone(),
}
}).collect(),
}))
elements: shape.iter().map(ShapeElement::from).collect(),
}))
}
}

impl<'a> From<&'a descriptors::ShapeElement> for ShapeElement {
fn from(e: &'a descriptors::ShapeElement) -> ShapeElement {
let descriptors::ShapeElement {
flag_implicit,
flag_link_property,
flag_link,
cardinality,
name,
type_pos: _,
} = e;
ShapeElement {
flag_implicit: *flag_implicit,
flag_link_property: *flag_link_property,
flag_link: *flag_link,
cardinality: *cardinality,
name: name.clone(),
}
}
}

Expand Down
17 changes: 9 additions & 8 deletions edgedb-protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,21 @@ pub enum Value {

mod query_result; // sealed trait should remain non-public

pub mod encoding;
pub mod client_message;
pub mod codec;
pub mod common;
pub mod descriptors;
pub mod encoding;
pub mod error_response;
pub mod errors;
pub mod features;
pub mod queryable;
pub mod serialization;
pub mod client_message;
pub mod server_message;
pub mod errors;
pub mod error_response;
pub mod descriptors;
pub mod value;
pub mod codec;
pub mod queryable;
#[macro_use]
pub mod value_opt;
pub mod query_arg;
pub mod model;


pub use query_result::QueryResult;
2 changes: 1 addition & 1 deletion edgedb-protocol/src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use edgedb_errors::{ProtocolEncodingError, DescriptorMismatch};

use crate::codec::Codec;
use crate::queryable::{Queryable, Decoder, DescriptorContext};
use crate::descriptors::{TypePos};
use crate::descriptors::TypePos;
use crate::value::Value;

pub trait Sealed: Sized {}
Expand Down
125 changes: 125 additions & 0 deletions edgedb-protocol/src/value_opt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use std::collections::HashMap;

use edgedb_errors::{ClientEncodingError, Error, ErrorKind};

use crate::codec::{ObjectShape, ShapeElement};
use crate::descriptors::Descriptor;
use crate::query_arg::{Encoder, QueryArgs};
use crate::value::Value;

/// An optional [Value] that can be constructed from `impl Into<Value>`,
/// `Option<impl Into<Value>>`, `Vec<impl Into<Value>>` or
/// `Option<Vec<impl Into<Value>>>`.
/// Used by [eargs!] macro.
pub struct ValueOpt(Option<Value>);

impl<V: Into<Value>> From<V> for ValueOpt {
fn from(value: V) -> Self {
ValueOpt(Some(value.into()))
}
}
impl<V: Into<Value>> From<Option<V>> for ValueOpt
where
Value: From<V>,
{
fn from(value: Option<V>) -> Self {
ValueOpt(value.map(Value::from))
}
}
impl<V: Into<Value>> From<Vec<V>> for ValueOpt
where
Value: From<V>,
{
fn from(value: Vec<V>) -> Self {
ValueOpt(Some(Value::Array(
value.into_iter().map(Value::from).collect(),
)))
}
}
impl<V: Into<Value>> From<Option<Vec<V>>> for ValueOpt
where
Value: From<V>,
{
fn from(value: Option<Vec<V>>) -> Self {
let mapped = value.map(|value| Value::Array(value.into_iter().map(Value::from).collect()));
ValueOpt(mapped)
}
}
impl From<ValueOpt> for Option<Value> {
fn from(value: ValueOpt) -> Self {
value.0
}
}

impl QueryArgs for HashMap<&str, ValueOpt> {
fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> {
if self.len() == 0 && encoder.ctx.root_pos.is_none() {
return Ok(());
}

let root_pos = encoder.ctx.root_pos.ok_or_else(|| {
ClientEncodingError::with_message(format!(
"provided {} named arguments, but no arguments were expected by the server",
self.len()
))
})?;

let Descriptor::ObjectShape(target_shape) = encoder.ctx.get(root_pos)? else {
return Err(ClientEncodingError::with_message(
"query didn't expect named arguments",
));
};

let mut shape_elements: Vec<ShapeElement> = Vec::new();
let mut fields: Vec<Option<Value>> = Vec::new();

for param_descriptor in target_shape.elements.iter() {
let value = self.get(param_descriptor.name.as_str());

let Some(value) = value else {
return Err(ClientEncodingError::with_message(format!(
"argument for ${} missing",
param_descriptor.name
)));
};

shape_elements.push(ShapeElement::from(param_descriptor));
fields.push(value.0.clone());
}

Value::Object {
shape: ObjectShape::new(shape_elements),
fields,
}
.encode(encoder)
}
}

/// Constructs named query arguments that implement [QueryArgs] so they can be passed
/// into any query method.
/// ```no_run
/// use edgedb_protocol::value::Value;
///
/// let query = "SELECT (<str>$my_str, <int64>$my_int)";
/// let args = edgedb_protocol::named_args! {
/// "my_str" => "Hello world!".to_string(),
/// "my_int" => Value::Int64(42),
/// };
/// ```
///
/// The value side of an argument must be `impl Into<ValueOpt>`.
/// The type of the returned object is `HashMap<&str, ValueOpt>`.
#[macro_export]
macro_rules! named_args {
($($key:expr => $value:expr,)+) => { $crate::named_args!($($key => $value),+) };
($($key:expr => $value:expr),*) => {
{
const CAP: usize = <[()]>::len(&[$({ stringify!($key); }),*]);
let mut map = ::std::collections::HashMap::<&str, $crate::value_opt::ValueOpt>::with_capacity(CAP);
$(
map.insert($key, $crate::value_opt::ValueOpt::from($value));
)*
map
}
};
}
17 changes: 17 additions & 0 deletions edgedb-tokio/tests/func/client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use edgedb_protocol::named_args;
use edgedb_protocol::value::{EnumValue, Value};
use edgedb_tokio::Client;
use edgedb_errors::NoDataError;
Expand Down Expand Up @@ -70,6 +71,22 @@ async fn simple() -> anyhow::Result<()> {
true
);

// named args
let value = client.query_required_single::<String, _>(
"select (
std::array_join(<array<str>>$msg1, ' ')
++ (<optional str>$question ?? ' the ultimate question of life')
++ ': '
++ <str><int64>$answer
);",
&named_args! {
"msg1" => vec!["the".to_string(), "answer".to_string(), "to".to_string()],
"question" => None::<String>,
"answer" => 42 as i64,
}
).await.unwrap();
assert_eq!(value.as_str(), "the answer to the ultimate question of life: 42");

Ok(())
}

Expand Down