diff --git a/edgedb-tokio/src/builder.rs b/edgedb-tokio/src/builder.rs index a0be61d6..5893111d 100644 --- a/edgedb-tokio/src/builder.rs +++ b/edgedb-tokio/src/builder.rs @@ -19,6 +19,7 @@ use tokio::fs; use edgedb_protocol::model; use crate::credentials::{Credentials, TlsSecurity}; +use crate::env::{get_env, Env}; use crate::errors::{ClientError, Error, ErrorKind, ResultExt}; use crate::errors::{ClientNoCredentialsError, NoCloudConfigFound}; use crate::errors::{InterfaceError, InvalidArgumentError}; @@ -30,20 +31,50 @@ pub const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(60); pub const DEFAULT_POOL_SIZE: usize = 10; pub const DEFAULT_HOST: &str = "localhost"; pub const DEFAULT_PORT: u16 = 5656; -pub const COMPOUND_ENV_VARS: &[&str] = &[ - "EDGEDB_HOST", - // "EDGEDB_PORT", // port check is special because of Docker - "EDGEDB_CREDENTIALS_FILE", - "EDGEDB_INSTANCE", - "EDGEDB_DSN", -]; const DOMAIN_LABEL_MAX_LENGTH: usize = 63; const CLOUD_INSTANCE_NAME_MAX_LENGTH: usize = DOMAIN_LABEL_MAX_LENGTH - 2 + 1; // "--" -> "/" -static PORT_WARN: std::sync::Once = std::sync::Once::new(); - type Verifier = Arc; +mod sealed { + use super::*; + + /// Helper trait to extract errors and redirect them to the Vec. + pub(super) trait ErrorBuilder { + /// Convert a Result, Error> to an Option. + /// If the result is an error, it is pushed to the Vec. + fn maybe(&mut self, res: Result, Error>) -> Option; + + /// Convert a Result to an Option. + /// If the result is an error, it is pushed to the Vec. + fn check(&mut self, res: Result) -> Option; + } + + impl ErrorBuilder for Vec { + fn maybe(&mut self, res: Result, Error>) -> Option { + match res { + Ok(v) => v, + Err(e) => { + self.push(e); + None + } + } + } + + fn check(&mut self, res: Result) -> Option { + match res { + Ok(v) => Some(v), + Err(e) => { + self.push(e); + None + } + } + } + } +} + +use sealed::ErrorBuilder; + /// Client security mode. #[derive(Default, Debug, Clone, Copy)] pub enum ClientSecurity { @@ -63,6 +94,19 @@ pub enum CloudCerts { Local, } +impl CloudCerts { + pub fn root(&self) -> &'static str { + match self { + // Staging certs retrieved from + // https://letsencrypt.org/docs/staging-environment/#root-certificates + CloudCerts::Staging => include_str!("letsencrypt_staging.pem"), + // Local nebula development root cert found in + // nebula/infra/terraform/local/ca/root.certificate.pem + CloudCerts::Local => include_str!("nebula_development.pem"), + } + } +} + /// TCP keepalive configuration. #[derive(Default, Debug, Clone, Copy)] pub enum TcpKeepalive { @@ -200,26 +244,6 @@ struct Claims { issuer: Option, } -fn get_env(name: &str) -> Result, Error> { - match env::var(name) { - Ok(v) if v.is_empty() => Ok(None), - Ok(v) => Ok(Some(v)), - Err(env::VarError::NotPresent) => Ok(None), - Err(e) => Err(ClientError::with_source(e) - .context(format!("Cannot decode environment variable {:?}", name))), - } -} - -fn has_port_env() -> bool { - if let Some(port) = env::var_os("EDGEDB_PORT") { - port.to_str() - .map(|s| !s.starts_with("tcp://")) - .unwrap_or(true) - } else { - false - } -} - #[cfg(unix)] fn path_bytes(path: &Path) -> &'_ [u8] { use std::os::unix::ffi::OsStrExt; @@ -968,10 +992,7 @@ impl Builder { let mut conflict = None; if let Some(instance) = &self.instance { conflict = Some("instance"); - read_instance(cfg, instance) - .await - .map_err(|e| errors.push(e)) - .ok(); + errors.check(read_instance(cfg, instance).await); } if let Some(dsn) = &self.dsn { if let Some(conflict) = conflict { @@ -991,10 +1012,7 @@ impl Builder { ))); } conflict = Some("credentials_file"); - read_credentials(cfg, credentials_file) - .await - .map_err(|e| errors.push(e)) - .ok(); + errors.check(read_credentials(cfg, credentials_file).await); } if let Some(credentials) = &self.credentials { if let Some(conflict) = conflict { @@ -1004,9 +1022,7 @@ impl Builder { ))); } conflict = Some("credentials"); - set_credentials(cfg, credentials) - .map_err(|e| errors.push(e)) - .ok(); + errors.check(set_credentials(cfg, credentials)); } if let Some(host) = &self.host { if let Some(conflict) = conflict { @@ -1077,9 +1093,8 @@ impl Builder { } if let Some(tls_ca_file) = &self.tls_ca_file { - match read_certificates(tls_ca_file).await { - Ok(pem) => cfg.pem_certificates = Some(pem), - Err(e) => errors.push(e), + if let Some(pem) = errors.check(read_certificates(tls_ca_file).await) { + cfg.pem_certificates = Some(pem) } } @@ -1096,127 +1111,75 @@ impl Builder { } } - async fn compound_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) { - // Due to how shared-test-cases are implemented we have to check for - // conflicts first and then do the actual parsing - let mut conflict = None; - let mut check_conflict = |var_name: &'static str| { - if env::var_os(var_name).is_some() { - if let Some(cvar) = conflict { - errors.push(ClientError::with_message(format!( - "{} conflicts with {}", - var_name, cvar - ))); - } - conflict = Some(var_name); - } - }; - check_conflict("EDGEDB_INSTANCE"); - check_conflict("EDGEDB_DSN"); - check_conflict("EDGEDB_CREDENTIALS_FILE"); - check_conflict("EDGEDB_HOST"); - if let Some(port) = env::var_os("EDGEDB_PORT") { - if !port - .to_str() - .map(|s| s.starts_with("tcp://")) - .unwrap_or(false) - { - if let Some(cvar) = conflict { - if cvar != "EDGEDB_HOST" { - errors.push(ClientError::with_message(format!( - "{} conflicts with {}", - "EDGEDB_PORT", cvar - ))); - } - } - } - // note: not setting conflict to work with HOST + async fn compound_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) -> bool { + let instance = Env::instance(); + let dsn = Env::dsn(); + let credentials_file = Env::credentials_file(); + let host = Env::host(); + let port = Env::port(); + + fn has(opt: &Result, Error>) -> bool { + opt.as_ref().map(|s| s.as_ref()).ok().flatten().is_some() } - let str_env = |var_name: &'static str, errors: &mut Vec| { - get_env(var_name).map_err(|e| errors.push(e)).ok().flatten() - }; - if let Some(instance) = str_env("EDGEDB_INSTANCE", errors) { - match instance.parse() { - Ok(instance) => { - read_instance(cfg, &instance) - .await - .map_err(|e| errors.push(e)) - .ok(); - } - Err(e) => { - errors.push(ClientError::with_source(e).context("EDGEDB_INSTANCE is invalid")); - } - } + let groups = [ + (has(&instance), "GEL_INSTANCE"), + (has(&dsn), "GEL_DSN"), + (has(&credentials_file), "GEL_CREDENTIALS_FILE"), + (has(&host) || has(&port), "GEL_HOST or GEL_PORT"), + ]; + + let has_envs = groups + .into_iter() + .filter_map(|(has, name)| if has { Some(name) } else { None }) + .collect::>(); + + if has_envs.len() > 1 { + errors.push(InvalidArgumentError::with_message(format!( + "environment variable {} conflicts with {}", + has_envs[0], + has_envs[1..].join(", "), + ))); } - if let Some(dsn) = str_env("EDGEDB_DSN", errors) { - match dsn.parse() { - Ok(url) => self.read_dsn(cfg, &url, errors).await, - Err(e) => { - errors.push(ClientError::with_source(e).context("EDGEDB_DSN is invalid")); - } - } + + if let Some(instance) = errors.maybe(instance) { + errors.check(read_instance(cfg, &instance).await); } - if let Some(fpath) = str_env("EDGEDB_CREDENTIALS_FILE", errors) { - read_credentials(cfg, fpath) - .await - .map_err(|e| errors.push(e)) - .ok(); + if let Some(dsn) = errors.maybe(dsn) { + self.read_dsn(cfg, &dsn, errors).await } - if let Some(host) = str_env("EDGEDB_HOST", errors) { - match validate_host(&host) { - Ok(_) => { - cfg.address = Address::Tcp((host, DEFAULT_PORT)); - } - Err(e) => errors.push(e.context("EDGEDB_HOST is invalid")), - } + if let Some(fpath) = errors.maybe(credentials_file) { + errors.check(read_credentials(cfg, fpath).await); } - if let Some(port_str) = str_env("EDGEDB_PORT", errors) { - let port = port_str - .parse() - .map_err(ClientError::with_source) - .and_then(validate_port) - .context("EDGEDB_PORT is invalid"); - match port { - Ok(port) => { - if let Address::Tcp((_, ref mut portref)) = &mut cfg.address { - *portref = port - } - } - Err(e) => { - if port_str.starts_with("tcp://") { - PORT_WARN.call_once(|| { - log::warn!( - "Environment variable `EDGEDB_PORT` \ - contains docker-link-like definition. \ - Ignoring..." - ); - }); - } else { - errors.push(e); - } - } + if let Some(host) = errors.maybe(host) { + cfg.address = Address::Tcp((host, DEFAULT_PORT)); + } + if let Some(port) = errors.maybe(port) { + if let Address::Tcp((_, ref mut portref)) = &mut cfg.address { + *portref = port.into(); } } + + // This code needs a total rework... + + // Because an incomplete configuration trumps errors, we return "complete" if + // there are errors, so those errors can be reported. + !has_envs.is_empty() || !errors.is_empty() } async fn secret_key_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) { - cfg.secret_key = self.secret_key.clone().or_else(|| { - get_env("EDGEDB_SECRET_KEY") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - }); + cfg.secret_key = self + .secret_key + .clone() + .or_else(|| errors.maybe(Env::secret_key())); } async fn granular_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) { let database_branch = self.database.as_ref().or(self.branch.as_ref()) .cloned() .or_else(|| { - let database = get_env("EDGEDB_DATABASE") - .map_err(|e| errors.push(e)).ok()?; - let branch = get_env("EDGEDB_BRANCH") - .map_err(|e| errors.push(e)).ok()?; + let database = errors.maybe(Env::database()); + let branch = errors.maybe(Env::branch()); if database.is_some() && branch.is_some() { errors.push(InvalidArgumentError::with_message( @@ -1232,103 +1195,52 @@ impl Builder { cfg.branch = name; } - let user = self.user.clone().or_else(|| { - get_env("EDGEDB_USER") - .and_then(|v| v.map(validate_user).transpose()) - .map_err(|e| errors.push(e)) - .ok() - .flatten() - }); + let user = self.user.clone().or_else(|| errors.maybe(Env::user())); if let Some(user) = user { cfg.user = user; } - let tls_server_name = self.tls_server_name.clone().or_else(|| { - get_env("EDGEDB_TLS_SERVER_NAME") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - }); + let tls_server_name = self + .tls_server_name + .clone() + .or_else(|| errors.maybe(Env::tls_server_name())); if let Some(tls_server_name) = tls_server_name { cfg.tls_server_name = Some(tls_server_name); } - let password = self.password.clone().or_else(|| { - get_env("EDGEDB_PASSWORD") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - }); + let password = self + .password + .clone() + .or_else(|| errors.maybe(Env::password())); if let Some(password) = password { cfg.password = Some(password); } - let tls_ca_file = self.tls_ca_file.clone().or_else(|| { - get_env("EDGEDB_TLS_CA_FILE") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - .map(|p| p.into()) - }); + let tls_ca_file = self + .tls_ca_file + .clone() + .or_else(|| errors.maybe(Env::tls_ca_file())); if let Some(tls_ca_file) = tls_ca_file { - match read_certificates(tls_ca_file).await { - Ok(pem) => cfg.pem_certificates = Some(pem), - Err(e) => errors.push(e), + if let Some(pem) = errors.check(read_certificates(tls_ca_file).await) { + cfg.pem_certificates = Some(pem) } } - let tls_ca = get_env("EDGEDB_TLS_CA") - .map_err(|e| errors.push(e)) - .ok() - .flatten(); + let tls_ca = errors.maybe(Env::tls_ca()); if let Some(pem) = tls_ca { - match validate_certs(&pem) { - Ok(()) => cfg.pem_certificates = Some(pem), - Err(e) => errors.push(e), + if let Some(()) = errors.check(validate_certs(&pem)) { + cfg.pem_certificates = Some(pem) } } - let security = get_env("EDGEDB_CLIENT_TLS_SECURITY") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - .and_then(|x| { - x.parse::() - .map_err(|e| { - errors.push(e.context("EDGEDB_CLIENT_TLS_SECURITY error")); - }) - .ok() - }); + let security = errors.maybe(Env::client_tls_security()); if let Some(security) = security { cfg.tls_security = security; } - let wait = self.wait_until_available.or_else(|| { - get_env("EDGEDB_WAIT_UNTIL_AVAILABLE") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - .and_then(|x| { - x.parse::() - .map_err(|e| { - errors.push( - ClientError::with_source(e) - .context("EDGEDB_WAIT_UNTIL_AVAILABLE error"), - ); - }) - .ok() - }) - .and_then(|x| { - x.try_into() - .map_err(|e| { - errors.push( - ClientError::with_source(e) - .context("EDGEDB_WAIT_UNTIL_AVAILABLE error"), - ); - }) - .ok() - }) - }); + let wait = self + .wait_until_available + .or_else(|| errors.maybe(Env::wait_until_available())); if let Some(wait) = wait { cfg.wait = wait; } @@ -1342,37 +1254,23 @@ impl Builder { return; } }; - let host = dsn - .retrieve_host() - .await - .map_err(|e| errors.push(e)) - .ok() - .flatten() + let host = errors + .maybe(dsn.retrieve_host().await) .unwrap_or_else(|| DEFAULT_HOST.into()); - let port = dsn - .retrieve_port() - .await - .map_err(|e| errors.push(e)) - .ok() - .flatten() + let port = errors + .maybe(dsn.retrieve_port().await) .unwrap_or(DEFAULT_PORT); - match dsn.retrieve_tls_server_name().await { - Ok(Some(value)) => cfg.tls_server_name = Some(value), - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_tls_server_name().await) { + cfg.tls_server_name = Some(value) } cfg.address = Address::Tcp((host, port)); cfg.admin = dsn.admin; - match dsn.retrieve_user().await { - Ok(Some(value)) => cfg.user = value, - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_user().await) { + cfg.user = value } if self.password.is_none() { - match dsn.retrieve_password().await { - Ok(Some(value)) => cfg.password = Some(value), - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_password().await) { + cfg.password = Some(value) } } else { dsn.ignore_value("password"); @@ -1396,62 +1294,43 @@ impl Builder { dsn.retrieve_branch().await }; - match database_or_branch { - Ok(Some(name)) => { + if let Some(name) = errors.maybe(database_or_branch) { + { cfg.branch.clone_from(&name); cfg.database = name; } - Ok(None) => {} - Err(e) => errors.push(e), } } else { dsn.ignore_value("branch"); dsn.ignore_value("database"); } - match dsn.retrieve_secret_key().await { - Ok(Some(value)) => cfg.secret_key = Some(value), - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_secret_key().await) { + cfg.secret_key = Some(value) } if self.tls_ca_file.is_none() { - match dsn.retrieve_tls_ca_file().await { - Ok(Some(path)) => match read_certificates(&path).await { - Ok(pem) => cfg.pem_certificates = Some(pem), - Err(e) => errors.push(e), - }, - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(path) = errors.maybe(dsn.retrieve_tls_ca_file().await) { + if let Some(pem) = errors.check(read_certificates(&path).await) { + cfg.pem_certificates = Some(pem) + } } } else { dsn.ignore_value("tls_ca_file"); } - match dsn.retrieve_tls_security().await { - Ok(Some(value)) => cfg.tls_security = value, - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_tls_security().await) { + cfg.tls_security = value } - match dsn.retrieve_wait_until_available().await { - Ok(Some(value)) => cfg.wait = value, - Ok(None) => {} - Err(e) => errors.push(e), + if let Some(value) = errors.maybe(dsn.retrieve_wait_until_available().await) { + cfg.wait = value } cfg.extra_dsn_query_args = dsn.remaining_queries(); } async fn read_project(&self, cfg: &mut ConfigInner, errors: &mut Vec) -> bool { - let pair = self - ._get_stash_path() - .await - .map_err(|e| errors.push(e)) - .ok() - .flatten(); + let pair = errors.maybe(self._get_stash_path().await); if let Some((project, stash)) = pair { - self._read_project(cfg, &project, &stash) - .await - .map_err(|e| errors.push(e)) - .ok(); + errors.check(self._read_project(cfg, &project, &stash).await); true } else { false @@ -1574,12 +1453,10 @@ impl Builder { verifier: Arc::new(tls::NullVerifier), }; - cfg.cloud_profile = self.cloud_profile.clone().or_else(|| { - get_env("EDGEDB_CLOUD_PROFILE") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - }); + cfg.cloud_profile = self + .cloud_profile + .clone() + .or_else(|| errors.maybe(Env::cloud_profile())); let complete = if self.host.is_some() || self.port.is_some() @@ -1593,53 +1470,32 @@ impl Builder { self.compound_owned(&mut cfg, &mut errors).await; self.granular_owned(&mut cfg, &mut errors).await; true - } else if COMPOUND_ENV_VARS.iter().any(|x| env::var_os(x).is_some()) || has_port_env() { - self.secret_key_env(&mut cfg, &mut errors).await; - self.compound_env(&mut cfg, &mut errors).await; - self.granular_env(&mut cfg, &mut errors).await; - true } else { self.secret_key_env(&mut cfg, &mut errors).await; - let complete = self.read_project(&mut cfg, &mut errors).await; + let complete = if self.compound_env(&mut cfg, &mut errors).await { + true + } else { + self.read_project(&mut cfg, &mut errors).await + }; self.granular_env(&mut cfg, &mut errors).await; complete }; - let security = get_env("EDGEDB_CLIENT_SECURITY") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - .and_then(|x| { - x.parse::() - .map_err(|e| { - errors.push(e.context("EDGEDB_CLIENT_SECURITY error")); - }) - .ok() - }); + let security = errors.maybe(Env::client_security()); + if let Some(security) = security { cfg.client_security = security; } - let cloud_certs = get_env("_EDGEDB_CLOUD_CERTS") - .map_err(|e| errors.push(e)) - .ok() - .flatten() - .and_then(|x| { - x.parse::() - .map_err(|e| { - errors.push(e.context("_EDGEDB_CLOUD_CERTS error")); - }) - .ok() - }); + let cloud_certs = errors.maybe(Env::_cloud_certs()); if let Some(cloud_certs) = cloud_certs { cfg.cloud_certs = Some(cloud_certs); } // we don't overwrite this param in cfg because we want // `with_pem_certificates` to bump security to Strict - let tls_security = cfg - .compute_tls_security() - .map_err(|e| errors.push(e)) + let tls_security = errors + .check(cfg.compute_tls_security()) .unwrap_or(TlsSecurity::Strict); cfg.verifier = cfg.make_verifier(tls_security); @@ -2088,16 +1944,8 @@ impl ConfigInner { roots: webpki_roots::TLS_SERVER_ROOTS.into(), }; if let Some(certs) = self.cloud_certs { - let data = match certs { - // Staging certs retrieved from - // https://letsencrypt.org/docs/staging-environment/#root-certificates - CloudCerts::Staging => include_str!("letsencrypt_staging.pem"), - // Local nebula development root cert found in - // nebula/infra/terraform/local/ca/root.certificate.pem - CloudCerts::Local => include_str!("nebula_development.pem"), - }; root_store.extend( - tls::read_root_cert_pem(data) + tls::read_root_cert_pem(certs.root()) .expect("embedded certs are correct") .roots, ); diff --git a/edgedb-tokio/src/env.rs b/edgedb-tokio/src/env.rs new file mode 100644 index 00000000..7c3e7c9b --- /dev/null +++ b/edgedb-tokio/src/env.rs @@ -0,0 +1,226 @@ +use std::fmt::Debug; +use std::io; +use std::num::NonZeroU16; +use std::time::Duration; +use std::{env, path::PathBuf, str::FromStr}; + +use edgedb_protocol::model; +use url::Url; + +use crate::errors::{ClientError, Error, ErrorKind}; +use crate::{builder::CloudCerts, ClientSecurity, InstanceName, TlsSecurity}; + +#[cfg_attr(feature = "unstable", macro_export)] +macro_rules! define_env { + ($( + #[doc=$doc:expr] + #[env($($env_name:expr),+)] + $(#[preprocess=$preprocess:expr])? + $(#[parse=$parse:expr])? + $(#[validate=$validate:expr])? + $name:ident: $type:ty + ),* $(,)?) => { + #[derive(Debug, Clone)] + pub struct Env { + } + + impl Env { + $( + #[doc = $doc] + pub fn $name() -> ::std::result::Result<::std::option::Option<$type>, $crate::Error> { + const ENV_NAMES: &[&str] = &[$(stringify!($env_name)),+]; + let Some((name, s)) = $crate::env::get_envs(ENV_NAMES)? else { + return Ok(None); + }; + $(let Some(s) = $preprocess(s) else { + return Ok(None); + };)? + + // This construct lets us choose between $parse and std::str::FromStr + // without requiring all types to implement FromStr. + #[allow(unused_labels)] + let value: $type = 'block: { + $( + break 'block $parse(&name, &s)?; + + // Disable the fallback parser + #[cfg(all(debug_assertions, not(debug_assertions)))] + )? + $crate::env::parse(&name, &s)? + }; + + $($validate(name, &value)?;)? + Ok(Some(value)) + } + )* + } + }; +} + +define_env!( + /// The host to connect to. + #[env(GEL_HOST, EDGEDB_HOST)] + #[validate=validate_host] + host: String, + + /// The port to connect to. + #[env(GEL_PORT, EDGEDB_PORT)] + #[preprocess=ignore_docker_tcp_port] + port: NonZeroU16, + + /// The database name to connect to. + #[env(GEL_DATABASE, EDGEDB_DATABASE)] + #[validate=non_empty_string] + database: String, + + /// The branch name to connect to. + #[env(GEL_BRANCH, EDGEDB_BRANCH)] + #[validate=non_empty_string] + branch: String, + + /// The username to connect as. + #[env(GEL_USER, EDGEDB_USER)] + #[validate=non_empty_string] + user: String, + + /// The password to use for authentication. + #[env(GEL_PASSWORD, EDGEDB_PASSWORD)] + password: String, + + /// TLS server name to verify. + #[env(GEL_TLS_SERVER_NAME, EDGEDB_TLS_SERVER_NAME)] + tls_server_name: String, + + /// Path to credentials file. + #[env(GEL_CREDENTIALS_FILE, EDGEDB_CREDENTIALS_FILE)] + credentials_file: String, + + /// Instance name to connect to. + #[env(GEL_INSTANCE, EDGEDB_INSTANCE)] + instance: InstanceName, + + /// Connection DSN string. + #[env(GEL_DSN, EDGEDB_DSN)] + dsn: Url, + + /// Secret key for authentication. + #[env(GEL_SECRET_KEY, EDGEDB_SECRET_KEY)] + secret_key: String, + + /// Client security mode. + #[env(GEL_CLIENT_SECURITY, EDGEDB_CLIENT_SECURITY)] + client_security: ClientSecurity, + + /// TLS security mode. + #[env(GEL_CLIENT_TLS_SECURITY, EDGEDB_CLIENT_TLS_SECURITY)] + client_tls_security: TlsSecurity, + + /// Path to TLS CA certificate file. + #[env(GEL_TLS_CA, EDGEDB_TLS_CA)] + tls_ca: String, + + /// Path to TLS CA certificate file. + #[env(GEL_TLS_CA_FILE, EDGEDB_TLS_CA_FILE)] + tls_ca_file: PathBuf, + + /// Cloud profile name. + #[env(GEL_CLOUD_PROFILE, EDGEDB_CLOUD_PROFILE)] + cloud_profile: String, + + /// Cloud certificates mode. + #[env(_GEL_CLOUD_CERTS, _EDGEDB_CLOUD_CERTS)] + _cloud_certs: CloudCerts, + + /// How long to wait for server to become available. + #[env(GEL_WAIT_UNTIL_AVAILABLE, EDGEDB_WAIT_UNTIL_AVAILABLE)] + #[parse=parse_duration] + wait_until_available: Duration, +); + +fn ignore_docker_tcp_port(s: String) -> Option { + static PORT_WARN: std::sync::Once = std::sync::Once::new(); + + if s.starts_with("tcp://") { + PORT_WARN.call_once(|| { + eprintln!("GEL_PORT/EDGEDB_PORT is ignored when using Docker TCP port"); + }); + None + } else { + Some(s) + } +} + +fn non_empty_string(var: &str, s: &str) -> Result<(), Error> { + if s.is_empty() { + Err(create_var_error(var, "empty string")) + } else { + Ok(()) + } +} + +fn validate_host(var: &str, s: &str) -> Result<(), Error> { + if s.is_empty() { + return Err(create_var_error(var, "invalid host: empty string")); + } else if s.contains(',') { + return Err(create_var_error(var, "invalid host: multiple hosts")); + } + Ok(()) +} + +#[inline(never)] +#[doc(hidden)] +pub fn parse(var: &str, s: &str) -> Result +where + ::Err: Debug, +{ + Ok(s.parse().map_err(|e| create_var_error(var, e))?) +} + +#[inline(never)] +pub(crate) fn get_env(name: &str) -> Result, Error> { + let var = env::var(name); + match var { + Ok(v) if v.is_empty() => Ok(None), + Ok(v) => Ok(Some(v)), + Err(env::VarError::NotPresent) => Ok(None), + Err(e) => Err(create_var_error(name, e)), + } +} + +#[inline(never)] +#[doc(hidden)] +pub fn get_envs(names: &'static [&'static str]) -> Result, Error> { + let mut value = None; + let mut found_vars = Vec::new(); + + for name in names { + if let Some(val) = get_env(name)? { + found_vars.push(format!("{}={}", name, val)); + if value.is_none() { + value = Some((*name, val)); + } + } + } + + if found_vars.len() > 1 { + log::warn!( + "Multiple environment variables set: {}", + found_vars.join(", ") + ); + } + + Ok(value) +} + +fn parse_duration(var: &str, s: &str) -> Result { + let duration = model::Duration::from_str(s).map_err(|e| create_var_error(var, e))?; + + duration.try_into().map_err(|e| create_var_error(var, e)) +} + +fn create_var_error(var: &str, e: impl Debug) -> Error { + ClientError::with_source(io::Error::new( + io::ErrorKind::InvalidInput, + format!("{var} is invalid: {e:?}"), + )) +} diff --git a/edgedb-tokio/src/lib.rs b/edgedb-tokio/src/lib.rs index 9b88829a..3f75f8f8 100644 --- a/edgedb-tokio/src/lib.rs +++ b/edgedb-tokio/src/lib.rs @@ -111,25 +111,28 @@ reader. warn(missing_docs, missing_debug_implementations) )] -#[cfg(feature = "unstable")] -pub mod credentials; -#[cfg(feature = "unstable")] -pub mod raw; -#[cfg(feature = "unstable")] -pub mod server_params; -#[cfg(feature = "unstable")] -pub mod tls; - -#[cfg(not(feature = "unstable"))] -mod credentials; -#[cfg(not(feature = "unstable"))] -mod raw; -#[cfg(not(feature = "unstable"))] -mod server_params; -#[cfg(not(feature = "unstable"))] -mod tls; - -mod builder; +macro_rules! unstable_pub_mods { + ($(mod $mod_name:ident;)*) => { + $( + #[cfg(feature = "unstable")] + pub mod $mod_name; + #[cfg(not(feature = "unstable"))] + mod $mod_name; + )* + } +} + +// If the unstable feature is enabled, the modules will be public. +// If the unstable feature is not enabled, the modules will be private. +unstable_pub_mods! { + mod builder; + mod credentials; + mod raw; + mod server_params; + mod tls; + mod env; +} + mod client; mod errors; mod options;