From 787a469d28e19b176a08397ee9d5a562236bee3e Mon Sep 17 00:00:00 2001 From: Quin Lynch <49576606+quinchs@users.noreply.github.com> Date: Tue, 9 Apr 2024 12:54:58 -0400 Subject: [PATCH 1/2] Tweak handling of branches and database params (#307) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Aljaž Mur Eržen --- .vscode/settings.json | 3 - edgedb-tokio/src/builder.rs | 121 +++++++++++++++++++----------------- 2 files changed, 63 insertions(+), 61 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index bfd35bb2..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "rust-analyzer.cargo.features": ["unstable", "chrono"] -} \ No newline at end of file diff --git a/edgedb-tokio/src/builder.rs b/edgedb-tokio/src/builder.rs index 6975aa2a..8755b683 100644 --- a/edgedb-tokio/src/builder.rs +++ b/edgedb-tokio/src/builder.rs @@ -101,7 +101,7 @@ impl Config { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub(crate) struct ConfigInner { pub address: Address, pub admin: bool, @@ -556,6 +556,7 @@ impl<'a> DsnHelper<'a> { }); self.retrieve_value("branch", v, |s| { let s = s.strip_prefix('/').unwrap_or(&s); + dbg!("here"); validate_branch(&s)?; Ok(s.to_owned()) }).await @@ -702,6 +703,7 @@ impl Builder { /// Set the branch name. pub fn branch(&mut self, branch: &str) -> Result<&mut Self, Error> { + dbg!("here"); validate_branch(branch)?; self.branch = Some(branch.into()); Ok(self) @@ -965,10 +967,10 @@ impl Builder { let full_path = resolve_unix(unix_path, port, self.admin); cfg.address = Address::Unix(full_path); } - if let Some(database) = &self.database { - if let Some(branch) = &self.branch { + if let Some((d, b)) = &self.database.as_ref().zip(self.branch.as_ref()) { + if d != b { errors.push(InvalidArgumentError::with_message(format!( - "database {} conflicts with branch {}", database, branch + "database {d} conflicts with branch {b}" ))) } } @@ -1123,24 +1125,26 @@ impl Builder { async fn granular_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) { - let database = self.database.clone().or_else(|| { - get_env("EDGEDB_DATABASE") - .and_then(|v| v.map(validate_database).transpose()) - .map_err(|e| errors.push(e)).ok().flatten() - }); - // TODO(tailhook) check if not empty - if let Some(database) = database { - cfg.database = database; - } - - let branch = self.branch.clone().or_else(|| { - get_env("EDGEDB_BRANCH") - .and_then(|v| v.map(validate_branch).transpose()) - .map_err(|e| errors.push(e)).ok().flatten() - }); + let database_branch = self.database.as_ref().or(self.branch.as_ref()) + .cloned() + .or_else(|| { + let database = get_env("EDGEDB_DATABASE") + .map_err(|e| errors.push(e)).ok()?; + let branch = get_env("EDGEDB_BRANCH") + .map_err(|e| errors.push(e)).ok()?; + + if database.is_some() && branch.is_some() { + errors.push(InvalidArgumentError::with_message( + "Invalid environment: variables `EDGEDB_DATABASE` and `EDGEDB_BRANCH` are mutually exclusive", + )); + return None; + } - if let Some(branch) = branch { - cfg.branch = branch; + database.or(branch) + }); + if let Some(name) = database_branch { + cfg.database = name.clone(); + cfg.branch = name; } let user = self.user.clone().or_else(|| { @@ -1253,43 +1257,31 @@ impl Builder { dsn.ignore_value("password"); } - let has_branch_option = dsn.query.contains_key("branch") || dsn.query.contains_key("branch_env") || dsn.query.contains_key("branch_file"); - let has_database_option = dsn.query.contains_key("database") || dsn.query.contains_key("database_env") || dsn.query.contains_key("database_file"); - - if has_branch_option { - if has_database_option { - errors.push(InvalidArgumentError::with_message( - "Invalid DSN: `database` and `branch` cannot be present at the same time" - )); - } else if self.database.is_some() { - errors.push(InvalidArgumentError::with_message( - "`branch` in DSN and `database` are mutually exclusive" - )); - } else { - match dsn.retrieve_branch().await { - Ok(Some(value)) => cfg.branch = value, - Ok(None) => {}, - Err(e) => errors.push(e) - } - } - } else if self.branch.is_some() { - if has_database_option { - errors.push(InvalidArgumentError::with_message( - "`database` in DSN and `branch` are mutually exclusive" - )); + let has_query_branch = dsn.query.contains_key("branch") || dsn.query.contains_key("branch_env") || dsn.query.contains_key("branch_file"); + let has_query_database = dsn.query.contains_key("database") || dsn.query.contains_key("database_env") || dsn.query.contains_key("database_file"); + if has_query_branch && has_query_database { + errors.push(InvalidArgumentError::with_message( + "Invalid DSN: `database` and `branch` are mutually exclusive", + )); + } + if self.branch.is_none() && self.database.is_none() { + let database_or_branch = if has_query_database { + dsn.retrieve_database().await } else { - match dsn.retrieve_branch().await { - Ok(Some(value)) => cfg.branch = value, - Ok(None) => {}, - Err(e) => errors.push(e) - } + dsn.retrieve_branch().await + }; + + match database_or_branch { + Ok(Some(name)) => { + cfg.branch = name.clone(); + cfg.database = name; + }, + Ok(None) => {} + Err(e) => errors.push(e), } } else { - match dsn.retrieve_database().await { - Ok(Some(value)) => cfg.database = value, - Ok(None) => {}, - Err(e) => errors.push(e) - } + dsn.ignore_value("branch"); + dsn.ignore_value("database"); } match dsn.retrieve_secret_key().await { @@ -1394,6 +1386,7 @@ impl Builder { project_dir, path) })? .to_owned(); + cfg.branch = cfg.database.clone(); } Err(e) if e.kind() == io::ErrorKind::NotFound => {} Err(e) => { @@ -1645,8 +1638,17 @@ fn set_credentials(cfg: &mut ConfigInner, creds: &Credentials) )); cfg.user = creds.user.clone(); cfg.password = creds.password.clone(); - cfg.database = creds.database.clone().unwrap_or_else(|| "edgedb".into()); - cfg.branch = creds.database.clone().unwrap_or_else(|| "__default__".into()); + + if let Some((b, d)) = creds.branch.as_ref().zip(creds.database.as_ref()) { + if b != d { + return Err(ClientError::with_message( + "branch and database are mutually exclusive") + ); + } + } + let db_branch = creds.branch.as_ref().or(creds.database.as_ref()); + cfg.database = db_branch.cloned().unwrap_or_else(|| "edgedb".into()); + cfg.branch = db_branch.cloned().unwrap_or_else(|| "__default__".into()); cfg.tls_security = creds.tls_security; cfg.creds_file_outdated = creds.file_outdated; Ok(()) @@ -2072,7 +2074,7 @@ async fn from_dsn() { )); assert_eq!(&cfg.0.user, "user1"); assert_eq!(&cfg.0.database, "db2"); - assert_eq!(&cfg.0.branch, "__default__"); + assert_eq!(&cfg.0.branch, "db2"); assert_eq!(cfg.0.password, Some("EiPhohl7".into())); let cfg = Builder::new() @@ -2087,6 +2089,7 @@ async fn from_dsn() { )); assert_eq!(&cfg.0.user, "user2"); assert_eq!(&cfg.0.database, "db2"); + assert_eq!(&cfg.0.branch, "db2"); assert_eq!(cfg.0.password, None); // Tests overriding @@ -2102,6 +2105,7 @@ async fn from_dsn() { )); assert_eq!(&cfg.0.user, "edgedb"); assert_eq!(&cfg.0.database, "edgedb"); + assert_eq!(&cfg.0.branch, "__default__"); assert_eq!(cfg.0.password, None); let cfg = Builder::new() @@ -2113,6 +2117,7 @@ async fn from_dsn() { )); assert_eq!(&cfg.0.user, "user3"); assert_eq!(&cfg.0.database, "abcdef"); + assert_eq!(&cfg.0.branch, "abcdef"); assert_eq!(cfg.0.password, Some("123123".into())); } From ca67c5480800c58a538327aaa96aff15023f560f Mon Sep 17 00:00:00 2001 From: MrFoxPro Date: Tue, 9 Apr 2024 22:02:41 +0500 Subject: [PATCH 2/2] Accept `Value::Str` for Enum encoding (#308) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Aljaž Mur Eržen --- edgedb-protocol/src/codec.rs | 7 ++-- edgedb-protocol/src/query_arg.rs | 32 ++++++++++--------- .../src/serialization/decode/raw_scalar.rs | 5 +++ edgedb-tokio/tests/func/client.rs | 28 ++++++++++++++++ edgedb-tokio/tests/func/dbschema/test.esdl | 2 ++ 5 files changed, 56 insertions(+), 18 deletions(-) diff --git a/edgedb-protocol/src/codec.rs b/edgedb-protocol/src/codec.rs index 473fd305..685e417a 100644 --- a/edgedb-protocol/src/codec.rs +++ b/edgedb-protocol/src/codec.rs @@ -1364,11 +1364,12 @@ impl Codec for Enum { -> Result<(), EncodeError> { let val = match val { - Value::Enum(val) => val, + Value::Enum(val) => val.0.as_ref(), + Value::Str(val) => val.as_str(), _ => Err(errors::invalid_value(type_name::(), val))?, }; - ensure!(self.members.get(&val.0).is_some(), errors::MissingEnumValue); - buf.extend(val.0.as_bytes()); + ensure!(self.members.get(val).is_some(), errors::MissingEnumValue); + buf.extend(val.as_bytes()); Ok(()) } } diff --git a/edgedb-protocol/src/query_arg.rs b/edgedb-protocol/src/query_arg.rs index 2b618423..91eb4017 100644 --- a/edgedb-protocol/src/query_arg.rs +++ b/edgedb-protocol/src/query_arg.rs @@ -224,21 +224,7 @@ impl QueryArg for Value { (Uuid(_), BaseScalar(d)) if d.id == codec::STD_UUID => Ok(()), (Enum(val), Enumeration(EnumerationTypeDescriptor { members, .. })) => { let val = val.deref(); - if members.iter().any(|c| c == val) { - Ok(()) - } else { - let members = { - let mut members = members - .iter() - .map(|c| format!("'{c}'")) - .collect::>(); - members.sort_unstable(); - members.join(", ") - }; - Err(InvalidReferenceError::with_message(format!( - "Expected one of: {members}, while enum value '{val}' was provided" - ))) - } + check_enum(val, members) } // TODO(tailhook) all types (_, desc) => Err(ctx.wrong_type(desc, self.kind())), @@ -249,6 +235,22 @@ impl QueryArg for Value { } } +pub(crate) fn check_enum(variant_name: &str, expected_members: &[String]) -> Result<(), Error> { + if expected_members.iter().any(|c| c == variant_name) { + Ok(()) + } else { + let mut members = expected_members + .into_iter() + .map(|c| format!("'{c}'")) + .collect::>(); + members.sort_unstable(); + let members = members.join(", "); + Err(InvalidReferenceError::with_message(format!( + "Expected one of: {members}, while enum value '{variant_name}' was provided" + ))) + } +} + impl QueryArgs for Value { fn encode(&self, enc: &mut Encoder) -> Result<(), Error> { let codec = enc.ctx.build_codec()?; diff --git a/edgedb-protocol/src/serialization/decode/raw_scalar.rs b/edgedb-protocol/src/serialization/decode/raw_scalar.rs index efbd37e8..6700551b 100644 --- a/edgedb-protocol/src/serialization/decode/raw_scalar.rs +++ b/edgedb-protocol/src/serialization/decode/raw_scalar.rs @@ -87,6 +87,11 @@ impl ScalarArg for &'_ str { fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { + // special case: &str can express an enum variant + if let Descriptor::Enumeration(_) = ctx.get(pos)? { + return Ok(()) + } + check_scalar(ctx, pos, String::uuid(), String::typename()) } fn to_value(&self) -> Result { diff --git a/edgedb-tokio/tests/func/client.rs b/edgedb-tokio/tests/func/client.rs index 64df8524..e8816929 100644 --- a/edgedb-tokio/tests/func/client.rs +++ b/edgedb-tokio/tests/func/client.rs @@ -1,3 +1,4 @@ +use edgedb_protocol::value::{EnumValue, Value}; use edgedb_tokio::Client; use edgedb_errors::NoDataError; use futures_util::stream::{self, StreamExt}; @@ -42,6 +43,33 @@ async fn simple() -> anyhow::Result<()> { client.execute("SELECT 1+1", &()).await?; client.execute("START MIGRATION TO {}; ABORT MIGRATION", &()).await?; + // basic enum param + let enum_query = "SELECT ($0) = 'waiting'"; + assert_eq!( + client.query_required_single::( + enum_query, &(Value::Enum(EnumValue::from("waiting")),) + ).await.unwrap(), + true + ); + + // unsupported: enum param as Value::Str + client.query_required_single::( + enum_query, &(Value::Str("waiting".to_string()), ), + ).await.unwrap_err(); + + // unsupported: enum param as String + client.query_required_single::( + enum_query, &("waiting".to_string(), ), + ).await.unwrap_err(); + + // enum param as &str + assert_eq!( + client.query_required_single::( + enum_query, &("waiting", ), + ).await.unwrap(), + true + ); + Ok(()) } diff --git a/edgedb-tokio/tests/func/dbschema/test.esdl b/edgedb-tokio/tests/func/dbschema/test.esdl index ff355a43..9e8cebc1 100644 --- a/edgedb-tokio/tests/func/dbschema/test.esdl +++ b/edgedb-tokio/tests/func/dbschema/test.esdl @@ -1,4 +1,6 @@ module test { + scalar type State extending enum<'done', 'waiting', 'blocked'>; + type Counter { required property name -> str { constraint std::exclusive;