diff --git a/Cargo.lock b/Cargo.lock index 5caeea4..0f475fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -182,7 +182,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.1.0", + "hyper 1.3.1", "hyper-util", "itoa", "matchit", @@ -273,7 +273,7 @@ dependencies = [ "bytes", "http 1.1.0", "http-body 1.0.0", - "hyper 1.1.0", + "hyper 1.3.1", "reqwest 0.11.24", "serde", "tokio", @@ -352,9 +352,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "capture" @@ -691,6 +691,29 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +[[package]] +name = "feature-flags" +version = "0.1.0" +dependencies = [ + "anyhow", + "assert-json-diff", + "async-trait", + "axum 0.7.5", + "axum-client-ip", + "bytes", + "envconfig", + "once_cell", + "rand", + "redis", + "reqwest 0.12.3", + "serde", + "serde_json", + "thiserror", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "finl_unicode" version = "1.2.0" @@ -1226,9 +1249,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.1.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5aa53871fc917b1a9ed87b683a5d86db645e23acb32c2e0785a353e522fb75" +checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" dependencies = [ "bytes", "futures-channel", @@ -1240,6 +1263,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", + "smallvec", "tokio", "want", ] @@ -1278,7 +1302,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper 1.1.0", + "hyper 1.3.1", "hyper-util", "native-tls", "tokio", @@ -1297,7 +1321,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper 1.1.0", + "hyper 1.3.1", "pin-project-lite", "socket2 0.5.5", "tokio", @@ -1519,7 +1543,7 @@ checksum = "5d58e362dc7206e9456ddbcdbd53c71ba441020e62104703075a69151e38d85f" dependencies = [ "base64 0.22.0", "http-body-util", - "hyper 1.1.0", + "hyper 1.3.1", "hyper-tls", "hyper-util", "indexmap 2.2.2", @@ -2310,7 +2334,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.1.0", + "hyper 1.3.1", "hyper-tls", "hyper-util", "ipnet", @@ -3054,9 +3078,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.36.0" +version = "1.37.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" +checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" dependencies = [ "backtrace", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 180355b..ea5d041 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ resolver = "2" members = [ "capture", "common/health", + "feature-flags", "hook-api", "hook-common", "hook-janitor", @@ -49,7 +50,7 @@ opentelemetry-otlp = "0.15.0" opentelemetry_sdk = { version = "0.22.1", features = ["trace", "rt-tokio"] } rand = "0.8.5" rdkafka = { version = "0.36.0", features = ["cmake-build", "ssl", "tracing"] } -reqwest = { version = "0.12.3", features = ["stream"] } +reqwest = { version = "0.12.3", features = ["json", "stream"] } serde = { version = "1.0", features = ["derive"] } serde_derive = { version = "1.0" } serde_json = { version = "1.0" } diff --git a/feature-flags/Cargo.toml b/feature-flags/Cargo.toml new file mode 100644 index 0000000..ddfe070 --- /dev/null +++ b/feature-flags/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "feature-flags" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +axum = { workspace = true } +axum-client-ip = { workspace = true } +envconfig = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } +bytes = { workspace = true } +rand = { workspace = true } +redis = { version = "0.23.3", features = [ + "tokio-comp", + "cluster", + "cluster-async", +] } +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } + +[lints] +workspace = true + +[dev-dependencies] +assert-json-diff = { workspace = true } +once_cell = "1.18.0" +reqwest = { workspace = true } + diff --git a/feature-flags/src/api.rs b/feature-flags/src/api.rs new file mode 100644 index 0000000..c94eed6 --- /dev/null +++ b/feature-flags/src/api.rs @@ -0,0 +1,58 @@ +use std::collections::HashMap; + +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] +pub enum FlagsResponseCode { + Ok = 1, +} + +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct FlagsResponse { + pub error_while_computing_flags: bool, + // TODO: better typing here, support bool responses + pub feature_flags: HashMap, +} + +#[derive(Error, Debug)] +pub enum FlagError { + #[error("failed to decode request: {0}")] + RequestDecodingError(String), + #[error("failed to parse request: {0}")] + RequestParsingError(#[from] serde_json::Error), + + #[error("Empty distinct_id in request")] + EmptyDistinctId, + #[error("No distinct_id in request")] + MissingDistinctId, + + #[error("No api_key in request")] + NoTokenError, + #[error("API key is not valid")] + TokenValidationError, + + #[error("rate limited")] + RateLimited, +} + +impl IntoResponse for FlagError { + fn into_response(self) -> Response { + match self { + FlagError::RequestDecodingError(_) + | FlagError::RequestParsingError(_) + | FlagError::EmptyDistinctId + | FlagError::MissingDistinctId => (StatusCode::BAD_REQUEST, self.to_string()), + + FlagError::NoTokenError | FlagError::TokenValidationError => { + (StatusCode::UNAUTHORIZED, self.to_string()) + } + + FlagError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, self.to_string()), + } + .into_response() + } +} diff --git a/feature-flags/src/config.rs b/feature-flags/src/config.rs new file mode 100644 index 0000000..3fa6f50 --- /dev/null +++ b/feature-flags/src/config.rs @@ -0,0 +1,24 @@ +use std::net::SocketAddr; + +use envconfig::Envconfig; + +#[derive(Envconfig, Clone)] +pub struct Config { + #[envconfig(default = "127.0.0.1:0")] + pub address: SocketAddr, + + #[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")] + pub write_database_url: String, + + #[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")] + pub read_database_url: String, + + #[envconfig(default = "1024")] + pub max_concurrent_jobs: usize, + + #[envconfig(default = "100")] + pub max_pg_connections: u32, + + #[envconfig(default = "redis://localhost:6379/")] + pub redis_url: String, +} diff --git a/feature-flags/src/lib.rs b/feature-flags/src/lib.rs new file mode 100644 index 0000000..9175b5c --- /dev/null +++ b/feature-flags/src/lib.rs @@ -0,0 +1,7 @@ +pub mod api; +pub mod config; +pub mod redis; +pub mod router; +pub mod server; +pub mod v0_endpoint; +pub mod v0_request; diff --git a/feature-flags/src/main.rs b/feature-flags/src/main.rs new file mode 100644 index 0000000..980db69 --- /dev/null +++ b/feature-flags/src/main.rs @@ -0,0 +1,39 @@ +use envconfig::Envconfig; +use tokio::signal; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Layer}; + +use feature_flags::config::Config; +use feature_flags::server::serve; + +async fn shutdown() { + let mut term = signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to register SIGTERM handler"); + + let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt()) + .expect("failed to register SIGINT handler"); + + tokio::select! { + _ = term.recv() => {}, + _ = interrupt.recv() => {}, + }; + + tracing::info!("Shutting down gracefully..."); +} + +#[tokio::main] +async fn main() { + let config = Config::init_from_env().expect("Invalid configuration:"); + + // Basic logging for now: + // - stdout with a level configured by the RUST_LOG envvar (default=ERROR) + let log_layer = tracing_subscriber::fmt::layer().with_filter(EnvFilter::from_default_env()); + tracing_subscriber::registry().with(log_layer).init(); + + // Open the TCP port and start the server + let listener = tokio::net::TcpListener::bind(config.address) + .await + .expect("could not bind port"); + serve(config, listener, shutdown()).await +} diff --git a/feature-flags/src/redis.rs b/feature-flags/src/redis.rs new file mode 100644 index 0000000..8c03820 --- /dev/null +++ b/feature-flags/src/redis.rs @@ -0,0 +1,77 @@ +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use redis::AsyncCommands; +use tokio::time::timeout; + +// average for all commands is <10ms, check grafana +const REDIS_TIMEOUT_MILLISECS: u64 = 10; + +/// A simple redis wrapper +/// Copied from capture/src/redis.rs. +/// TODO: Modify this to support hincrby, get, and set commands. + +#[async_trait] +pub trait Client { + // A very simplified wrapper, but works for our usage + async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result>; +} + +pub struct RedisClient { + client: redis::Client, +} + +impl RedisClient { + pub fn new(addr: String) -> Result { + let client = redis::Client::open(addr)?; + + Ok(RedisClient { client }) + } +} + +#[async_trait] +impl Client for RedisClient { + async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result> { + let mut conn = self.client.get_async_connection().await?; + + let results = conn.zrangebyscore(k, min, max); + let fut = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?; + + Ok(fut?) + } +} + +// TODO: Find if there's a better way around this. +#[derive(Clone)] +pub struct MockRedisClient { + zrangebyscore_ret: Vec, +} + +impl MockRedisClient { + pub fn new() -> MockRedisClient { + MockRedisClient { + zrangebyscore_ret: Vec::new(), + } + } + + pub fn zrangebyscore_ret(&mut self, ret: Vec) -> Self { + self.zrangebyscore_ret = ret; + + self.clone() + } +} + +impl Default for MockRedisClient { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Client for MockRedisClient { + // A very simplified wrapper, but works for our usage + async fn zrangebyscore(&self, _k: String, _min: String, _max: String) -> Result> { + Ok(self.zrangebyscore_ret.clone()) + } +} diff --git a/feature-flags/src/router.rs b/feature-flags/src/router.rs new file mode 100644 index 0000000..8824d44 --- /dev/null +++ b/feature-flags/src/router.rs @@ -0,0 +1,19 @@ +use std::sync::Arc; + +use axum::{routing::post, Router}; + +use crate::{redis::Client, v0_endpoint}; + +#[derive(Clone)] +pub struct State { + pub redis: Arc, + // TODO: Add pgClient when ready +} + +pub fn router(redis: Arc) -> Router { + let state = State { redis }; + + Router::new() + .route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags)) + .with_state(state) +} diff --git a/feature-flags/src/server.rs b/feature-flags/src/server.rs new file mode 100644 index 0000000..ffe6b0e --- /dev/null +++ b/feature-flags/src/server.rs @@ -0,0 +1,31 @@ +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; + +use tokio::net::TcpListener; + +use crate::config::Config; + +use crate::redis::RedisClient; +use crate::router; + +pub async fn serve(config: Config, listener: TcpListener, shutdown: F) +where + F: Future + Send + 'static, +{ + let redis_client = + Arc::new(RedisClient::new(config.redis_url).expect("failed to create redis client")); + + let app = router::router(redis_client); + + // run our app with hyper + // `axum::Server` is a re-export of `hyper::Server` + tracing::info!("listening on {:?}", listener.local_addr().unwrap()); + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .with_graceful_shutdown(shutdown) + .await + .unwrap() +} diff --git a/feature-flags/src/v0_endpoint.rs b/feature-flags/src/v0_endpoint.rs new file mode 100644 index 0000000..8f77611 --- /dev/null +++ b/feature-flags/src/v0_endpoint.rs @@ -0,0 +1,89 @@ +use std::collections::HashMap; + +use axum::{debug_handler, Json}; +use bytes::Bytes; +// TODO: stream this instead +use axum::extract::{MatchedPath, Query, State}; +use axum::http::{HeaderMap, Method}; +use axum_client_ip::InsecureClientIp; +use tracing::instrument; + +use crate::{ + api::{FlagError, FlagsResponse}, + router, + v0_request::{FlagRequest, FlagsQueryParams}, +}; + +/// Feature flag evaluation endpoint. +/// Only supports a specific shape of data, and rejects any malformed data. + +#[instrument( + skip_all, + fields( + path, + token, + batch_size, + user_agent, + content_encoding, + content_type, + version, + compression, + historical_migration + ) +)] +#[debug_handler] +pub async fn flags( + _state: State, + InsecureClientIp(ip): InsecureClientIp, + meta: Query, + headers: HeaderMap, + method: Method, + path: MatchedPath, + body: Bytes, +) -> Result, FlagError> { + let user_agent = headers + .get("user-agent") + .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); + let content_encoding = headers + .get("content-encoding") + .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); + + tracing::Span::current().record("user_agent", user_agent); + tracing::Span::current().record("content_encoding", content_encoding); + tracing::Span::current().record("version", meta.version.clone()); + tracing::Span::current().record("method", method.as_str()); + tracing::Span::current().record("path", path.as_str().trim_end_matches('/')); + tracing::Span::current().record("ip", ip.to_string()); + + let request = match headers + .get("content-type") + .map_or("", |v| v.to_str().unwrap_or("")) + { + "application/x-www-form-urlencoded" => { + return Err(FlagError::RequestDecodingError(String::from( + "invalid form data", + ))); + } + ct => { + tracing::Span::current().record("content_type", ct); + + FlagRequest::from_bytes(body) + } + }?; + + let token = request.extract_and_verify_token()?; + + tracing::Span::current().record("token", &token); + + tracing::debug!("request: {:?}", request); + + // TODO: Some actual processing for evaluating the feature flag + + Ok(Json(FlagsResponse { + error_while_computing_flags: false, + feature_flags: HashMap::from([ + ("beta-feature".to_string(), "variant-1".to_string()), + ("rollout-flag".to_string(), true.to_string()), + ]), + })) +} diff --git a/feature-flags/src/v0_request.rs b/feature-flags/src/v0_request.rs new file mode 100644 index 0000000..f2269df --- /dev/null +++ b/feature-flags/src/v0_request.rs @@ -0,0 +1,68 @@ +use std::collections::HashMap; + +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tracing::instrument; + +use crate::api::FlagError; + +#[derive(Deserialize, Default)] +pub struct FlagsQueryParams { + #[serde(alias = "v")] + pub version: Option, +} + +#[derive(Default, Debug, Deserialize, Serialize)] +pub struct FlagRequest { + #[serde( + alias = "$token", + alias = "api_key", + skip_serializing_if = "Option::is_none" + )] + pub token: Option, + #[serde(alias = "$distinct_id", skip_serializing_if = "Option::is_none")] + pub distinct_id: Option, + pub geoip_disable: Option, + #[serde(default)] + pub person_properties: Option>, + #[serde(default)] + pub groups: Option>, + // TODO: better type this since we know its going to be a nested json + #[serde(default)] + pub group_properties: Option>, + #[serde(alias = "$anon_distinct_id", skip_serializing_if = "Option::is_none")] + pub anon_distinct_id: Option, +} + +impl FlagRequest { + /// Takes a request payload and tries to decompress and unmarshall it. + /// While posthog-js sends a compression query param, a sizable portion of requests + /// fail due to it being missing when the body is compressed. + /// Instead of trusting the parameter, we peek at the payload's first three bytes to + /// detect gzip, fallback to uncompressed utf8 otherwise. + #[instrument(skip_all)] + pub fn from_bytes(bytes: Bytes) -> Result { + tracing::debug!(len = bytes.len(), "decoding new request"); + // TODO: Add base64 decoding + let payload = String::from_utf8(bytes.into()).map_err(|e| { + tracing::error!("failed to decode body: {}", e); + FlagError::RequestDecodingError(String::from("invalid body encoding")) + })?; + + tracing::debug!(json = payload, "decoded event data"); + Ok(serde_json::from_str::(&payload)?) + } + + pub fn extract_and_verify_token(&self) -> Result { + let token = match self { + FlagRequest { + token: Some(token), .. + } => token.to_string(), + _ => return Err(FlagError::NoTokenError), + }; + // TODO: Get tokens from redis, confirm this one is valid + // validate_token(&token)?; + Ok(token) + } +} diff --git a/feature-flags/tests/common.rs b/feature-flags/tests/common.rs new file mode 100644 index 0000000..f66a11f --- /dev/null +++ b/feature-flags/tests/common.rs @@ -0,0 +1,66 @@ +use std::net::SocketAddr; +use std::str::FromStr; +use std::string::ToString; +use std::sync::Arc; + +use once_cell::sync::Lazy; +use rand::distributions::Alphanumeric; +use rand::Rng; +use tokio::net::TcpListener; +use tokio::sync::Notify; + +use feature_flags::config::Config; +use feature_flags::server::serve; + +pub static DEFAULT_CONFIG: Lazy = Lazy::new(|| Config { + address: SocketAddr::from_str("127.0.0.1:0").unwrap(), + redis_url: "redis://localhost:6379/".to_string(), + write_database_url: "postgres://posthog:posthog@localhost:15432/test_database".to_string(), + read_database_url: "postgres://posthog:posthog@localhost:15432/test_database".to_string(), + max_concurrent_jobs: 1024, + max_pg_connections: 100, +}); + +pub struct ServerHandle { + pub addr: SocketAddr, + shutdown: Arc, +} + +impl ServerHandle { + pub async fn for_config(config: Config) -> ServerHandle { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let notify = Arc::new(Notify::new()); + let shutdown = notify.clone(); + + tokio::spawn(async move { + serve(config, listener, async move { notify.notified().await }).await + }); + ServerHandle { addr, shutdown } + } + + pub async fn send_flags_request>(&self, body: T) -> reqwest::Response { + let client = reqwest::Client::new(); + client + .post(format!("http://{:?}/flags", self.addr)) + .body(body) + .send() + .await + .expect("failed to send request") + } +} + +impl Drop for ServerHandle { + fn drop(&mut self) { + self.shutdown.notify_one() + } +} + +pub fn random_string(prefix: &str, length: usize) -> String { + let suffix: String = rand::thread_rng() + .sample_iter(Alphanumeric) + .take(length) + .map(char::from) + .collect(); + format!("{}_{}", prefix, suffix) +} diff --git a/feature-flags/tests/test_flags.rs b/feature-flags/tests/test_flags.rs new file mode 100644 index 0000000..82f41f0 --- /dev/null +++ b/feature-flags/tests/test_flags.rs @@ -0,0 +1,43 @@ +use anyhow::Result; +use assert_json_diff::assert_json_include; + +use reqwest::StatusCode; +use serde_json::{json, Value}; + +use crate::common::*; +mod common; + +#[tokio::test] +async fn it_sends_flag_request() -> Result<()> { + let token = random_string("token", 16); + let distinct_id = "user_distinct_id".to_string(); + + let config = DEFAULT_CONFIG.clone(); + + let server = ServerHandle::for_config(config).await; + + let payload = json!({ + "token": token, + "distinct_id": distinct_id, + "groups": {"group1": "group1"} + }); + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::OK, res.status()); + + // We don't want to deserialize the data into a flagResponse struct here, + // because we want to assert the shape of the raw json data. + let json_data = res.json::().await?; + + assert_json_include!( + actual: json_data, + expected: json!({ + "errorWhileComputingFlags": false, + "featureFlags": { + "beta-feature": "variant-1", + "rollout-flag": "true", + } + }) + ); + + Ok(()) +}