Skip to content

Commit

Permalink
Merge branch 'origin/master' into feat/named-args
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Apr 9, 2024
2 parents d4d70d7 + ca67c54 commit 671f98b
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 80 deletions.
3 changes: 0 additions & 3 deletions .vscode/settings.json

This file was deleted.

7 changes: 4 additions & 3 deletions edgedb-protocol/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1364,11 +1364,12 @@ impl Codec for Enum {
-> Result<(), EncodeError>
{
let val = match val {
Value::Enum(val) => val,
Value::Enum(val) => val.0.as_ref(),
Value::Str(val) => val.as_str(),
_ => Err(errors::invalid_value(type_name::<Self>(), val))?,
};
ensure!(self.members.get(&val.0).is_some(), errors::MissingEnumValue);
buf.extend(val.0.as_bytes());
ensure!(self.members.get(val).is_some(), errors::MissingEnumValue);
buf.extend(val.as_bytes());
Ok(())
}
}
32 changes: 17 additions & 15 deletions edgedb-protocol/src/query_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,7 @@ impl QueryArg for Value {
(Uuid(_), BaseScalar(d)) if d.id == codec::STD_UUID => Ok(()),
(Enum(val), Enumeration(EnumerationTypeDescriptor { members, .. })) => {
let val = val.deref();
if members.iter().any(|c| c == val) {
Ok(())
} else {
let members = {
let mut members = members
.iter()
.map(|c| format!("'{c}'"))
.collect::<Vec<_>>();
members.sort_unstable();
members.join(", ")
};
Err(InvalidReferenceError::with_message(format!(
"Expected one of: {members}, while enum value '{val}' was provided"
)))
}
check_enum(val, members)
}
// TODO(tailhook) all types
(_, desc) => Err(ctx.wrong_type(desc, self.kind())),
Expand All @@ -249,6 +235,22 @@ impl QueryArg for Value {
}
}

pub(crate) fn check_enum(variant_name: &str, expected_members: &[String]) -> Result<(), Error> {
if expected_members.iter().any(|c| c == variant_name) {
Ok(())
} else {
let mut members = expected_members
.into_iter()
.map(|c| format!("'{c}'"))
.collect::<Vec<_>>();
members.sort_unstable();
let members = members.join(", ");
Err(InvalidReferenceError::with_message(format!(
"Expected one of: {members}, while enum value '{variant_name}' was provided"
)))
}
}

