Skip to content

Commit

Permalink
Branch connection argument support (#302)
Browse files Browse the repository at this point in the history
* support branch argument

* clarify `branch` and `database` in credentials

* fix logic of picking branch over database in DSN

* remove async closure

* fix tests
  • Loading branch information
quinchs authored Mar 19, 2024
1 parent 2c5ba18 commit 0bf0949
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
/edgeql_python.cpython-*.so
__pycache__
/Cargo.lock
/.idea
113 changes: 108 additions & 5 deletions edgedb-tokio/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub struct Builder {
unix_path: Option<PathBuf>,
user: Option<String>,
database: Option<String>,
branch: Option<String>,
password: Option<String>,
tls_ca_file: Option<PathBuf>,
tls_security: Option<TlsSecurity>,
Expand Down Expand Up @@ -109,6 +110,7 @@ pub(crate) struct ConfigInner {
pub secret_key: Option<String>,
pub cloud_profile: Option<String>,
pub database: String,
pub branch: String,
pub verifier: Verifier,
pub wait: Duration,
pub connect_timeout: Duration,
Expand Down Expand Up @@ -538,6 +540,21 @@ impl<'a> DsnHelper<'a> {
}).await
}

async fn retrieve_branch(&mut self) -> Result<Option<String>, Error> {
let v = self.url.path().strip_prefix("/").and_then(|s| {
if s.is_empty() {
None
} else {
Some(s.to_owned())
}
});
self.retrieve_value("branch", v, |s| {
let s = s.strip_prefix("/").unwrap_or(&s);
validate_branch(&s)?;
Ok(s.to_owned())
}).await
}

async fn retrieve_secret_key(&mut self) -> Result<Option<String>, Error> {
self.retrieve_value("secret_key", None, |s| Ok(s)).await
}
Expand Down Expand Up @@ -678,6 +695,13 @@ impl Builder {
Ok(self)
}

/// Set the branch name.
pub fn branch(&mut self, branch: &str) -> Result<&mut Self, Error> {
validate_branch(branch)?;
self.branch = Some(branch.into());
Ok(self)
}

/// Set certificate authority for TLS from file
///
/// Note: file is not read immediately but is read when configuration is
Expand Down Expand Up @@ -817,6 +841,9 @@ impl Builder {
database: self.database.clone()
.or_else(|| creds.map(|c| c.database.clone()).flatten())
.unwrap_or_else(|| "edgedb".into()),
branch: self.branch.clone()
.or_else(|| creds.map(|c| c.branch.clone()).flatten())
.unwrap_or_else(|| "__default__".into()),
instance_name: None,
wait: self.wait_until_available.unwrap_or(DEFAULT_WAIT),
connect_timeout: self.connect_timeout
Expand Down Expand Up @@ -934,6 +961,13 @@ impl Builder {
let full_path = resolve_unix(unix_path, port, self.admin);
cfg.address = Address::Unix(full_path);
}
if let Some(database) = &self.database {
if let Some(branch) = &self.branch {
errors.push(InvalidArgumentError::with_message(format!(
"database {} conflicts with branch {}", database, branch
)))
}
}
}

async fn granular_owned(&self, cfg: &mut ConfigInner,
Expand All @@ -943,6 +977,10 @@ impl Builder {
cfg.database = database.clone();
}

if let Some(branch) = &self.branch {
cfg.branch = branch.clone();
}

if let Some(user) = &self.user {
cfg.user = user.clone();
}
Expand Down Expand Up @@ -1093,6 +1131,16 @@ impl Builder {
cfg.database = database;
}

let branch = self.branch.clone().or_else(|| {
get_env("EDGEDB_BRANCH")
.and_then(|v| v.map(validate_branch).transpose())
.map_err(|e| errors.push(e)).ok().flatten()
});

if let Some(branch) = branch {
cfg.branch = branch;
}

let user = self.user.clone().or_else(|| {
get_env("EDGEDB_USER")
.and_then(|v| v.map(validate_user).transpose())
Expand Down Expand Up @@ -1202,15 +1250,46 @@ impl Builder {
} else {
dsn.ignore_value("password");
}
if self.database.is_none() {

let has_branch_option = dsn.query.contains_key("branch") || dsn.query.contains_key("branch_env") || dsn.query.contains_key("branch_file");
let has_database_option = dsn.query.contains_key("database") || dsn.query.contains_key("database_env") || dsn.query.contains_key("database_file");

if has_branch_option {
if has_database_option {
errors.push(InvalidArgumentError::with_message(
"Invalid DSN: `database` and `branch` cannot be present at the same time"
));
} else if self.database.is_some() {
errors.push(InvalidArgumentError::with_message(
"`branch` in DSN and `database` are mutually exclusive"
));
} else {
match dsn.retrieve_branch().await {
Ok(Some(value)) => cfg.branch = value,
Ok(None) => {},
Err(e) => errors.push(e)
}
}
} else if self.branch.is_some() {
if has_database_option {
errors.push(InvalidArgumentError::with_message(
"`database` in DSN and `branch` are mutually exclusive"
));
} else {
match dsn.retrieve_branch().await {
Ok(Some(value)) => cfg.branch = value,
Ok(None) => {},
Err(e) => errors.push(e)
}
}
} else {
match dsn.retrieve_database().await {
Ok(Some(value)) => cfg.database = value,
Ok(None) => {},
Err(e) => errors.push(e),
Err(e) => errors.push(e)
}
} else {
dsn.ignore_value("database");
}

match dsn.retrieve_secret_key().await {
Ok(Some(value)) => cfg.secret_key = Some(value),
Ok(None) => {},
Expand Down Expand Up @@ -1351,6 +1430,7 @@ impl Builder {
cloud_profile: None,
cloud_certs: None,
database: "edgedb".into(),
branch: "__default__".into(),
instance_name: None,
wait: self.wait_until_available.unwrap_or(DEFAULT_WAIT),
connect_timeout: self.connect_timeout
Expand Down Expand Up @@ -1566,6 +1646,7 @@ fn set_credentials(cfg: &mut ConfigInner, creds: &Credentials)
cfg.user = creds.user.clone();
cfg.password = creds.password.clone();
cfg.database = creds.database.clone().unwrap_or_else(|| "edgedb".into());
cfg.branch = creds.database.clone().unwrap_or_else(|| "__default__".into());
cfg.tls_security = creds.tls_security;
cfg.creds_file_outdated = creds.file_outdated;
Ok(())
Expand Down Expand Up @@ -1602,6 +1683,15 @@ fn validate_port(port: u16) -> Result<u16, Error> {
Ok(port)
}

fn validate_branch<T: AsRef<str>>(branch: T) -> Result<T, Error> {
if branch.as_ref().is_empty() {
return Err(InvalidArgumentError::with_message(
"invalid branch: empty string"
));
}
Ok(branch)
}

fn validate_database<T: AsRef<str>>(database: T) -> Result<T, Error> {
if database.as_ref().is_empty() {
return Err(InvalidArgumentError::with_message(
Expand Down Expand Up @@ -1658,7 +1748,8 @@ impl Config {
port: *port,
user: self.0.user.clone(),
password: self.0.password.clone(),
database: Some( self.0.database.clone()),
database: if self.0.branch == "__default__" { Some(self.0.database.clone()) } else { None },
branch: if self.0.branch == "__default__" { None } else { Some(self.0.branch.clone()) },
tls_ca: self.0.pem_certificates.clone(),
tls_security: self.0.tls_security,
file_outdated: false,
Expand All @@ -1674,6 +1765,7 @@ impl Config {
Address::Unix(path) => serde_json::json!(path.to_str().unwrap()),
},
"database": self.0.database,
"branch": self.0.branch,
"user": self.0.user,
"password": self.0.password,
"secretKey": self.0.secret_key,
Expand Down Expand Up @@ -1756,6 +1848,16 @@ impl Config {
Ok(self)
}

pub fn with_branch(mut self, branch: &str) -> Result<Config, Error> {
if branch.is_empty() {
return Err(InvalidArgumentError::with_message(
"invalid branch: empty string"
));
}
Arc::make_mut(&mut self.0).branch = branch.to_owned();
Ok(self)
}

/// Return the same config with changed wait until available timeout
#[cfg(any(feature="unstable", feature="test"))]
pub fn with_wait_until_available(mut self, wait: Duration) -> Config {
Expand Down Expand Up @@ -1969,6 +2071,7 @@ async fn from_dsn() {
));
assert_eq!(&cfg.0.user, "user1");
assert_eq!(&cfg.0.database, "db2");
assert_eq!(&cfg.0.branch, "__default__");
assert_eq!(cfg.0.password, Some("EiPhohl7".into()));

let cfg = Builder::new()
Expand Down
6 changes: 6 additions & 0 deletions edgedb-tokio/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct Credentials {
pub user: String,
pub password: Option<String>,
pub database: Option<String>,
pub branch: Option<String>,
pub tls_ca: Option<String>,
pub tls_security: TlsSecurity,
pub(crate) file_outdated: bool,
Expand All @@ -56,6 +57,8 @@ struct CredentialsCompat {
#[serde(default, skip_serializing_if="Option::is_none")]
database: Option<String>,
#[serde(default, skip_serializing_if="Option::is_none")]
branch: Option<String>,
#[serde(default, skip_serializing_if="Option::is_none")]
tls_cert_data: Option<String>, // deprecated
#[serde(default, skip_serializing_if="Option::is_none")]
tls_ca: Option<String>,
Expand Down Expand Up @@ -114,6 +117,7 @@ impl Default for Credentials {
user: "edgedb".into(),
password: None,
database: None,
branch: None,
tls_ca: None,
tls_security: TlsSecurity::Default,
file_outdated: false,
Expand All @@ -133,6 +137,7 @@ impl Serialize for Credentials {
user: self.user.clone(),
password: self.password.clone(),
database: self.database.clone(),
branch: self.branch.clone(),
tls_ca: self.tls_ca.clone(),
tls_cert_data: self.tls_ca.clone(),
tls_security: Some(self.tls_security),
Expand Down Expand Up @@ -192,6 +197,7 @@ impl<'de> Deserialize<'de> for Credentials {
user: creds.user,
password: creds.password,
database: creds.database,
branch: creds.branch,
tls_ca: creds.tls_ca.or(creds.tls_cert_data.clone()),
tls_security: creds.tls_security.unwrap_or(
match creds.tls_verify_hostname {
Expand Down

0 comments on commit 0bf0949

Please sign in to comment.