Skip to content

Commit

Permalink
Accept Value::Str for Enum encoding (#308)
Browse files Browse the repository at this point in the history
Co-authored-by: Aljaž Mur Eržen <[email protected]>
  • Loading branch information
MrFoxPro and aljazerzen authored Apr 9, 2024
1 parent 787a469 commit ca67c54
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 18 deletions.
7 changes: 4 additions & 3 deletions edgedb-protocol/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1364,11 +1364,12 @@ impl Codec for Enum {
-> Result<(), EncodeError>
{
let val = match val {
Value::Enum(val) => val,
Value::Enum(val) => val.0.as_ref(),
Value::Str(val) => val.as_str(),
_ => Err(errors::invalid_value(type_name::<Self>(), val))?,
};
ensure!(self.members.get(&val.0).is_some(), errors::MissingEnumValue);
buf.extend(val.0.as_bytes());
ensure!(self.members.get(val).is_some(), errors::MissingEnumValue);
buf.extend(val.as_bytes());
Ok(())
}
}
32 changes: 17 additions & 15 deletions edgedb-protocol/src/query_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,7 @@ impl QueryArg for Value {
(Uuid(_), BaseScalar(d)) if d.id == codec::STD_UUID => Ok(()),
(Enum(val), Enumeration(EnumerationTypeDescriptor { members, .. })) => {
let val = val.deref();
if members.iter().any(|c| c == val) {
Ok(())
} else {
let members = {
let mut members = members
.iter()
.map(|c| format!("'{c}'"))
.collect::<Vec<_>>();
members.sort_unstable();
members.join(", ")
};
Err(InvalidReferenceError::with_message(format!(
"Expected one of: {members}, while enum value '{val}' was provided"
)))
}
check_enum(val, members)
}
// TODO(tailhook) all types
(_, desc) => Err(ctx.wrong_type(desc, self.kind())),
Expand All @@ -249,6 +235,22 @@ impl QueryArg for Value {
}
}

pub(crate) fn check_enum(variant_name: &str, expected_members: &[String]) -> Result<(), Error> {
if expected_members.iter().any(|c| c == variant_name) {
Ok(())
} else {
let mut members = expected_members
.into_iter()
.map(|c| format!("'{c}'"))
.collect::<Vec<_>>();
members.sort_unstable();
let members = members.join(", ");
Err(InvalidReferenceError::with_message(format!(
"Expected one of: {members}, while enum value '{variant_name}' was provided"
)))
}
}

impl QueryArgs for Value {
fn encode(&self, enc: &mut Encoder) -> Result<(), Error> {
let codec = enc.ctx.build_codec()?;
Expand Down
5 changes: 5 additions & 0 deletions edgedb-protocol/src/serialization/decode/raw_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ impl ScalarArg for &'_ str {
fn check_descriptor(ctx: &DescriptorContext, pos: TypePos)
-> Result<(), Error>
{
// special case: &str can express an enum variant
if let Descriptor::Enumeration(_) = ctx.get(pos)? {
return Ok(())
}

check_scalar(ctx, pos, String::uuid(), String::typename())
}
fn to_value(&self) -> Result<Value, Error> {
Expand Down
28 changes: 28 additions & 0 deletions edgedb-tokio/tests/func/client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use edgedb_protocol::value::{EnumValue, Value};
use edgedb_tokio::Client;
use edgedb_errors::NoDataError;
use futures_util::stream::{self, StreamExt};
Expand Down Expand Up @@ -42,6 +43,33 @@ async fn simple() -> anyhow::Result<()> {
client.execute("SELECT 1+1", &()).await?;
client.execute("START MIGRATION TO {}; ABORT MIGRATION", &()).await?;

// basic enum param
let enum_query = "SELECT <str>(<test::State>$0) = 'waiting'";
assert_eq!(
client.query_required_single::<bool, _>(
enum_query, &(Value::Enum(EnumValue::from("waiting")),)
).await.unwrap(),
true
);

// unsupported: enum param as Value::Str
client.query_required_single::<bool, (Value, )>(
enum_query, &(Value::Str("waiting".to_string()), ),
).await.unwrap_err();

// unsupported: enum param as String
client.query_required_single::<bool, (String, )>(
enum_query, &("waiting".to_string(), ),
).await.unwrap_err();

// enum param as &str
assert_eq!(
client.query_required_single::<bool, (&'_ str, )>(
enum_query, &("waiting", ),
).await.unwrap(),
true
);

Ok(())
}

Expand Down
2 changes: 2 additions & 0 deletions edgedb-tokio/tests/func/dbschema/test.esdl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
module test {
scalar type State extending enum<'done', 'waiting', 'blocked'>;

type Counter {
required property name -> str {
constraint std::exclusive;
Expand Down

0 comments on commit ca67c54

Please sign in to comment.