impl QueryArgs for Value {
fn encode(&self, enc: &mut Encoder) -> Result<(), Error> {
let codec = enc.ctx.build_codec()?;
Expand Down
5 changes: 5 additions & 0 deletions edgedb-protocol/src/serialization/decode/raw_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ impl ScalarArg for &'_ str {
fn check_descriptor(ctx: &DescriptorContext, pos: TypePos)
-> Result<(), Error>
{
// special case: &str can express an enum variant
if let Descriptor::Enumeration(_) = ctx.get(pos)? {
return Ok(())
}

check_scalar(ctx, pos, String::uuid(), String::typename())
}
fn to_value(&self) -> Result<Value, Error> {
Expand Down
121 changes: 63 additions & 58 deletions edgedb-tokio/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl Config {
}
}

#[derive(Clone)]
#[derive(Debug, Clone)]
pub(crate) struct ConfigInner {
pub address: Address,
pub admin: bool,
Expand Down Expand Up @@ -556,6 +556,7 @@ impl<'a> DsnHelper<'a> {
});
self.retrieve_value("branch", v, |s| {
let s = s.strip_prefix('/').unwrap_or(&s);
dbg!("here");
validate_branch(&s)?;
Ok(s.to_owned())
}).await
Expand Down Expand Up @@ -702,6 +703,7 @@ impl Builder {

/// Set the branch name.
pub fn branch(&mut self, branch: &str) -> Result<&mut Self, Error> {
dbg!("here");
validate_branch(branch)?;
self.branch = Some(branch.into());
Ok(self)
Expand Down Expand Up @@ -965,10 +967,10 @@ impl Builder {
let full_path = resolve_unix(unix_path, port, self.admin);
cfg.address = Address::Unix(full_path);
}
if let Some(database) = &self.database {
if let Some(branch) = &self.branch {
if let Some((d, b)) = &self.database.as_ref().zip(self.branch.as_ref()) {
if d != b {
errors.push(InvalidArgumentError::with_message(format!(
"database {} conflicts with branch {}", database, branch
"database {d} conflicts with branch {b}"
)))
}
}
Expand Down Expand Up @@ -1123,24 +1125,26 @@ impl Builder {
async fn granular_env(&self, cfg: &mut ConfigInner,
errors: &mut Vec<Error>)
{
let database = self.database.clone().or_else(|| {
get_env("EDGEDB_DATABASE")
.and_then(|v| v.map(validate_database).transpose())
.map_err(|e| errors.push(e)).ok().flatten()
});
// TODO(tailhook) check if not empty
if let Some(database) = database {
cfg.database = database;
}

let branch = self.branch.clone().or_else(|| {
get_env("EDGEDB_BRANCH")
.and_then(|v| v.map(validate_branch).transpose())
.map_err(|e| errors.push(e)).ok().flatten()
});
let database_branch = self.database.as_ref().or(self.branch.as_ref())
.cloned()
.or_else(|| {
let database = get_env("EDGEDB_DATABASE")
.map_err(|e| errors.push(e)).ok()?;
let branch = get_env("EDGEDB_BRANCH")
.map_err(|e| errors.push(e)).ok()?;

if database.is_some() && branch.is_some() {
errors.push(InvalidArgumentError::with_message(
"Invalid environment: variables `EDGEDB_DATABASE` and `EDGEDB_BRANCH` are mutually exclusive",
));
return None;
}

if let Some(branch) = branch {
cfg.branch = branch;
database.or(branch)
});
if let Some(name) = database_branch {
cfg.database = name.clone();
cfg.branch = name;
}

let user = self.user.clone().or_else(|| {
Expand Down Expand Up @@ -1253,43 +1257,31 @@ impl Builder {
dsn.ignore_value("password");
}

let has_branch_option = dsn.query.contains_key("branch") || dsn.query.contains_key("branch_env") || dsn.query.contains_key("branch_file");
let has_database_option = dsn.query.contains_key("database") || dsn.query.contains_key("database_env") || dsn.query.contains_key("database_file");

if has_branch_option {
if has_database_option {
errors.push(InvalidArgumentError::with_message(
"Invalid DSN: `database` and `branch` cannot be present at the same time"
));
} else if self.database.is_some() {
errors.push(InvalidArgumentError::with_message(
"`branch` in DSN and `database` are mutually exclusive"
));
} else {
match dsn.retrieve_branch().await {
Ok(Some(value)) => cfg.branch = value,
Ok(None) => {},
Err(e) => errors.push(e)
}
}
} else if self.branch.is_some() {
if has_database_option {
errors.push(InvalidArgumentError::with_message(
"`database` in DSN and `branch` are mutually exclusive"
));
let has_query_branch = dsn.query.contains_key("branch") || dsn.query.contains_key("branch_env") || dsn.query.contains_key("branch_file");
let has_query_database = dsn.query.contains_key("database") || dsn.query.contains_key("database_env") || dsn.query.contains_key("database_file");
if has_query_branch && has_query_database {
errors.push(InvalidArgumentError::with_message(
"Invalid DSN: `database` and `branch` are mutually exclusive",
));
}
if self.branch.is_none() && self.database.is_none() {
let database_or_branch = if has_query_database {
dsn.retrieve_database().await
} else {
match dsn.retrieve_branch().await {
Ok(Some(value)) => cfg.branch = value,
Ok(None) => {},
Err(e) => errors.push(e)
}
dsn.retrieve_branch().await
};

match database_or_branch {
Ok(Some(name)) => {
cfg.branch = name.clone();
cfg.database = name;
},
Ok(None) => {}
Err(e) => errors.push(e),
}
} else {
match dsn.retrieve_database().await {
Ok(Some(value)) => cfg.database = value,
Ok(None) => {},
Err(e) => errors.push(e)
}
dsn.ignore_value("branch");
dsn.ignore_value("database");
}

match dsn.retrieve_secret_key().await {
Expand Down Expand Up @@ -1394,6 +1386,7 @@ impl Builder {
project_dir, path)
})?
.to_owned();
cfg.branch = cfg.database.clone();
}
Err(e) if e.kind() == io::ErrorKind::NotFound => {}
Err(e) => {
Expand Down Expand Up @@ -1645,8 +1638,17 @@ fn set_credentials(cfg: &mut ConfigInner, creds: &Credentials)
));
cfg.user = creds.user.clone();
cfg.password = creds.password.clone();
cfg.database = creds.database.clone().unwrap_or_else(|| "edgedb".into());
cfg.branch = creds.database.clone().unwrap_or_else(|| "__default__".into());

if let Some((b, d)) = creds.branch.as_ref().zip(creds.database.as_ref()) {
if b != d {
return Err(ClientError::with_message(
"branch and database are mutually exclusive")
);
}
}
let db_branch = creds.branch.as_ref().or(creds.database.as_ref());
cfg.database = db_branch.cloned().unwrap_or_else(|| "edgedb".into());
cfg.branch = db_branch.cloned().unwrap_or_else(|| "__default__".into());
cfg.tls_security = creds.tls_security;
cfg.creds_file_outdated = creds.file_outdated;
Ok(())
Expand Down Expand Up @@ -2072,7 +2074,7 @@ async fn from_dsn() {
));
assert_eq!(&cfg.0.user, "user1");
assert_eq!(&cfg.0.database, "db2");
assert_eq!(&cfg.0.branch, "__default__");
assert_eq!(&cfg.0.branch, "db2");
assert_eq!(cfg.0.password, Some("EiPhohl7".into()));

let cfg = Builder::new()
Expand All @@ -2087,6 +2089,7 @@ async fn from_dsn() {
));
assert_eq!(&cfg.0.user, "user2");
assert_eq!(&cfg.0.database, "db2");
assert_eq!(&cfg.0.branch, "db2");
assert_eq!(cfg.0.password, None);

// Tests overriding
Expand All @@ -2102,6 +2105,7 @@ async fn from_dsn() {
));
assert_eq!(&cfg.0.user, "edgedb");
assert_eq!(&cfg.0.database, "edgedb");
assert_eq!(&cfg.0.branch, "__default__");
assert_eq!(cfg.0.password, None);

let cfg = Builder::new()
Expand All @@ -2113,6 +2117,7 @@ async fn from_dsn() {
));
assert_eq!(&cfg.0.user, "user3");
assert_eq!(&cfg.0.database, "abcdef");
assert_eq!(&cfg.0.branch, "abcdef");
assert_eq!(cfg.0.password, Some("123123".into()));
}

Expand Down
31 changes: 30 additions & 1 deletion edgedb-tokio/tests/func/client.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use edgedb_protocol::eargs;
use edgedb_protocol::value::{EnumValue, Value};
use edgedb_tokio::Client;
use edgedb_errors::NoDataError;
use futures_util::stream::{self, StreamExt};

use crate::server::SERVER;

#[tokio::test]
async fn simple() -> anyhow::Result<()> {
async fn simple() {
let client = Client::new(&SERVER.config);
client.ensure_connected().await?;

Expand Down Expand Up @@ -43,6 +44,34 @@ async fn simple() -> anyhow::Result<()> {
client.execute("SELECT 1+1", &()).await?;
client.execute("START MIGRATION TO {}; ABORT MIGRATION", &()).await?;

// basic enum param
let enum_query = "SELECT <str>(<test::State>$0) = 'waiting'";
assert_eq!(
client.query_required_single::<bool, _>(
enum_query, &(Value::Enum(EnumValue::from("waiting")),)
).await.unwrap(),
true
);

// unsupported: enum param as Value::Str
client.query_required_single::<bool, (Value, )>(
enum_query, &(Value::Str("waiting".to_string()), ),
).await.unwrap_err();

// unsupported: enum param as String
client.query_required_single::<bool, (String, )>(
enum_query, &("waiting".to_string(), ),
).await.unwrap_err();

// enum param as &str
assert_eq!(
client.query_required_single::<bool, (&'_ str, )>(
enum_query, &("waiting", ),
).await.unwrap(),
true
);

// params as macro
let value = client.query_required_single::<String, _>(
"select (
std::array_join(<array<str>>$msg1, ' ')
Expand Down
2 changes: 2 additions & 0 deletions edgedb-tokio/tests/func/dbschema/test.esdl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
module test {
scalar type State extending enum<'done', 'waiting', 'blocked'>;

type Counter {
required property name -> str {
constraint std::exclusive;
Expand Down

0 comments on commit 671f98b

Please sign in to comment.