From 4214a078436c54adbf857058f44f0d63aaa13a86 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Wed, 13 Nov 2024 11:56:34 -0700 Subject: [PATCH] Support both GEL_ and EDGEDB_ vars (#359) This allows edgedb-cli to pass all of the shared testcases. To simplify this logic we extract most of the env-var parsing to env.rs. This isn't a full cleanup of this code, but it makes it a little easier to parse what's going on. Long term we'll need to untangle this error parsing code as it's quite complex and hard to follow. There's a bit of code here where we need to "lie" that the configuration is complete when we find errors in the environment variables themselves. Ideally this would be done when we surface the errors instead, but that's a more complicated lift. --- edgedb-tokio/src/builder.rs | 494 +++++++++++++----------------------- edgedb-tokio/src/env.rs | 226 +++++++++++++++++ edgedb-tokio/src/lib.rs | 41 +-- 3 files changed, 419 insertions(+), 342 deletions(-) create mode 100644 edgedb-tokio/src/env.rs diff --git a/edgedb-tokio/src/builder.rs b/edgedb-tokio/src/builder.rs index a0be61d6..5893111d 100644 --- a/edgedb-tokio/src/builder.rs +++ b/edgedb-tokio/src/builder.rs @@ -19,6 +19,7 @@ use tokio::fs; use edgedb_protocol::model; use crate::credentials::{Credentials, TlsSecurity}; +use crate::env::{get_env, Env}; use crate::errors::{ClientError, Error, ErrorKind, ResultExt}; use crate::errors::{ClientNoCredentialsError, NoCloudConfigFound}; use crate::errors::{InterfaceError, InvalidArgumentError}; @@ -30,20 +31,50 @@ pub const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(60); pub const DEFAULT_POOL_SIZE: usize = 10; pub const DEFAULT_HOST: &str = "localhost"; pub const DEFAULT_PORT: u16 = 5656; -pub const COMPOUND_ENV_VARS: &[&str] = &[ - "EDGEDB_HOST", - // "EDGEDB_PORT", // port check is special because of Docker - "EDGEDB_CREDENTIALS_FILE", - "EDGEDB_INSTANCE", - "EDGEDB_DSN", -]; const DOMAIN_LABEL_MAX_LENGTH: usize = 63; const CLOUD_INSTANCE_NAME_MAX_LENGTH: usize = DOMAIN_LABEL_MAX_LENGTH - 2 + 1; // "--" -> "/" -static PORT_WARN: std::sync::Once = std::sync::Once::new(); - type Verifier = Arc; +mod sealed { + use super::*; + + /// Helper trait to extract errors and redirect them to the Vec. + pub(super) trait ErrorBuilder { + /// Convert a Result, Error> to an Option. + /// If the result is an error, it is pushed to the Vec. + fn maybe(&mut self, res: Result, Error>) -> Option; + + /// Convert a Result to an Option. + /// If the result is an error, it is pushed to the Vec. + fn check(&mut self, res: Result) -> Option; + } + + impl ErrorBuilder for Vec { + fn maybe(&mut self, res: Result, Error>) -> Option { + match res { + Ok(v) => v, + Err(e) => { + self.push(e); + None + } + } + } + + fn check(&mut self, res: Result) -> Option { + match res { + Ok(v) => Some(v), + Err(e) => { + self.push(e); + None + } + } + } + } +} + +use sealed::ErrorBuilder; + /// Client security mode. #[derive(Default, Debug, Clone, Copy)] pub enum ClientSecurity { @@ -63,6 +94,19 @@ pub enum CloudCerts { Local, } +impl CloudCerts { + pub fn root(&self) -> &'static str { + match self { + // Staging certs retrieved from + // https://letsencrypt.org/docs/staging-environment/#root-certificates + CloudCerts::Staging => include_str!("letsencrypt_staging.pem"), + // Local nebula development root cert found in + // nebula/infra/terraform/local/ca/root.certificate.pem + CloudCerts::Local => include_str!("nebula_development.pem"), + } + } +} + /// TCP keepalive configuration. #[derive(Default, Debug, Clone, Copy)] pub enum TcpKeepalive { @@ -200,26 +244,6 @@ struct Claims { issuer: Option, } -fn get_env(name: &str) -> Result, Error> { - match env::var(name) { - Ok(v) if v.is_empty() => Ok(None), - Ok(v) => Ok(Some(v)), - Err(env::VarError::NotPresent) => Ok(None), - Err(e) => Err(ClientError::with_source(e) - .context(format!("Cannot decode environment variable {:?}", name))), - } -} - -fn has_port_env() -> bool { - if let Some(port) = env::var_os("EDGEDB_PORT") { - port.to_str() - .map(|s| !s.starts_with("tcp://")) - .unwrap_or(true) - } else { - false - } -} - #[cfg(unix)] fn path_bytes(path: &Path) -> &'_ [u8] { use std::os::unix::ffi::OsStrExt; @@ -968,10 +992,7 @@ impl Builder { let mut conflict = None; if let Some(instance) = &self.instance { conflict = Some("instance"); - read_instance(cfg, instance) - .await - .map_err(|e| errors.push(e)) - .ok(); + errors.check(read_instance(cfg, instance).await); } if let Some(dsn) = &self.dsn { if let Some(conflict) = conflict { @@ -991,10 +1012,7 @@ impl Builder { ))); } conflict = Some("credentials_file"); - read_credentials(cfg, credentials_file) - .await - .map_err(|e| errors.push(e)) - .ok(); + errors.check(read_credentials(cfg, credentials_file).await); } if let Some(credentials) = &self.credentials { if let Some(conflict) = conflict { @@ -1004,9 +1022,7 @@ impl Builder { ))); } conflict = Some("credentials"); - set_credentials(cfg, credentials) - .map_err(|e| errors.push(e)) - .ok(); + errors.check(set_credentials(cfg, credentials)); } if let Some(host) = &self.host { if let Some(conflict) = conflict { @@ -1077,9 +1093,8 @@ impl Builder { } if let Some(tls_ca_file) = &self.tls_ca_file { - match read_certificates(tls_ca_file).await { - Ok(pem) => cfg.pem_certificates = Some(pem), - Err(e) => errors.push(e), + if let Some(pem) = errors.check(read_certificates(tls_ca_file).await) { + cfg.pem_certificates = Some(pem) } } @@ -1096,127 +1111,75 @@ impl Builder { } } - async fn compound_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) { - // Due to how shared-test-cases are implemented we have to check for - // conflicts first and then do the actual parsing - let mut conflict = None; - let mut check_conflict = |var_name: &'static str| { - if env::var_os(var_name).is_some() { - if let Some(cvar) = conflict { - errors.push(ClientError::with_message(format!( - "{} conflicts with {}", - var_name, cvar - ))); - } - conflict = Some(var_name); - } - }; - check_conflict("EDGEDB_INSTANCE"); - check_conflict("EDGEDB_DSN"); - check_conflict("EDGEDB_CREDENTIALS_FILE"); - check_conflict("EDGEDB_HOST"); - if let Some(port) = env::var_os("EDGEDB_PORT") { - if !port - .to_str() - .map(|s| s.starts_with("tcp://")) - .unwrap_or(false) - { - if let Some(cvar) = conflict { - if cvar != "EDGEDB_HOST" { - errors.push(ClientError::with_message(format!( - "{} conflicts with {}", - "EDGEDB_PORT", cvar - ))); - } - } - } - // note: not setting conflict to work with HOST + async fn compound_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) -> bool { + let instance = Env::instance(); + let dsn = Env::dsn(); + let credentials_file = Env::credentials_file(); + let host = Env::host(); + let port = Env::port(); + + fn has(opt: &Result, Error>) -> bool { + opt.as_ref().map(|s| s.as_ref()).ok().flatten().is_some() } - let str_env = |var_name: &'static str, errors: &mut Vec| { - get_env(var_name).map_err(|e| errors.push(e)).ok().flatten() - }; - if let Some(instance) = str_env("EDGEDB_INSTANCE", errors) { - match instance.parse() { - Ok(instance) => { - read_instance(cfg, &instance) - .await - .map_err(|e| errors.push(e)) - .ok(); - } - Err(e) => { - errors.push(ClientError::with_source(e).context("EDGEDB_INSTANCE is invalid")); - } - } + let groups = [ + (has(&instance), "GEL_INSTANCE"), + (has(&dsn), "GEL_DSN"), + (has(&credentials_file), "GEL_CREDENTIALS_FILE"), + (has(&host) || has(&port), "GEL_HOST or GEL_PORT"), + ]; + + let has_envs = groups + .into_iter() + .filter_map(|(has, name)| if has { Some(name) } else { None }) + .collect::>(); + + if has_envs.len() > 1 { + errors.push(InvalidArgumentError::with_message(format!( + "environment variable {} conflicts with {}", + has_envs[0], + has_envs[1..].join(", "), + ))); } - if let Some(dsn) = str_env("EDGEDB_DSN", errors) { - match dsn.parse() { - Ok(url) => self.read_dsn(cfg, &url, errors).await, - Err(e) => { - errors.push(ClientError::with_source(e).context("EDGEDB_DSN is invalid")); - } - } + + if let Some(instance) = errors.maybe(instance) { + errors.check(read_instance(cfg, &instance).await); } - if let Some(fpath) = str_env("EDGEDB_CREDENTIALS_FILE", errors) { - read_credentials(cfg, fpath) - .await - .map_err(|e| errors.push(e)) - .ok(); + if let Some(dsn) = errors.maybe(dsn) { + self.read_dsn(cfg, &dsn, errors).await } - if let Some(host) = str_env("EDGEDB_HOST", errors) { - match validate_host(&host) { - Ok(_) => { - cfg.address = Address::Tcp((host, DEFAULT_PORT)); - } - Err(e) => errors.push(e.context("EDGEDB_HOST is invalid")), - } + if let Some(fpath) = errors.maybe(credentials_file) { + errors.check(read_credentials(cfg, fpath).await); } - if let Some(port_str) = str_env("EDGEDB_PORT", errors) { - let port = port_str - .parse() - .map_err(ClientError::with_source) - .and_then(validate_port) - .context("EDGEDB_PORT is invalid"); - match port { - Ok(port) => { - if let Address::Tcp((_, ref mut portref)) = &mut cfg.address { - *portref = port - } - } - Err(e) => { - if port_str.starts_with("tcp://") { - PORT_WARN.call_once(|| { - log::warn!( - "Environment variable `EDGEDB_PORT` \ - contains docker-link-like definition. \ - Ignoring..." - ); - }); - } else { - errors.push(e); - } - } + if let Some(host) = errors.maybe(host) { + cfg.address = Address::Tcp((host, DEFAULT_PORT)); + } + if let Some(port) = errors.maybe(port) { + if let Address::Tcp((_, ref mut portref)) = &mut cfg.address { + *portref = port.into(); } } + + // This code needs a total rework... + + // Because an incomplete configuration trumps errors, we return "complete" if + // there are errors, so those errors can be reported. + !has_envs.is_empty() || !errors.is_empty() } async fn secret_key_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) { - cfg.secret_key = self.secret_key.clone().or_else(|| { - get_env("EDGEDB_SECRET_KEY") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - }); + cfg.secret_key = self + .secret_key + .clone() + .or_else(|| errors.maybe(Env::secret_key())); } async fn granular_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) { let database_branch = self.database.as_ref().or(self.branch.as_ref()) .cloned() .or_else(|| { - let database = get_env("EDGEDB_DATABASE") - .map_err(|e| errors.push(e)).ok()?; - let branch = get_env("EDGEDB_BRANCH") - .map_err(|e| errors.push(e)).ok()?; + let database = errors.maybe(Env::database()); + let branch = errors.maybe(Env::branch()); if database.is_some() && branch.is_some() { errors.push(InvalidArgumentError::with_message( @@ -1232,103 +1195,52 @@ impl Builder { cfg.branch = name; } - let user = self.user.clone().or_else(|| { - get_env("EDGEDB_USER") - .and_then(|v| v.map(validate_user).transpose()) - .map_err(|e| errors.push(e)) - .ok() - .flatten() - }); + let user = self.user.clone().or_else(|| errors.maybe(Env::user())); if let Some(user) = user { cfg.user = user; } - let tls_server_name = self.tls_server_name.clone().or_else(|| { - get_env("EDGEDB_TLS_SERVER_NAME") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - }); + let tls_server_name = self + .tls_server_name + .clone() + .or_else(|| errors.maybe(Env::tls_server_name())); if let Some(tls_server_name) = tls_server_name { cfg.tls_server_name = Some(tls_server_name); } - let password = self.password.clone().or_else(|| { - get_env("EDGEDB_PASSWORD") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - }); + let password = self + .password + .clone() + .or_else(|| errors.maybe(Env::password())); if let Some(password) = password { cfg.password = Some(password); } - let tls_ca_file = self.tls_ca_file.clone().or_else(|| { - get_env("EDGEDB_TLS_CA_FILE") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - .map(|p| p.into()) - }); + let tls_ca_file = self + .tls_ca_file + .clone() + .or_else(|| errors.maybe(Env::tls_ca_file())); if let Some(tls_ca_file) = tls_ca_file { - match read_certificates(tls_ca_file).await { - Ok(pem) => cfg.pem_certificates = Some(pem), - Err(e) => errors.push(e), + if let Some(pem) = errors.check(read_certificates(tls_ca_file).await) { + cfg.pem_certificates = Some(pem) } } - let tls_ca = get_env("EDGEDB_TLS_CA") - .map_err(|e| errors.push(e)) - .ok() - .flatten(); + let tls_ca = errors.maybe(Env::tls_ca()); if let Some(pem) = tls_ca { - match validate_certs(&pem) { - Ok(()) => cfg.pem_certificates = Some(pem), - Err(e) => errors.push(e), + if let Some(()) = errors.check(validate_certs(&pem)) { + cfg.pem_certificates = Some(pem) } } - let security = get_env("EDGEDB_CLIENT_TLS_SECURITY") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - .and_then(|x| { - x.parse::() - .map_err(|e| { - errors.push(e.context("EDGEDB_CLIENT_TLS_SECURITY error")); - }) - .ok() - }); + let security = errors.maybe(Env::client_tls_security()); if let Some(security) = security { cfg.tls_security = security; } - let wait = self.wait_until_available.or_else(|| { - get_env("EDGEDB_WAIT_UNTIL_AVAILABLE") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - .and_then(|x| { - x.parse::() - .map_err(|e| { - errors.push( - ClientError::with_source(e) - .context("EDGEDB_WAIT_UNTIL_AVAILABLE error"), - ); - }) - .ok() - }) - .and_then(|x| { - x.try_into() - .map_err(|e| { - errors.push( - ClientError::with_source(e) - .context("EDGEDB_WAIT_UNTIL_AVAILABLE error"), - ); - }) - .ok() - }) - }); + let wait = self + .wait_until_available + .or_else(|| errors.maybe(Env::wait_until_available())); if let Some(wait) = wait { cfg.wait = wait; } @@ -1342,37 +1254,23 @@ impl Builder { return; } }; - let host = dsn - .retrieve_host() - .await - .map_err(|e| errors.push(e)) - .ok() - .flatten() + let host = errors + .maybe(dsn.retrieve_host().await) .unwrap_or_else(|| DEFAULT_HOST.into()); - let port = dsn - .retrieve_port() - .await - .map_err(|e| errors.push(e)) - .ok() - .flatten() + let port = errors + .maybe(dsn.retrieve_port().await) .unwrap_or(DEFAULT_PORT); - match dsn.retrieve_tls_server_name().await { - Ok(Some(value)) => cfg.tls_server_name = Some(value), - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_tls_server_name().await) { + cfg.tls_server_name = Some(value) } cfg.address = Address::Tcp((host, port)); cfg.admin = dsn.admin; - match dsn.retrieve_user().await { - Ok(Some(value)) => cfg.user = value, - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_user().await) { + cfg.user = value } if self.password.is_none() { - match dsn.retrieve_password().await { - Ok(Some(value)) => cfg.password = Some(value), - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_password().await) { + cfg.password = Some(value) } } else { dsn.ignore_value("password"); @@ -1396,62 +1294,43 @@ impl Builder { dsn.retrieve_branch().await }; - match database_or_branch { - Ok(Some(name)) => { + if let Some(name) = errors.maybe(database_or_branch) { + { cfg.branch.clone_from(&name); cfg.database = name; } - Ok(None) => {} - Err(e) => errors.push(e), } } else { dsn.ignore_value("branch"); dsn.ignore_value("database"); } - match dsn.retrieve_secret_key().await { - Ok(Some(value)) => cfg.secret_key = Some(value), - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_secret_key().await) { + cfg.secret_key = Some(value) } if self.tls_ca_file.is_none() { - match dsn.retrieve_tls_ca_file().await { - Ok(Some(path)) => match read_certificates(&path).await { - Ok(pem) => cfg.pem_certificates = Some(pem), - Err(e) => errors.push(e), - }, - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(path) = errors.maybe(dsn.retrieve_tls_ca_file().await) { + if let Some(pem) = errors.check(read_certificates(&path).await) { + cfg.pem_certificates = Some(pem) + } } } else { dsn.ignore_value("tls_ca_file"); } - match dsn.retrieve_tls_security().await { - Ok(Some(value)) => cfg.tls_security = value, - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_tls_security().await) { + cfg.tls_security = value } - match dsn.retrieve_wait_until_available().await { - Ok(Some(value)) => cfg.wait = value, - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_wait_until_available().await) { + cfg.wait = value } cfg.extra_dsn_query_args = dsn.remaining_queries(); } async fn read_project(&self, cfg: &mut ConfigInner, errors: &mut Vec) -> bool { - let pair = self - ._get_stash_path() - .await - .map_err(|e| errors.push(e)) - .ok() - .flatten(); + let pair = errors.maybe(self._get_stash_path().await); if let Some((project, stash)) = pair { - self._read_project(cfg, &project, &stash) - .await - .map_err(|e| errors.push(e)) - .ok(); + errors.check(self._read_project(cfg, &project, &stash).await); true } else { false @@ -1574,12 +1453,10 @@ impl Builder { verifier: Arc::new(tls::NullVerifier), }; - cfg.cloud_profile = self.cloud_profile.clone().or_else(|| { - get_env("EDGEDB_CLOUD_PROFILE") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - }); + cfg.cloud_profile = self + .cloud_profile + .clone() + .or_else(|| errors.maybe(Env::cloud_profile())); let complete = if self.host.is_some() || self.port.is_some() @@ -1593,53 +1470,32 @@ impl Builder { self.compound_owned(&mut cfg, &mut errors).await; self.granular_owned(&mut cfg, &mut errors).await; true - } else if COMPOUND_ENV_VARS.iter().any(|x| env::var_os(x).is_some()) || has_port_env() { - self.secret_key_env(&mut cfg, &mut errors).await; - self.compound_env(&mut cfg, &mut errors).await; - self.granular_env(&mut cfg, &mut errors).await; - true } else { self.secret_key_env(&mut cfg, &mut errors).await; - let complete = self.read_project(&mut cfg, &mut errors).await; + let complete = if self.compound_env(&mut cfg, &mut errors).await { + true + } else { + self.read_project(&mut cfg, &mut errors).await + }; self.granular_env(&mut cfg, &mut errors).await; complete }; - let security = get_env("EDGEDB_CLIENT_SECURITY") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - .and_then(|x| { - x.parse::() - .map_err(|e| { - errors.push(e.context("EDGEDB_CLIENT_SECURITY error")); - }) - .ok() - }); + let security = errors.maybe(Env::client_security()); + if let Some(security) = security { cfg.client_security = security; } - let cloud_certs = get_env("_EDGEDB_CLOUD_CERTS") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - .and_then(|x| { - x.parse::() - .map_err(|e| { - errors.push(e.context("_EDGEDB_CLOUD_CERTS error")); - }) - .ok() - }); + let cloud_certs = errors.maybe(Env::_cloud_certs()); if let Some(cloud_certs) = cloud_certs { cfg.cloud_certs = Some(cloud_certs); } // we don't overwrite this param in cfg because we want // `with_pem_certificates` to bump security to Strict - let tls_security = cfg - .compute_tls_security() - .map_err(|e| errors.push(e)) + let tls_security = errors + .check(cfg.compute_tls_security()) .unwrap_or(TlsSecurity::Strict); cfg.verifier = cfg.make_verifier(tls_security); @@ -2088,16 +1944,8 @@ impl ConfigInner { roots: webpki_roots::TLS_SERVER_ROOTS.into(), }; if let Some(certs) = self.cloud_certs { - let data = match certs { - // Staging certs retrieved from - // https://letsencrypt.org/docs/staging-environment/#root-certificates - CloudCerts::Staging => include_str!("letsencrypt_staging.pem"), - // Local nebula development root cert found in - // nebula/infra/terraform/local/ca/root.certificate.pem - CloudCerts::Local => include_str!("nebula_development.pem"), - }; root_store.extend( - tls::read_root_cert_pem(data) + tls::read_root_cert_pem(certs.root()) .expect("embedded certs are correct") .roots, ); diff --git a/edgedb-tokio/src/env.rs b/edgedb-tokio/src/env.rs new file mode 100644 index 00000000..7c3e7c9b --- /dev/null +++ b/edgedb-tokio/src/env.rs @@ -0,0 +1,226 @@ +use std::fmt::Debug; +use std::io; +use std::num::NonZeroU16; +use std::time::Duration; +use std::{env, path::PathBuf, str::FromStr}; + +use edgedb_protocol::model; +use url::Url; + +use crate::errors::{ClientError, Error, ErrorKind}; +use crate::{builder::CloudCerts, ClientSecurity, InstanceName, TlsSecurity}; + +#[cfg_attr(feature = "unstable", macro_export)] +macro_rules! define_env { + ($( + #[doc=$doc:expr] + #[env($($env_name:expr),+)] + $(#[preprocess=$preprocess:expr])? + $(#[parse=$parse:expr])? + $(#[validate=$validate:expr])? + $name:ident: $type:ty + ),* $(,)?) => { + #[derive(Debug, Clone)] + pub struct Env { + } + + impl Env { + $( + #[doc = $doc] + pub fn $name() -> ::std::result::Result<::std::option::Option<$type>, $crate::Error> { + const ENV_NAMES: &[&str] = &[$(stringify!($env_name)),+]; + let Some((name, s)) = $crate::env::get_envs(ENV_NAMES)? else { + return Ok(None); + }; + $(let Some(s) = $preprocess(s) else { + return Ok(None); + };)? + + // This construct lets us choose between $parse and std::str::FromStr + // without requiring all types to implement FromStr. + #[allow(unused_labels)] + let value: $type = 'block: { + $( + break 'block $parse(&name, &s)?; + + // Disable the fallback parser + #[cfg(all(debug_assertions, not(debug_assertions)))] + )? + $crate::env::parse(&name, &s)? + }; + + $($validate(name, &value)?;)? + Ok(Some(value)) + } + )* + } + }; +} + +define_env!( + /// The host to connect to. + #[env(GEL_HOST, EDGEDB_HOST)] + #[validate=validate_host] + host: String, + + /// The port to connect to. + #[env(GEL_PORT, EDGEDB_PORT)] + #[preprocess=ignore_docker_tcp_port] + port: NonZeroU16, + + /// The database name to connect to. + #[env(GEL_DATABASE, EDGEDB_DATABASE)] + #[validate=non_empty_string] + database: String, + + /// The branch name to connect to. + #[env(GEL_BRANCH, EDGEDB_BRANCH)] + #[validate=non_empty_string] + branch: String, + + /// The username to connect as. + #[env(GEL_USER, EDGEDB_USER)] + #[validate=non_empty_string] + user: String, + + /// The password to use for authentication. + #[env(GEL_PASSWORD, EDGEDB_PASSWORD)] + password: String, + + /// TLS server name to verify. + #[env(GEL_TLS_SERVER_NAME, EDGEDB_TLS_SERVER_NAME)] + tls_server_name: String, + + /// Path to credentials file. + #[env(GEL_CREDENTIALS_FILE, EDGEDB_CREDENTIALS_FILE)] + credentials_file: String, + + /// Instance name to connect to. + #[env(GEL_INSTANCE, EDGEDB_INSTANCE)] + instance: InstanceName, + + /// Connection DSN string. + #[env(GEL_DSN, EDGEDB_DSN)] + dsn: Url, + + /// Secret key for authentication. + #[env(GEL_SECRET_KEY, EDGEDB_SECRET_KEY)] + secret_key: String, + + /// Client security mode. + #[env(GEL_CLIENT_SECURITY, EDGEDB_CLIENT_SECURITY)] + client_security: ClientSecurity, + + /// TLS security mode. + #[env(GEL_CLIENT_TLS_SECURITY, EDGEDB_CLIENT_TLS_SECURITY)] + client_tls_security: TlsSecurity, + + /// Path to TLS CA certificate file. + #[env(GEL_TLS_CA, EDGEDB_TLS_CA)] + tls_ca: String, + + /// Path to TLS CA certificate file. + #[env(GEL_TLS_CA_FILE, EDGEDB_TLS_CA_FILE)] + tls_ca_file: PathBuf, + + /// Cloud profile name. + #[env(GEL_CLOUD_PROFILE, EDGEDB_CLOUD_PROFILE)] + cloud_profile: String, + + /// Cloud certificates mode. + #[env(_GEL_CLOUD_CERTS, _EDGEDB_CLOUD_CERTS)] + _cloud_certs: CloudCerts, + + /// How long to wait for server to become available. + #[env(GEL_WAIT_UNTIL_AVAILABLE, EDGEDB_WAIT_UNTIL_AVAILABLE)] + #[parse=parse_duration] + wait_until_available: Duration, +); + +fn ignore_docker_tcp_port(s: String) -> Option { + static PORT_WARN: std::sync::Once = std::sync::Once::new(); + + if s.starts_with("tcp://") { + PORT_WARN.call_once(|| { + eprintln!("GEL_PORT/EDGEDB_PORT is ignored when using Docker TCP port"); + }); + None + } else { + Some(s) + } +} + +fn non_empty_string(var: &str, s: &str) -> Result<(), Error> { + if s.is_empty() { + Err(create_var_error(var, "empty string")) + } else { + Ok(()) + } +} + +fn validate_host(var: &str, s: &str) -> Result<(), Error> { + if s.is_empty() { + return Err(create_var_error(var, "invalid host: empty string")); + } else if s.contains(',') { + return Err(create_var_error(var, "invalid host: multiple hosts")); + } + Ok(()) +} + +#[inline(never)] +#[doc(hidden)] +pub fn parse(var: &str, s: &str) -> Result +where + ::Err: Debug, +{ + Ok(s.parse().map_err(|e| create_var_error(var, e))?) +} + +#[inline(never)] +pub(crate) fn get_env(name: &str) -> Result, Error> { + let var = env::var(name); + match var { + Ok(v) if v.is_empty() => Ok(None), + Ok(v) => Ok(Some(v)), + Err(env::VarError::NotPresent) => Ok(None), + Err(e) => Err(create_var_error(name, e)), + } +} + +#[inline(never)] +#[doc(hidden)] +pub fn get_envs(names: &'static [&'static str]) -> Result, Error> { + let mut value = None; + let mut found_vars = Vec::new(); + + for name in names { + if let Some(val) = get_env(name)? { + found_vars.push(format!("{}={}", name, val)); + if value.is_none() { + value = Some((*name, val)); + } + } + } + + if found_vars.len() > 1 { + log::warn!( + "Multiple environment variables set: {}", + found_vars.join(", ") + ); + } + + Ok(value) +} + +fn parse_duration(var: &str, s: &str) -> Result { + let duration = model::Duration::from_str(s).map_err(|e| create_var_error(var, e))?; + + duration.try_into().map_err(|e| create_var_error(var, e)) +} + +fn create_var_error(var: &str, e: impl Debug) -> Error { + ClientError::with_source(io::Error::new( + io::ErrorKind::InvalidInput, + format!("{var} is invalid: {e:?}"), + )) +} diff --git a/edgedb-tokio/src/lib.rs b/edgedb-tokio/src/lib.rs index 9b88829a..3f75f8f8 100644 --- a/edgedb-tokio/src/lib.rs +++ b/edgedb-tokio/src/lib.rs @@ -111,25 +111,28 @@ reader. warn(missing_docs, missing_debug_implementations) )] -#[cfg(feature = "unstable")] -pub mod credentials; -#[cfg(feature = "unstable")] -pub mod raw; -#[cfg(feature = "unstable")] -pub mod server_params; -#[cfg(feature = "unstable")] -pub mod tls; - -#[cfg(not(feature = "unstable"))] -mod credentials; -#[cfg(not(feature = "unstable"))] -mod raw; -#[cfg(not(feature = "unstable"))] -mod server_params; -#[cfg(not(feature = "unstable"))] -mod tls; - -mod builder; +macro_rules! unstable_pub_mods { + ($(mod $mod_name:ident;)*) => { + $( + #[cfg(feature = "unstable")] + pub mod $mod_name; + #[cfg(not(feature = "unstable"))] + mod $mod_name; + )* + } +} + +// If the unstable feature is enabled, the modules will be public. +// If the unstable feature is not enabled, the modules will be private. +unstable_pub_mods! { + mod builder; + mod credentials; + mod raw; + mod server_params; + mod tls; + mod env; +} + mod client; mod errors; mod options;