diff --git a/.gitignore b/.gitignore index b7e20833..9b9c5aa2 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ /edgeql_python.cpython-*.so __pycache__ /Cargo.lock +/.idea diff --git a/edgedb-tokio/src/builder.rs b/edgedb-tokio/src/builder.rs index 9e7ec62d..1aad8ec9 100644 --- a/edgedb-tokio/src/builder.rs +++ b/edgedb-tokio/src/builder.rs @@ -72,6 +72,7 @@ pub struct Builder { unix_path: Option, user: Option, database: Option, + branch: Option, password: Option, tls_ca_file: Option, tls_security: Option, @@ -109,6 +110,7 @@ pub(crate) struct ConfigInner { pub secret_key: Option, pub cloud_profile: Option, pub database: String, + pub branch: String, pub verifier: Verifier, pub wait: Duration, pub connect_timeout: Duration, @@ -538,6 +540,21 @@ impl<'a> DsnHelper<'a> { }).await } + async fn retrieve_branch(&mut self) -> Result, Error> { + let v = self.url.path().strip_prefix("/").and_then(|s| { + if s.is_empty() { + None + } else { + Some(s.to_owned()) + } + }); + self.retrieve_value("branch", v, |s| { + let s = s.strip_prefix("/").unwrap_or(&s); + validate_branch(&s)?; + Ok(s.to_owned()) + }).await + } + async fn retrieve_secret_key(&mut self) -> Result, Error> { self.retrieve_value("secret_key", None, |s| Ok(s)).await } @@ -678,6 +695,13 @@ impl Builder { Ok(self) } + /// Set the branch name. + pub fn branch(&mut self, branch: &str) -> Result<&mut Self, Error> { + validate_branch(branch)?; + self.branch = Some(branch.into()); + Ok(self) + } + /// Set certificate authority for TLS from file /// /// Note: file is not read immediately but is read when configuration is @@ -817,6 +841,9 @@ impl Builder { database: self.database.clone() .or_else(|| creds.map(|c| c.database.clone()).flatten()) .unwrap_or_else(|| "edgedb".into()), + branch: self.branch.clone() + .or_else(|| creds.map(|c| c.branch.clone()).flatten()) + .unwrap_or_else(|| "__default__".into()), instance_name: None, wait: self.wait_until_available.unwrap_or(DEFAULT_WAIT), connect_timeout: self.connect_timeout @@ -934,6 +961,13 @@ impl Builder { let full_path = resolve_unix(unix_path, port, self.admin); cfg.address = Address::Unix(full_path); } + if let Some(database) = &self.database { + if let Some(branch) = &self.branch { + errors.push(InvalidArgumentError::with_message(format!( + "database {} conflicts with branch {}", database, branch + ))) + } + } } async fn granular_owned(&self, cfg: &mut ConfigInner, @@ -943,6 +977,10 @@ impl Builder { cfg.database = database.clone(); } + if let Some(branch) = &self.branch { + cfg.branch = branch.clone(); + } + if let Some(user) = &self.user { cfg.user = user.clone(); } @@ -1093,6 +1131,16 @@ impl Builder { cfg.database = database; } + let branch = self.branch.clone().or_else(|| { + get_env("EDGEDB_BRANCH") + .and_then(|v| v.map(validate_branch).transpose()) + .map_err(|e| errors.push(e)).ok().flatten() + }); + + if let Some(branch) = branch { + cfg.branch = branch; + } + let user = self.user.clone().or_else(|| { get_env("EDGEDB_USER") .and_then(|v| v.map(validate_user).transpose()) @@ -1202,15 +1250,46 @@ impl Builder { } else { dsn.ignore_value("password"); } - if self.database.is_none() { + + let has_branch_option = dsn.query.contains_key("branch") || dsn.query.contains_key("branch_env") || dsn.query.contains_key("branch_file"); + let has_database_option = dsn.query.contains_key("database") || dsn.query.contains_key("database_env") || dsn.query.contains_key("database_file"); + + if has_branch_option { + if has_database_option { + errors.push(InvalidArgumentError::with_message( + "Invalid DSN: `database` and `branch` cannot be present at the same time" + )); + } else if self.database.is_some() { + errors.push(InvalidArgumentError::with_message( + "`branch` in DSN and `database` are mutually exclusive" + )); + } else { + match dsn.retrieve_branch().await { + Ok(Some(value)) => cfg.branch = value, + Ok(None) => {}, + Err(e) => errors.push(e) + } + } + } else if self.branch.is_some() { + if has_database_option { + errors.push(InvalidArgumentError::with_message( + "`database` in DSN and `branch` are mutually exclusive" + )); + } else { + match dsn.retrieve_branch().await { + Ok(Some(value)) => cfg.branch = value, + Ok(None) => {}, + Err(e) => errors.push(e) + } + } + } else { match dsn.retrieve_database().await { Ok(Some(value)) => cfg.database = value, Ok(None) => {}, - Err(e) => errors.push(e), + Err(e) => errors.push(e) } - } else { - dsn.ignore_value("database"); } + match dsn.retrieve_secret_key().await { Ok(Some(value)) => cfg.secret_key = Some(value), Ok(None) => {}, @@ -1351,6 +1430,7 @@ impl Builder { cloud_profile: None, cloud_certs: None, database: "edgedb".into(), + branch: "__default__".into(), instance_name: None, wait: self.wait_until_available.unwrap_or(DEFAULT_WAIT), connect_timeout: self.connect_timeout @@ -1566,6 +1646,7 @@ fn set_credentials(cfg: &mut ConfigInner, creds: &Credentials) cfg.user = creds.user.clone(); cfg.password = creds.password.clone(); cfg.database = creds.database.clone().unwrap_or_else(|| "edgedb".into()); + cfg.branch = creds.database.clone().unwrap_or_else(|| "__default__".into()); cfg.tls_security = creds.tls_security; cfg.creds_file_outdated = creds.file_outdated; Ok(()) @@ -1602,6 +1683,15 @@ fn validate_port(port: u16) -> Result { Ok(port) } +fn validate_branch>(branch: T) -> Result { + if branch.as_ref().is_empty() { + return Err(InvalidArgumentError::with_message( + "invalid branch: empty string" + )); + } + Ok(branch) +} + fn validate_database>(database: T) -> Result { if database.as_ref().is_empty() { return Err(InvalidArgumentError::with_message( @@ -1658,7 +1748,8 @@ impl Config { port: *port, user: self.0.user.clone(), password: self.0.password.clone(), - database: Some( self.0.database.clone()), + database: if self.0.branch == "__default__" { Some(self.0.database.clone()) } else { None }, + branch: if self.0.branch == "__default__" { None } else { Some(self.0.branch.clone()) }, tls_ca: self.0.pem_certificates.clone(), tls_security: self.0.tls_security, file_outdated: false, @@ -1674,6 +1765,7 @@ impl Config { Address::Unix(path) => serde_json::json!(path.to_str().unwrap()), }, "database": self.0.database, + "branch": self.0.branch, "user": self.0.user, "password": self.0.password, "secretKey": self.0.secret_key, @@ -1756,6 +1848,16 @@ impl Config { Ok(self) } + pub fn with_branch(mut self, branch: &str) -> Result { + if branch.is_empty() { + return Err(InvalidArgumentError::with_message( + "invalid branch: empty string" + )); + } + Arc::make_mut(&mut self.0).branch = branch.to_owned(); + Ok(self) + } + /// Return the same config with changed wait until available timeout #[cfg(any(feature="unstable", feature="test"))] pub fn with_wait_until_available(mut self, wait: Duration) -> Config { @@ -1969,6 +2071,7 @@ async fn from_dsn() { )); assert_eq!(&cfg.0.user, "user1"); assert_eq!(&cfg.0.database, "db2"); + assert_eq!(&cfg.0.branch, "__default__"); assert_eq!(cfg.0.password, Some("EiPhohl7".into())); let cfg = Builder::new() diff --git a/edgedb-tokio/src/credentials.rs b/edgedb-tokio/src/credentials.rs index be136765..07628a07 100644 --- a/edgedb-tokio/src/credentials.rs +++ b/edgedb-tokio/src/credentials.rs @@ -38,6 +38,7 @@ pub struct Credentials { pub user: String, pub password: Option, pub database: Option, + pub branch: Option, pub tls_ca: Option, pub tls_security: TlsSecurity, pub(crate) file_outdated: bool, @@ -56,6 +57,8 @@ struct CredentialsCompat { #[serde(default, skip_serializing_if="Option::is_none")] database: Option, #[serde(default, skip_serializing_if="Option::is_none")] + branch: Option, + #[serde(default, skip_serializing_if="Option::is_none")] tls_cert_data: Option, // deprecated #[serde(default, skip_serializing_if="Option::is_none")] tls_ca: Option, @@ -114,6 +117,7 @@ impl Default for Credentials { user: "edgedb".into(), password: None, database: None, + branch: None, tls_ca: None, tls_security: TlsSecurity::Default, file_outdated: false, @@ -133,6 +137,7 @@ impl Serialize for Credentials { user: self.user.clone(), password: self.password.clone(), database: self.database.clone(), + branch: self.branch.clone(), tls_ca: self.tls_ca.clone(), tls_cert_data: self.tls_ca.clone(), tls_security: Some(self.tls_security), @@ -192,6 +197,7 @@ impl<'de> Deserialize<'de> for Credentials { user: creds.user, password: creds.password, database: creds.database, + branch: creds.branch, tls_ca: creds.tls_ca.or(creds.tls_cert_data.clone()), tls_security: creds.tls_security.unwrap_or( match creds.tls_verify_hostname {