Skip to content

Commit

Permalink
feat: add support for tls connection (#15)
Browse files Browse the repository at this point in the history
* feat: add support for tls connection

* test: update test case api
  • Loading branch information
sunng87 authored Feb 29, 2024
1 parent 35c09ba commit d21dbcf
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rand = "0.8"
snafu = "0.7"
tokio = { version = "1", features = ["rt", "time"] }
tokio-stream = { version = "0.1", features = ["net"] }
tonic = { version = "0.10", features = ["tls"] }
tonic = { version = "0.10", features = ["tls", "tls-roots"] }
tower = "0.4"

[build-dependencies]
Expand Down
18 changes: 16 additions & 2 deletions examples/ingest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,30 @@ use derive_new::new;
use greptimedb_client::api::v1::*;
use greptimedb_client::helpers::schema::*;
use greptimedb_client::helpers::values::*;
use greptimedb_client::{Client, Database, DEFAULT_SCHEMA_NAME};
use greptimedb_client::{
ChannelConfig, ChannelManager, Client, ClientTlsOption, Database, DEFAULT_SCHEMA_NAME,
};

#[tokio::main]
async fn main() {
let greptimedb_endpoint =
std::env::var("GREPTIMEDB_ENDPOINT").unwrap_or_else(|_| "localhost:4001".to_owned());
let greptimedb_dbname =
std::env::var("GREPTIMEDB_DBNAME").unwrap_or_else(|_| DEFAULT_SCHEMA_NAME.to_owned());
let greptimedb_secure = std::env::var("GREPTIMEDB_TLS")
.map(|s| s == "1")
.unwrap_or(false);

let grpc_client = if greptimedb_secure {
let channel_config = ChannelConfig::default().client_tls_config(ClientTlsOption::default());

let channel_manager = ChannelManager::with_tls_config(channel_config)
.expect("Failed to create channel manager");
Client::with_manager_and_urls(channel_manager, vec![&greptimedb_endpoint])
} else {
Client::with_urls(vec![&greptimedb_endpoint])
};

let grpc_client = Client::with_urls(vec![&greptimedb_endpoint]);
let client = Database::new_with_dbname(greptimedb_dbname, grpc_client);

let records = weather_records();
Expand Down
62 changes: 37 additions & 25 deletions src/channel_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,23 @@ impl ChannelManager {
msg: "no config input",
})?;

let server_root_ca_cert = std::fs::read_to_string(path_config.server_ca_cert_path)
.context(InvalidConfigFilePathSnafu)?;
let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
let client_cert = std::fs::read_to_string(path_config.client_cert_path)
.context(InvalidConfigFilePathSnafu)?;
let client_key = std::fs::read_to_string(path_config.client_key_path)
.context(InvalidConfigFilePathSnafu)?;
let client_identity = Identity::from_pem(client_cert, client_key);

cm.client_tls_config = Some(
ClientTlsConfig::new()
.ca_certificate(server_root_ca_cert)
.identity(client_identity),
);
let mut client_tls_config = ClientTlsConfig::new();
if let Some(server_ca_cert_path) = path_config.server_ca_cert_path {
let ca_cert = Certificate::from_pem(
std::fs::read_to_string(server_ca_cert_path).context(InvalidConfigFilePathSnafu)?,
);
client_tls_config = client_tls_config.ca_certificate(ca_cert);
}

if let (Some(cert), Some(key)) = (path_config.client_cert_path, path_config.client_key_path)
{
let client_cert = std::fs::read_to_string(cert).context(InvalidConfigFilePathSnafu)?;
let client_key = std::fs::read_to_string(key).context(InvalidConfigFilePathSnafu)?;
let client_identity = Identity::from_pem(client_cert, client_key);
client_tls_config = client_tls_config.identity(client_identity);
}

cm.client_tls_config = Some(client_tls_config);

Ok(cm)
}
Expand Down Expand Up @@ -152,7 +155,13 @@ impl ChannelManager {
}

fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
let mut endpoint = Endpoint::new(format!("http://{addr}")).context(CreateChannelSnafu)?;
let scheme = if self.client_tls_config.is_some() {
"https"
} else {
"http"
};
let mut endpoint =
Endpoint::new(format!("{scheme}://{addr}")).context(CreateChannelSnafu)?;

if let Some(dur) = self.config.timeout {
endpoint = endpoint.timeout(dur);
Expand Down Expand Up @@ -198,11 +207,14 @@ impl ChannelManager {
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct ClientTlsOption {
pub server_ca_cert_path: PathBuf,
pub client_cert_path: PathBuf,
pub client_key_path: PathBuf,
/// Path to server CA file, use system CA when not configured
pub server_ca_cert_path: Option<PathBuf>,
/// the file path to client certificate
pub client_cert_path: Option<PathBuf>,
/// the file path to client private key
pub client_key_path: Option<PathBuf>,
}

#[derive(Clone, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -523,9 +535,9 @@ mod tests {
.tcp_keepalive(Duration::from_secs(2))
.tcp_nodelay(false)
.client_tls_config(ClientTlsOption {
server_ca_cert_path: "some_server_path".into(),
client_cert_path: "some_cert_path".into(),
client_key_path: "some_key_path".into(),
server_ca_cert_path: Some("some_server_path".into()),
client_cert_path: Some("some_cert_path".into()),
client_key_path: Some("some_key_path".into()),
});

assert_eq!(
Expand All @@ -543,9 +555,9 @@ mod tests {
tcp_keepalive: Some(Duration::from_secs(2)),
tcp_nodelay: false,
client_tls: Some(ClientTlsOption {
server_ca_cert_path: "some_server_path".into(),
client_cert_path: "some_cert_path".into(),
client_key_path: "some_key_path".into(),
server_ca_cert_path: Some("some_server_path".into()),
client_cert_path: Some("some_cert_path".into()),
client_key_path: Some("some_key_path".into()),
}),
},
cfg
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub mod helpers;
pub mod load_balance;
mod stream_insert;

pub use self::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption};
pub use self::client::Client;
pub use self::database::Database;
pub use self::error::{Error, Result};
Expand Down

0 comments on commit d21dbcf

Please sign in to comment.