Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for tls connection #15

Merged
merged 2 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rand = "0.8"
snafu = "0.7"
tokio = { version = "1", features = ["rt", "time"] }
tokio-stream = { version = "0.1", features = ["net"] }
tonic = { version = "0.10", features = ["tls"] }
tonic = { version = "0.10", features = ["tls", "tls-roots"] }
tower = "0.4"

[build-dependencies]
Expand Down
18 changes: 16 additions & 2 deletions examples/ingest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,30 @@ use derive_new::new;
use greptimedb_client::api::v1::*;
use greptimedb_client::helpers::schema::*;
use greptimedb_client::helpers::values::*;
use greptimedb_client::{Client, Database, DEFAULT_SCHEMA_NAME};
use greptimedb_client::{
ChannelConfig, ChannelManager, Client, ClientTlsOption, Database, DEFAULT_SCHEMA_NAME,
};

#[tokio::main]
async fn main() {
let greptimedb_endpoint =
std::env::var("GREPTIMEDB_ENDPOINT").unwrap_or_else(|_| "localhost:4001".to_owned());
let greptimedb_dbname =
std::env::var("GREPTIMEDB_DBNAME").unwrap_or_else(|_| DEFAULT_SCHEMA_NAME.to_owned());
let greptimedb_secure = std::env::var("GREPTIMEDB_TLS")
.map(|s| s == "1")
.unwrap_or(false);

let grpc_client = if greptimedb_secure {
let channel_config = ChannelConfig::default().client_tls_config(ClientTlsOption::default());

let channel_manager = ChannelManager::with_tls_config(channel_config)
.expect("Failed to create channel manager");
Client::with_manager_and_urls(channel_manager, vec![&greptimedb_endpoint])
} else {
Client::with_urls(vec![&greptimedb_endpoint])
};

let grpc_client = Client::with_urls(vec![&greptimedb_endpoint]);
let client = Database::new_with_dbname(greptimedb_dbname, grpc_client);

let records = weather_records();
Expand Down
62 changes: 37 additions & 25 deletions src/channel_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,23 @@ impl ChannelManager {
msg: "no config input",
})?;

let server_root_ca_cert = std::fs::read_to_string(path_config.server_ca_cert_path)
.context(InvalidConfigFilePathSnafu)?;
let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
let client_cert = std::fs::read_to_string(path_config.client_cert_path)
.context(InvalidConfigFilePathSnafu)?;
let client_key = std::fs::read_to_string(path_config.client_key_path)
.context(InvalidConfigFilePathSnafu)?;
let client_identity = Identity::from_pem(client_cert, client_key);

cm.client_tls_config = Some(
ClientTlsConfig::new()
.ca_certificate(server_root_ca_cert)
.identity(client_identity),
);
let mut client_tls_config = ClientTlsConfig::new();
if let Some(server_ca_cert_path) = path_config.server_ca_cert_path {
let ca_cert = Certificate::from_pem(
std::fs::read_to_string(server_ca_cert_path).context(InvalidConfigFilePathSnafu)?,
);
client_tls_config = client_tls_config.ca_certificate(ca_cert);
}

if let (Some(cert), Some(key)) = (path_config.client_cert_path, path_config.client_key_path)
{
let client_cert = std::fs::read_to_string(cert).context(InvalidConfigFilePathSnafu)?;
let client_key = std::fs::read_to_string(key).context(InvalidConfigFilePathSnafu)?;
let client_identity = Identity::from_pem(client_cert, client_key);
client_tls_config = client_tls_config.identity(client_identity);
}

cm.client_tls_config = Some(client_tls_config);

Ok(cm)
}
Expand Down Expand Up @@ -152,7 +155,13 @@ impl ChannelManager {
}

fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
let mut endpoint = Endpoint::new(format!("http://{addr}")).context(CreateChannelSnafu)?;
let scheme = if self.client_tls_config.is_some() {
"https"
} else {
"http"
};
let mut endpoint =
Endpoint::new(format!("{scheme}://{addr}")).context(CreateChannelSnafu)?;

if let Some(dur) = self.config.timeout {
endpoint = endpoint.timeout(dur);
Expand Down Expand Up @@ -198,11 +207,14 @@ impl ChannelManager {
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct ClientTlsOption {
pub server_ca_cert_path: PathBuf,
pub client_cert_path: PathBuf,
pub client_key_path: PathBuf,
/// Path to server CA file, use system CA when not configured
pub server_ca_cert_path: Option<PathBuf>,
/// the file path to client certificate
pub client_cert_path: Option<PathBuf>,
/// the file path to client private key
pub client_key_path: Option<PathBuf>,
}

#[derive(Clone, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -523,9 +535,9 @@ mod tests {
.tcp_keepalive(Duration::from_secs(2))
.tcp_nodelay(false)
.client_tls_config(ClientTlsOption {
server_ca_cert_path: "some_server_path".into(),
client_cert_path: "some_cert_path".into(),
client_key_path: "some_key_path".into(),
server_ca_cert_path: Some("some_server_path".into()),
client_cert_path: Some("some_cert_path".into()),
client_key_path: Some("some_key_path".into()),
});

assert_eq!(
Expand All @@ -543,9 +555,9 @@ mod tests {
tcp_keepalive: Some(Duration::from_secs(2)),
tcp_nodelay: false,
client_tls: Some(ClientTlsOption {
server_ca_cert_path: "some_server_path".into(),
client_cert_path: "some_cert_path".into(),
client_key_path: "some_key_path".into(),
server_ca_cert_path: Some("some_server_path".into()),
client_cert_path: Some("some_cert_path".into()),
client_key_path: Some("some_key_path".into()),
}),
},
cfg
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub mod helpers;
pub mod load_balance;
mod stream_insert;

pub use self::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption};
pub use self::client::Client;
pub use self::database::Database;
pub use self::error::{Error, Result};
Expand Down
Loading