Skip to content

Commit

Permalink
Allow specifying SNI separately from host address (#300)
Browse files Browse the repository at this point in the history
When the target instance address cannot be resolved correctly
from a hostname, but the SNI is still desirable for TLS verification
and/or tenant selection reasons, provide a way to specify the SNI value
via the `tls_server_name` (`EDGEDB_TLS_SERVER_NAME`) connection parameter.
  • Loading branch information
elprans authored Mar 14, 2024
1 parent 3e08a28 commit 2c5ba18
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 15 deletions.
34 changes: 34 additions & 0 deletions edgedb-tokio/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub struct Builder {
password: Option<String>,
tls_ca_file: Option<PathBuf>,
tls_security: Option<TlsSecurity>,
tls_server_name: Option<String>,
client_security: Option<ClientSecurity>,
pem_certificates: Option<String>,
wait_until_available: Option<Duration>,
Expand Down Expand Up @@ -120,6 +121,8 @@ pub(crate) struct ConfigInner {
// Pool configuration
pub max_concurrency: Option<usize>,

pub tls_server_name: Option<String>,

instance_name: Option<InstanceName>,
tls_security: TlsSecurity,
client_security: ClientSecurity,
Expand Down Expand Up @@ -492,6 +495,10 @@ impl<'a> DsnHelper<'a> {
}
}

async fn retrieve_tls_server_name(&mut self) -> Result<Option<String>, Error> {
self.retrieve_value("tls_server_name", None, |s| Ok(s)).await
}

async fn retrieve_port(&mut self) -> Result<Option<u16>, Error> {
self.retrieve_value("port", self.url.port(), |s| {
s.parse().map_err(|e| {
Expand Down Expand Up @@ -623,6 +630,13 @@ impl Builder {
Ok(self)
}

/// Override server name indication (SNI) in TLS handshake
pub fn tls_server_name(&mut self, tls_server_name: &str) -> Result<&mut Self, Error> {
validate_host(tls_server_name)?;
self.tls_server_name = Some(tls_server_name.to_string());
Ok(self)
}

/// Set port to connect to
pub fn port(&mut self, port: u16) -> Result<&mut Self, Error> {
validate_port(port)?;
Expand Down Expand Up @@ -790,6 +804,7 @@ impl Builder {
let creds = self.credentials.as_ref();
let mut cfg = ConfigInner {
address,
tls_server_name: self.tls_server_name.clone(),
admin: self.admin,
user: self.user.clone()
.or_else(|| creds.map(|c| c.user.clone()))
Expand Down Expand Up @@ -936,6 +951,10 @@ impl Builder {
cfg.password = Some(password.clone());
}

if let Some(tls_server_name) = &self.tls_server_name {
cfg.tls_server_name = Some(tls_server_name.clone());
}

if let Some(tls_ca_file) = &self.tls_ca_file {
match read_certificates(tls_ca_file).await {
Ok(pem) => cfg.pem_certificates = Some(pem),
Expand Down Expand Up @@ -1083,6 +1102,14 @@ impl Builder {
cfg.user = user;
}

let tls_server_name = self.tls_server_name.clone().or_else(|| {
get_env("EDGEDB_TLS_SERVER_NAME")
.map_err(|e| errors.push(e)).ok().flatten()
});
if let Some(tls_server_name) = tls_server_name {
cfg.tls_server_name = Some(tls_server_name);
}

let password = self.password.clone().or_else(|| {
get_env("EDGEDB_PASSWORD")
.map_err(|e| errors.push(e)).ok().flatten()
Expand Down Expand Up @@ -1154,6 +1181,11 @@ impl Builder {
let port = dsn.retrieve_port().await
.map_err(|e| errors.push(e)).ok().flatten()
.unwrap_or(DEFAULT_PORT);
match dsn.retrieve_tls_server_name().await {
Ok(Some(value)) => cfg.tls_server_name = Some(value),
Ok(None) => {},
Err(e) => errors.push(e),
}
cfg.address = Address::Tcp((host, port));
cfg.admin = dsn.admin;
match dsn.retrieve_user().await {
Expand Down Expand Up @@ -1311,6 +1343,7 @@ impl Builder {

let mut cfg = ConfigInner {
address: Address::Tcp((DEFAULT_HOST.into(), DEFAULT_PORT)),
tls_server_name: self.tls_server_name.clone(),
admin: self.admin,
user: "edgedb".into(),
password: None,
Expand Down Expand Up @@ -1646,6 +1679,7 @@ impl Config {
"secretKey": self.0.secret_key,
"tlsCAData": self.0.pem_certificates,
"tlsSecurity": self.0.compute_tls_security().unwrap(),
"tlsServerName": self.0.tls_server_name,
"serverSettings": self.0.extra_dsn_query_args,
"waitUntilAvailable": self.0.wait.as_micros() as i64,
}).to_string()
Expand Down
36 changes: 21 additions & 15 deletions edgedb-tokio/src/raw/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,21 +335,27 @@ async fn connect3(cfg: &Config, tls: &TlsConnectorBox)
Address::Tcp(addr@(host,_)) => {
let conn = TcpStream::connect(addr).await
.map_err(ClientConnectionError::with_source)?;
let is_valid_dns = DnsName::try_from(host.clone()).is_ok();
let host = if !is_valid_dns {
// FIXME: https://github.com/rustls/rustls/issues/184
// If self.host is neither an IP address nor a valid DNS
// name, the hacks below won't make it valid anyways.
let host = format!("{}.host-for-ip.edgedb.net", host);
// for ipv6addr
let host = host.replace(":", "-").replace("%", "-");
if host.starts_with("-") {
Cow::from(format!("i{}", host))
} else {
Cow::from(host)
}
} else {
Cow::from(host)
let host = match &cfg.0.tls_server_name {
Some(server_name) => {
Cow::from(server_name)
},
None => {
if !DnsName::try_from(host.clone()).is_ok() {
// FIXME: https://github.com/rustls/rustls/issues/184
// If self.host is neither an IP address nor a valid DNS
// name, the hacks below won't make it valid anyways.
let host = format!("{}.host-for-ip.edgedb.net", host);
// for ipv6addr
let host = host.replace(":", "-").replace("%", "-");
if host.starts_with("-") {
Cow::from(format!("i{}", host))
} else {
Cow::from(host)
}
} else {
Cow::from(host)
}
},
};
Ok(tls.connect(&host[..], conn).await.map_err(tls_fail)?)
}
Expand Down

0 comments on commit 2c5ba18

Please sign in to comment.