diff --git a/edgedb-tokio/src/builder.rs b/edgedb-tokio/src/builder.rs index 4a1e0260..9e7ec62d 100644 --- a/edgedb-tokio/src/builder.rs +++ b/edgedb-tokio/src/builder.rs @@ -75,6 +75,7 @@ pub struct Builder { password: Option, tls_ca_file: Option, tls_security: Option, + tls_server_name: Option, client_security: Option, pem_certificates: Option, wait_until_available: Option, @@ -120,6 +121,8 @@ pub(crate) struct ConfigInner { // Pool configuration pub max_concurrency: Option, + pub tls_server_name: Option, + instance_name: Option, tls_security: TlsSecurity, client_security: ClientSecurity, @@ -492,6 +495,10 @@ impl<'a> DsnHelper<'a> { } } + async fn retrieve_tls_server_name(&mut self) -> Result, Error> { + self.retrieve_value("tls_server_name", None, |s| Ok(s)).await + } + async fn retrieve_port(&mut self) -> Result, Error> { self.retrieve_value("port", self.url.port(), |s| { s.parse().map_err(|e| { @@ -623,6 +630,13 @@ impl Builder { Ok(self) } + /// Override server name indication (SNI) in TLS handshake + pub fn tls_server_name(&mut self, tls_server_name: &str) -> Result<&mut Self, Error> { + validate_host(tls_server_name)?; + self.tls_server_name = Some(tls_server_name.to_string()); + Ok(self) + } + /// Set port to connect to pub fn port(&mut self, port: u16) -> Result<&mut Self, Error> { validate_port(port)?; @@ -790,6 +804,7 @@ impl Builder { let creds = self.credentials.as_ref(); let mut cfg = ConfigInner { address, + tls_server_name: self.tls_server_name.clone(), admin: self.admin, user: self.user.clone() .or_else(|| creds.map(|c| c.user.clone())) @@ -936,6 +951,10 @@ impl Builder { cfg.password = Some(password.clone()); } + if let Some(tls_server_name) = &self.tls_server_name { + cfg.tls_server_name = Some(tls_server_name.clone()); + } + if let Some(tls_ca_file) = &self.tls_ca_file { match read_certificates(tls_ca_file).await { Ok(pem) => cfg.pem_certificates = Some(pem), @@ -1083,6 +1102,14 @@ impl Builder { cfg.user = user; } + let tls_server_name = self.tls_server_name.clone().or_else(|| { + get_env("EDGEDB_TLS_SERVER_NAME") + .map_err(|e| errors.push(e)).ok().flatten() + }); + if let Some(tls_server_name) = tls_server_name { + cfg.tls_server_name = Some(tls_server_name); + } + let password = self.password.clone().or_else(|| { get_env("EDGEDB_PASSWORD") .map_err(|e| errors.push(e)).ok().flatten() @@ -1154,6 +1181,11 @@ impl Builder { let port = dsn.retrieve_port().await .map_err(|e| errors.push(e)).ok().flatten() .unwrap_or(DEFAULT_PORT); + match dsn.retrieve_tls_server_name().await { + Ok(Some(value)) => cfg.tls_server_name = Some(value), + Ok(None) => {}, + Err(e) => errors.push(e), + } cfg.address = Address::Tcp((host, port)); cfg.admin = dsn.admin; match dsn.retrieve_user().await { @@ -1311,6 +1343,7 @@ impl Builder { let mut cfg = ConfigInner { address: Address::Tcp((DEFAULT_HOST.into(), DEFAULT_PORT)), + tls_server_name: self.tls_server_name.clone(), admin: self.admin, user: "edgedb".into(), password: None, @@ -1646,6 +1679,7 @@ impl Config { "secretKey": self.0.secret_key, "tlsCAData": self.0.pem_certificates, "tlsSecurity": self.0.compute_tls_security().unwrap(), + "tlsServerName": self.0.tls_server_name, "serverSettings": self.0.extra_dsn_query_args, "waitUntilAvailable": self.0.wait.as_micros() as i64, }).to_string() diff --git a/edgedb-tokio/src/raw/connection.rs b/edgedb-tokio/src/raw/connection.rs index 122f1c79..7f39b5da 100644 --- a/edgedb-tokio/src/raw/connection.rs +++ b/edgedb-tokio/src/raw/connection.rs @@ -335,21 +335,27 @@ async fn connect3(cfg: &Config, tls: &TlsConnectorBox) Address::Tcp(addr@(host,_)) => { let conn = TcpStream::connect(addr).await .map_err(ClientConnectionError::with_source)?; - let is_valid_dns = DnsName::try_from(host.clone()).is_ok(); - let host = if !is_valid_dns { - // FIXME: https://github.com/rustls/rustls/issues/184 - // If self.host is neither an IP address nor a valid DNS - // name, the hacks below won't make it valid anyways. - let host = format!("{}.host-for-ip.edgedb.net", host); - // for ipv6addr - let host = host.replace(":", "-").replace("%", "-"); - if host.starts_with("-") { - Cow::from(format!("i{}", host)) - } else { - Cow::from(host) - } - } else { - Cow::from(host) + let host = match &cfg.0.tls_server_name { + Some(server_name) => { + Cow::from(server_name) + }, + None => { + if !DnsName::try_from(host.clone()).is_ok() { + // FIXME: https://github.com/rustls/rustls/issues/184 + // If self.host is neither an IP address nor a valid DNS + // name, the hacks below won't make it valid anyways. + let host = format!("{}.host-for-ip.edgedb.net", host); + // for ipv6addr + let host = host.replace(":", "-").replace("%", "-"); + if host.starts_with("-") { + Cow::from(format!("i{}", host)) + } else { + Cow::from(host) + } + } else { + Cow::from(host) + } + }, }; Ok(tls.connect(&host[..], conn).await.map_err(tls_fail)?) }