From a6d73090d7095c9ad733314aa764c84bf128b23e Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Wed, 8 May 2024 20:06:35 +0100 Subject: [PATCH 1/6] feat(flags): Do token validation and extract distinct id --- Cargo.lock | 39 ++++++++++++++++++++--- feature-flags/Cargo.toml | 1 + feature-flags/src/api.rs | 4 +++ feature-flags/src/config.rs | 2 +- feature-flags/src/lib.rs | 1 + feature-flags/src/redis.rs | 54 +++++++++++++++----------------- feature-flags/src/v0_endpoint.rs | 18 +++++------ feature-flags/src/v0_request.rs | 13 +++++--- 8 files changed, 83 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0f475fa..8642ade 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -707,6 +707,7 @@ dependencies = [ "redis", "reqwest 0.12.3", "serde", + "serde-pickle", "serde_json", "thiserror", "tokio", @@ -1395,6 +1396,12 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +[[package]] +name = "iter-read" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c397ca3ea05ad509c4ec451fea28b4771236a376ca1c69fd5143aae0cf8f93c4" + [[package]] name = "itertools" version = "0.12.1" @@ -1680,6 +1687,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -1705,11 +1722,10 @@ checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "num-integer" -version = "0.1.45" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" dependencies = [ - "autocfg", "num-traits", ] @@ -1726,9 +1742,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", "libm", @@ -2533,6 +2549,19 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-pickle" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c762ad136a26407c6a80825813600ceeab5e613660d93d79a41f0ec877171e71" +dependencies = [ + "byteorder", + "iter-read", + "num-bigint", + "num-traits", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.196" diff --git a/feature-flags/Cargo.toml b/feature-flags/Cargo.toml index ddfe070..1e0c111 100644 --- a/feature-flags/Cargo.toml +++ b/feature-flags/Cargo.toml @@ -24,6 +24,7 @@ redis = { version = "0.23.3", features = [ serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } +serde-pickle = { version = "1.1.1"} [lints] workspace = true diff --git a/feature-flags/src/api.rs b/feature-flags/src/api.rs index c94eed6..ebad1f5 100644 --- a/feature-flags/src/api.rs +++ b/feature-flags/src/api.rs @@ -25,6 +25,9 @@ pub enum FlagError { #[error("failed to parse request: {0}")] RequestParsingError(#[from] serde_json::Error), + #[error("failed to parse redis data: {0}")] + DataParsingError(#[from] serde_pickle::Error), + #[error("Empty distinct_id in request")] EmptyDistinctId, #[error("No distinct_id in request")] @@ -44,6 +47,7 @@ impl IntoResponse for FlagError { match self { FlagError::RequestDecodingError(_) | FlagError::RequestParsingError(_) + | FlagError::DataParsingError(_) | FlagError::EmptyDistinctId | FlagError::MissingDistinctId => (StatusCode::BAD_REQUEST, self.to_string()), diff --git a/feature-flags/src/config.rs b/feature-flags/src/config.rs index 3fa6f50..cc7ad37 100644 --- a/feature-flags/src/config.rs +++ b/feature-flags/src/config.rs @@ -4,7 +4,7 @@ use envconfig::Envconfig; #[derive(Envconfig, Clone)] pub struct Config { - #[envconfig(default = "127.0.0.1:0")] + #[envconfig(default = "127.0.0.1:3001")] pub address: SocketAddr, #[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")] diff --git a/feature-flags/src/lib.rs b/feature-flags/src/lib.rs index 9175b5c..71a5e69 100644 --- a/feature-flags/src/lib.rs +++ b/feature-flags/src/lib.rs @@ -5,3 +5,4 @@ pub mod router; pub mod server; pub mod v0_endpoint; pub mod v0_request; +pub mod team; diff --git a/feature-flags/src/redis.rs b/feature-flags/src/redis.rs index 8c03820..70b7146 100644 --- a/feature-flags/src/redis.rs +++ b/feature-flags/src/redis.rs @@ -2,7 +2,7 @@ use std::time::Duration; use anyhow::Result; use async_trait::async_trait; -use redis::AsyncCommands; +use redis::{AsyncCommands, RedisError}; use tokio::time::timeout; // average for all commands is <10ms, check grafana @@ -10,12 +10,15 @@ const REDIS_TIMEOUT_MILLISECS: u64 = 10; /// A simple redis wrapper /// Copied from capture/src/redis.rs. -/// TODO: Modify this to support hincrby, get, and set commands. +/// TODO: Modify this to support hincrby #[async_trait] pub trait Client { // A very simplified wrapper, but works for our usage async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result>; + + async fn get(&self, k: String) -> Result; + async fn set(&self, k: String, v: String) -> Result<()>; } pub struct RedisClient { @@ -40,38 +43,31 @@ impl Client for RedisClient { Ok(fut?) } -} -// TODO: Find if there's a better way around this. -#[derive(Clone)] -pub struct MockRedisClient { - zrangebyscore_ret: Vec, -} + async fn get(&self, k: String) -> Result { + let mut conn = self.client.get_async_connection().await?; -impl MockRedisClient { - pub fn new() -> MockRedisClient { - MockRedisClient { - zrangebyscore_ret: Vec::new(), - } - } + let results = conn.get(k.clone()); + // TODO: Is this safe? Should we be doing something else for error handling here? + let fut: Result, RedisError> = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?; - pub fn zrangebyscore_ret(&mut self, ret: Vec) -> Self { - self.zrangebyscore_ret = ret; + // TRICKY: We serialise data to json, then django pickles it. + // Here we deserialize the bytes using serde_pickle, to get the json string. + let string_response: String = serde_pickle::from_slice(&fut?, Default::default())?; - self.clone() + Ok(string_response) } -} -impl Default for MockRedisClient { - fn default() -> Self { - Self::new() - } -} + async fn set(&self, k: String, v: String) -> Result<()> { + // TRICKY: We serialise data to json, then django pickles it. + // Here we serialize the json string to bytes using serde_pickle. + let bytes = serde_pickle::to_vec(&v, Default::default())?; -#[async_trait] -impl Client for MockRedisClient { - // A very simplified wrapper, but works for our usage - async fn zrangebyscore(&self, _k: String, _min: String, _max: String) -> Result> { - Ok(self.zrangebyscore_ret.clone()) + let mut conn = self.client.get_async_connection().await?; + + let results = conn.set(k, bytes); + let fut = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?; + + Ok(fut?) } -} +} \ No newline at end of file diff --git a/feature-flags/src/v0_endpoint.rs b/feature-flags/src/v0_endpoint.rs index 8f77611..4a46d45 100644 --- a/feature-flags/src/v0_endpoint.rs +++ b/feature-flags/src/v0_endpoint.rs @@ -33,7 +33,7 @@ use crate::{ )] #[debug_handler] pub async fn flags( - _state: State, + state: State, InsecureClientIp(ip): InsecureClientIp, meta: Query, headers: HeaderMap, @@ -59,19 +59,19 @@ pub async fn flags( .get("content-type") .map_or("", |v| v.to_str().unwrap_or("")) { - "application/x-www-form-urlencoded" => { - return Err(FlagError::RequestDecodingError(String::from( - "invalid form data", - ))); + "application/json" => { + tracing::Span::current().record("content_type", "application/json"); + FlagRequest::from_bytes(body) } ct => { - tracing::Span::current().record("content_type", ct); - - FlagRequest::from_bytes(body) + return Err(FlagError::RequestDecodingError(format!( + "unsupported content type: {}", + ct + ))); } }?; - let token = request.extract_and_verify_token()?; + let token = request.extract_and_verify_token(state.redis.clone()).await?; tracing::Span::current().record("token", &token); diff --git a/feature-flags/src/v0_request.rs b/feature-flags/src/v0_request.rs index f2269df..2954b2e 100644 --- a/feature-flags/src/v0_request.rs +++ b/feature-flags/src/v0_request.rs @@ -1,11 +1,11 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use bytes::Bytes; use serde::{Deserialize, Serialize}; use serde_json::Value; use tracing::instrument; -use crate::api::FlagError; +use crate::{api::FlagError, redis::Client, team::Team}; #[derive(Deserialize, Default)] pub struct FlagsQueryParams { @@ -54,15 +54,18 @@ impl FlagRequest { Ok(serde_json::from_str::(&payload)?) } - pub fn extract_and_verify_token(&self) -> Result { + pub async fn extract_and_verify_token(&self, redis_client: Arc) -> Result { let token = match self { FlagRequest { token: Some(token), .. } => token.to_string(), _ => return Err(FlagError::NoTokenError), }; - // TODO: Get tokens from redis, confirm this one is valid - // validate_token(&token)?; + + let team = Team::from_redis(redis_client, token.clone()).await?; + + // TODO: Remove this, is useless, doing just for now because + tracing::Span::current().record("team_id", &team.id); Ok(token) } } From d0e9bc04bef92e2570ab6c27d69a2ddeee282672 Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Wed, 8 May 2024 20:07:57 +0100 Subject: [PATCH 2/6] add mod --- feature-flags/src/team.rs | 133 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 feature-flags/src/team.rs diff --git a/feature-flags/src/team.rs b/feature-flags/src/team.rs new file mode 100644 index 0000000..cfa54c3 --- /dev/null +++ b/feature-flags/src/team.rs @@ -0,0 +1,133 @@ +use std::sync::Arc; + +use crate::{api::FlagError, redis::Client}; + +use serde::{Deserialize, Serialize}; +use tracing::instrument; + + +// TRICKY: I'm still not sure where the :1: is coming from. +// The Django prefix is `posthog` only. +// It's from here: https://docs.djangoproject.com/en/4.2/topics/cache/#cache-versioning +// F&!£%% on the bright side we don't use this functionality yet. +// Will rely on integration tests to catch this. +const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:"; + +// TODO: Check what happens if json has extra stuff, does serde ignore it? Yes +// Make sure we don't serialize and store team data in redis. Let main decide endpoint control this... +// and track misses. Revisit if this becomes an issue. +// because otherwise very annoying to keep this in sync with main django which has a lot of extra fields we need here. +// will lead to inconsistent behaviour. +// This is turning out to be very annoying, because we have django key prefixes to be mindful of as well. +// Wonder if it would be better to make these caches independent? This generates that new problem of CRUD happening in Django, +// which needs to update this cache immediately, so they can't really ever be independent. +// True for both team cache and flags cache. Hmm. Just I guess need to add tests around the key prefixes... +#[derive(Debug, Deserialize, Serialize)] +pub struct Team { + pub id: i64, + pub name: String, + pub api_token: String, +} + +impl Team { + /// Validates a token, and returns a team if it exists. + /// + + #[instrument(skip_all)] + pub async fn from_redis( + client: Arc, + token: String, + ) -> Result { + + // TODO: Instead of failing here, i.e. if not in redis, fallback to pg + let serialized_team = client + .get( + format!("{TEAM_TOKEN_CACHE_PREFIX}{}", token) + ) + .await + .map_err(|e| { + tracing::error!("failed to fetch data: {}", e); + // TODO: Can be other errors if serde_pickle destructuring fails? + FlagError::TokenValidationError + })?; + + let team: Team = serde_json::from_str(&serialized_team).map_err(|e| { + tracing::error!("failed to parse data to team: {}", e); + // TODO: Internal error, shouldn't send back to client + FlagError::RequestParsingError(e) + })?; + + Ok(team) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use anyhow::Error; + + use crate::redis::RedisClient; + use rand::{distributions::Alphanumeric, Rng}; + + use super::*; + + fn random_string(prefix: &str, length: usize) -> String { + let suffix: String = rand::thread_rng() + .sample_iter(Alphanumeric) + .take(length) + .map(char::from) + .collect(); + format!("{}{}", prefix, suffix) + } + + async fn insert_new_team_in_redis(client: Arc) -> Result { + let id = rand::thread_rng().gen_range(0..10_000_000); + let token = random_string("phc_", 12); + let team = Team { + id: id, + name: "team".to_string(), + api_token: token, + }; + + let serialized_team = serde_json::to_string(&team)?; + client + .set( + format!("{TEAM_TOKEN_CACHE_PREFIX}{}", team.api_token.clone()), + serialized_team, + ) + .await?; + + Ok(team) + } + + #[tokio::test] + async fn test_fetch_team_from_redis() { + let client = RedisClient::new("redis://localhost:6379/".to_string()) + .expect("Failed to create redis client"); + let client = Arc::new(client); + + let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + + let target_token = team.api_token; + + let team_from_redis = Team::from_redis(client.clone(), target_token.clone()).await.unwrap(); + assert_eq!( + team_from_redis.api_token, target_token + ); + assert_eq!( + team_from_redis.id, team.id + ); + } + + #[tokio::test] + async fn test_fetch_invalid_team_from_redis() { + let client = RedisClient::new("redis://localhost:6379/".to_string()) + .expect("Failed to create redis client"); + let client = Arc::new(client); + + match Team::from_redis(client.clone(), "banana".to_string()).await { + Err(FlagError::TokenValidationError) => (), + _ => panic!("Expected TokenValidationError"), + }; + } +} From 327074cde9d40f0790439e1f1f84e0ab9a5a4626 Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Thu, 9 May 2024 11:02:35 +0100 Subject: [PATCH 3/6] add more tests, common scaffolding --- feature-flags/src/lib.rs | 8 +++ feature-flags/src/team.rs | 49 +++---------------- feature-flags/src/test_utils.rs | 43 ++++++++++++++++ .../tests/{common.rs => common/mod.rs} | 24 ++++----- feature-flags/tests/test_flags.rs | 42 ++++++++++++++-- 5 files changed, 111 insertions(+), 55 deletions(-) create mode 100644 feature-flags/src/test_utils.rs rename feature-flags/tests/{common.rs => common/mod.rs} (77%) diff --git a/feature-flags/src/lib.rs b/feature-flags/src/lib.rs index 71a5e69..c9d07cc 100644 --- a/feature-flags/src/lib.rs +++ b/feature-flags/src/lib.rs @@ -6,3 +6,11 @@ pub mod server; pub mod v0_endpoint; pub mod v0_request; pub mod team; + +// Test modules don't need to be compiled with main binary +// #[cfg(test)] +// TODO: To use in integration tests, we need to compile with binary +// or make it a separate feature using cfg(feature = "integration-tests") +// and then use this feature only in tests. +// For now, ok to just include in binary +pub mod test_utils; diff --git a/feature-flags/src/team.rs b/feature-flags/src/team.rs index cfa54c3..d55aa93 100644 --- a/feature-flags/src/team.rs +++ b/feature-flags/src/team.rs @@ -11,7 +11,7 @@ use tracing::instrument; // It's from here: https://docs.djangoproject.com/en/4.2/topics/cache/#cache-versioning // F&!£%% on the bright side we don't use this functionality yet. // Will rely on integration tests to catch this. -const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:"; +pub const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:"; // TODO: Check what happens if json has extra stuff, does serde ignore it? Yes // Make sure we don't serialize and store team data in redis. Let main decide endpoint control this... @@ -63,48 +63,13 @@ impl Team { #[cfg(test)] mod tests { - use std::sync::Arc; - use anyhow::Error; - - use crate::redis::RedisClient; - use rand::{distributions::Alphanumeric, Rng}; - + use crate::test_utils::{insert_new_team_in_redis, setup_redis_client}; use super::*; - fn random_string(prefix: &str, length: usize) -> String { - let suffix: String = rand::thread_rng() - .sample_iter(Alphanumeric) - .take(length) - .map(char::from) - .collect(); - format!("{}{}", prefix, suffix) - } - - async fn insert_new_team_in_redis(client: Arc) -> Result { - let id = rand::thread_rng().gen_range(0..10_000_000); - let token = random_string("phc_", 12); - let team = Team { - id: id, - name: "team".to_string(), - api_token: token, - }; - - let serialized_team = serde_json::to_string(&team)?; - client - .set( - format!("{TEAM_TOKEN_CACHE_PREFIX}{}", team.api_token.clone()), - serialized_team, - ) - .await?; - - Ok(team) - } #[tokio::test] async fn test_fetch_team_from_redis() { - let client = RedisClient::new("redis://localhost:6379/".to_string()) - .expect("Failed to create redis client"); - let client = Arc::new(client); + let client = setup_redis_client(None); let team = insert_new_team_in_redis(client.clone()).await.unwrap(); @@ -121,10 +86,12 @@ mod tests { #[tokio::test] async fn test_fetch_invalid_team_from_redis() { - let client = RedisClient::new("redis://localhost:6379/".to_string()) - .expect("Failed to create redis client"); - let client = Arc::new(client); + let client = setup_redis_client(None); + // TODO: It's not ideal that this can fail on random errors like connection refused. + // Is there a way to be more specific throughout this code? + // Or maybe I shouldn't be mapping conn refused to token validation error, and instead handling it as a + // top level 500 error instead of 400 right now. match Team::from_redis(client.clone(), "banana".to_string()).await { Err(FlagError::TokenValidationError) => (), _ => panic!("Expected TokenValidationError"), diff --git a/feature-flags/src/test_utils.rs b/feature-flags/src/test_utils.rs new file mode 100644 index 0000000..1a91c8b --- /dev/null +++ b/feature-flags/src/test_utils.rs @@ -0,0 +1,43 @@ +use std::sync::Arc; +use anyhow::Error; + +use crate::{redis::{Client, RedisClient}, team::{self, Team}}; +use rand::{distributions::Alphanumeric, Rng}; + +pub fn random_string(prefix: &str, length: usize) -> String { + let suffix: String = rand::thread_rng() + .sample_iter(Alphanumeric) + .take(length) + .map(char::from) + .collect(); + format!("{}{}", prefix, suffix) +} + +pub async fn insert_new_team_in_redis(client: Arc) -> Result { + let id = rand::thread_rng().gen_range(0..10_000_000); + let token = random_string("phc_", 12); + let team = Team { + id: id, + name: "team".to_string(), + api_token: token, + }; + + let serialized_team = serde_json::to_string(&team)?; + client + .set( + format!("{}{}", team::TEAM_TOKEN_CACHE_PREFIX, team.api_token.clone()), + serialized_team, + ) + .await?; + + Ok(team) +} + +pub fn setup_redis_client(url: Option) -> Arc { + let redis_url = match url { + Some(value) => value, + None => "redis://localhost:6379/".to_string(), + }; + let client = RedisClient::new(redis_url).expect("Failed to create redis client"); + Arc::new(client) +} \ No newline at end of file diff --git a/feature-flags/tests/common.rs b/feature-flags/tests/common/mod.rs similarity index 77% rename from feature-flags/tests/common.rs rename to feature-flags/tests/common/mod.rs index f66a11f..5a63285 100644 --- a/feature-flags/tests/common.rs +++ b/feature-flags/tests/common/mod.rs @@ -4,8 +4,7 @@ use std::string::ToString; use std::sync::Arc; use once_cell::sync::Lazy; -use rand::distributions::Alphanumeric; -use rand::Rng; +use reqwest::header::CONTENT_TYPE; use tokio::net::TcpListener; use tokio::sync::Notify; @@ -44,6 +43,18 @@ impl ServerHandle { client .post(format!("http://{:?}/flags", self.addr)) .body(body) + .header(CONTENT_TYPE, "application/json") + .send() + .await + .expect("failed to send request") + } + + pub async fn send_invalid_header_for_flags_request>(&self, body: T) -> reqwest::Response { + let client = reqwest::Client::new(); + client + .post(format!("http://{:?}/flags", self.addr)) + .body(body) + .header(CONTENT_TYPE, "xyz") .send() .await .expect("failed to send request") @@ -55,12 +66,3 @@ impl Drop for ServerHandle { self.shutdown.notify_one() } } - -pub fn random_string(prefix: &str, length: usize) -> String { - let suffix: String = rand::thread_rng() - .sample_iter(Alphanumeric) - .take(length) - .map(char::from) - .collect(); - format!("{}_{}", prefix, suffix) -} diff --git a/feature-flags/tests/test_flags.rs b/feature-flags/tests/test_flags.rs index 82f41f0..5302ea9 100644 --- a/feature-flags/tests/test_flags.rs +++ b/feature-flags/tests/test_flags.rs @@ -5,14 +5,20 @@ use reqwest::StatusCode; use serde_json::{json, Value}; use crate::common::*; -mod common; + +use feature_flags::test_utils::{insert_new_team_in_redis, setup_redis_client}; + +pub mod common; #[tokio::test] async fn it_sends_flag_request() -> Result<()> { - let token = random_string("token", 16); + let config = DEFAULT_CONFIG.clone(); + let distinct_id = "user_distinct_id".to_string(); - let config = DEFAULT_CONFIG.clone(); + let client = setup_redis_client(Some(config.redis_url.clone())); + let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + let token = team.api_token; let server = ServerHandle::for_config(config).await; @@ -41,3 +47,33 @@ async fn it_sends_flag_request() -> Result<()> { Ok(()) } + + +#[tokio::test] +async fn it_rejects_invalid_headers_flag_request() -> Result<()> { + let config = DEFAULT_CONFIG.clone(); + + let distinct_id = "user_distinct_id".to_string(); + + let client = setup_redis_client(Some(config.redis_url.clone())); + let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + let token = team.api_token; + + let server = ServerHandle::for_config(config).await; + + let payload = json!({ + "token": token, + "distinct_id": distinct_id, + "groups": {"group1": "group1"} + }); + let res = server.send_invalid_header_for_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::BAD_REQUEST, res.status()); + + // We don't want to deserialize the data into a flagResponse struct here, + // because we want to assert the shape of the raw json data. + let response_text = res.text().await?; + + assert_eq!(response_text, "failed to decode request: unsupported content type: xyz"); + + Ok(()) +} \ No newline at end of file From 838dd2c471ac4ddeecd51021e7b1e8ea5ab4a45d Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Thu, 9 May 2024 11:19:05 +0100 Subject: [PATCH 4/6] lint --- feature-flags/src/lib.rs | 2 +- feature-flags/src/redis.rs | 5 +++-- feature-flags/src/team.rs | 26 +++++++++----------------- feature-flags/src/test_utils.rs | 15 +++++++++++---- feature-flags/src/v0_endpoint.rs | 4 +++- feature-flags/src/v0_request.rs | 9 ++++++--- feature-flags/tests/common/mod.rs | 5 ++++- feature-flags/tests/test_flags.rs | 12 ++++++++---- 8 files changed, 45 insertions(+), 33 deletions(-) diff --git a/feature-flags/src/lib.rs b/feature-flags/src/lib.rs index c9d07cc..195a55c 100644 --- a/feature-flags/src/lib.rs +++ b/feature-flags/src/lib.rs @@ -3,9 +3,9 @@ pub mod config; pub mod redis; pub mod router; pub mod server; +pub mod team; pub mod v0_endpoint; pub mod v0_request; -pub mod team; // Test modules don't need to be compiled with main binary // #[cfg(test)] diff --git a/feature-flags/src/redis.rs b/feature-flags/src/redis.rs index 70b7146..3f6dd7f 100644 --- a/feature-flags/src/redis.rs +++ b/feature-flags/src/redis.rs @@ -49,7 +49,8 @@ impl Client for RedisClient { let results = conn.get(k.clone()); // TODO: Is this safe? Should we be doing something else for error handling here? - let fut: Result, RedisError> = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?; + let fut: Result, RedisError> = + timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?; // TRICKY: We serialise data to json, then django pickles it. // Here we deserialize the bytes using serde_pickle, to get the json string. @@ -70,4 +71,4 @@ impl Client for RedisClient { Ok(fut?) } -} \ No newline at end of file +} diff --git a/feature-flags/src/team.rs b/feature-flags/src/team.rs index d55aa93..54f7318 100644 --- a/feature-flags/src/team.rs +++ b/feature-flags/src/team.rs @@ -5,7 +5,6 @@ use crate::{api::FlagError, redis::Client}; use serde::{Deserialize, Serialize}; use tracing::instrument; - // TRICKY: I'm still not sure where the :1: is coming from. // The Django prefix is `posthog` only. // It's from here: https://docs.djangoproject.com/en/4.2/topics/cache/#cache-versioning @@ -31,26 +30,22 @@ pub struct Team { impl Team { /// Validates a token, and returns a team if it exists. - /// - + #[instrument(skip_all)] pub async fn from_redis( client: Arc, token: String, ) -> Result { - // TODO: Instead of failing here, i.e. if not in redis, fallback to pg let serialized_team = client - .get( - format!("{TEAM_TOKEN_CACHE_PREFIX}{}", token) - ) + .get(format!("{TEAM_TOKEN_CACHE_PREFIX}{}", token)) .await .map_err(|e| { tracing::error!("failed to fetch data: {}", e); // TODO: Can be other errors if serde_pickle destructuring fails? FlagError::TokenValidationError })?; - + let team: Team = serde_json::from_str(&serialized_team).map_err(|e| { tracing::error!("failed to parse data to team: {}", e); // TODO: Internal error, shouldn't send back to client @@ -63,9 +58,8 @@ impl Team { #[cfg(test)] mod tests { - use crate::test_utils::{insert_new_team_in_redis, setup_redis_client}; use super::*; - + use crate::test_utils::{insert_new_team_in_redis, setup_redis_client}; #[tokio::test] async fn test_fetch_team_from_redis() { @@ -75,13 +69,11 @@ mod tests { let target_token = team.api_token; - let team_from_redis = Team::from_redis(client.clone(), target_token.clone()).await.unwrap(); - assert_eq!( - team_from_redis.api_token, target_token - ); - assert_eq!( - team_from_redis.id, team.id - ); + let team_from_redis = Team::from_redis(client.clone(), target_token.clone()) + .await + .unwrap(); + assert_eq!(team_from_redis.api_token, target_token); + assert_eq!(team_from_redis.id, team.id); } #[tokio::test] diff --git a/feature-flags/src/test_utils.rs b/feature-flags/src/test_utils.rs index 1a91c8b..1604079 100644 --- a/feature-flags/src/test_utils.rs +++ b/feature-flags/src/test_utils.rs @@ -1,7 +1,10 @@ -use std::sync::Arc; use anyhow::Error; +use std::sync::Arc; -use crate::{redis::{Client, RedisClient}, team::{self, Team}}; +use crate::{ + redis::{Client, RedisClient}, + team::{self, Team}, +}; use rand::{distributions::Alphanumeric, Rng}; pub fn random_string(prefix: &str, length: usize) -> String { @@ -25,7 +28,11 @@ pub async fn insert_new_team_in_redis(client: Arc) -> Result) -> Arc { }; let client = RedisClient::new(redis_url).expect("Failed to create redis client"); Arc::new(client) -} \ No newline at end of file +} diff --git a/feature-flags/src/v0_endpoint.rs b/feature-flags/src/v0_endpoint.rs index 4a46d45..bbd7ff3 100644 --- a/feature-flags/src/v0_endpoint.rs +++ b/feature-flags/src/v0_endpoint.rs @@ -71,7 +71,9 @@ pub async fn flags( } }?; - let token = request.extract_and_verify_token(state.redis.clone()).await?; + let token = request + .extract_and_verify_token(state.redis.clone()) + .await?; tracing::Span::current().record("token", &token); diff --git a/feature-flags/src/v0_request.rs b/feature-flags/src/v0_request.rs index 2954b2e..f75ef56 100644 --- a/feature-flags/src/v0_request.rs +++ b/feature-flags/src/v0_request.rs @@ -54,17 +54,20 @@ impl FlagRequest { Ok(serde_json::from_str::(&payload)?) } - pub async fn extract_and_verify_token(&self, redis_client: Arc) -> Result { + pub async fn extract_and_verify_token( + &self, + redis_client: Arc, + ) -> Result { let token = match self { FlagRequest { token: Some(token), .. } => token.to_string(), _ => return Err(FlagError::NoTokenError), }; - + let team = Team::from_redis(redis_client, token.clone()).await?; - // TODO: Remove this, is useless, doing just for now because + // TODO: Remove this, is useless, doing just for now because tracing::Span::current().record("team_id", &team.id); Ok(token) } diff --git a/feature-flags/tests/common/mod.rs b/feature-flags/tests/common/mod.rs index 5a63285..c8644fe 100644 --- a/feature-flags/tests/common/mod.rs +++ b/feature-flags/tests/common/mod.rs @@ -49,7 +49,10 @@ impl ServerHandle { .expect("failed to send request") } - pub async fn send_invalid_header_for_flags_request>(&self, body: T) -> reqwest::Response { + pub async fn send_invalid_header_for_flags_request>( + &self, + body: T, + ) -> reqwest::Response { let client = reqwest::Client::new(); client .post(format!("http://{:?}/flags", self.addr)) diff --git a/feature-flags/tests/test_flags.rs b/feature-flags/tests/test_flags.rs index 5302ea9..2ceba24 100644 --- a/feature-flags/tests/test_flags.rs +++ b/feature-flags/tests/test_flags.rs @@ -48,7 +48,6 @@ async fn it_sends_flag_request() -> Result<()> { Ok(()) } - #[tokio::test] async fn it_rejects_invalid_headers_flag_request() -> Result<()> { let config = DEFAULT_CONFIG.clone(); @@ -66,14 +65,19 @@ async fn it_rejects_invalid_headers_flag_request() -> Result<()> { "distinct_id": distinct_id, "groups": {"group1": "group1"} }); - let res = server.send_invalid_header_for_flags_request(payload.to_string()).await; + let res = server + .send_invalid_header_for_flags_request(payload.to_string()) + .await; assert_eq!(StatusCode::BAD_REQUEST, res.status()); // We don't want to deserialize the data into a flagResponse struct here, // because we want to assert the shape of the raw json data. let response_text = res.text().await?; - assert_eq!(response_text, "failed to decode request: unsupported content type: xyz"); + assert_eq!( + response_text, + "failed to decode request: unsupported content type: xyz" + ); Ok(()) -} \ No newline at end of file +} From ad04232e0aeaf505a26c7776ff4ee84e1136d7f4 Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Thu, 9 May 2024 14:15:37 +0100 Subject: [PATCH 5/6] clean up --- feature-flags/src/api.rs | 13 ++-- feature-flags/src/redis.rs | 35 +++++++++-- feature-flags/src/team.rs | 105 ++++++++++++++++++++++--------- feature-flags/src/test_utils.rs | 2 +- feature-flags/src/v0_endpoint.rs | 3 + feature-flags/src/v0_request.rs | 80 ++++++++++++++++++++--- 6 files changed, 192 insertions(+), 46 deletions(-) diff --git a/feature-flags/src/api.rs b/feature-flags/src/api.rs index ebad1f5..ccf4735 100644 --- a/feature-flags/src/api.rs +++ b/feature-flags/src/api.rs @@ -25,9 +25,6 @@ pub enum FlagError { #[error("failed to parse request: {0}")] RequestParsingError(#[from] serde_json::Error), - #[error("failed to parse redis data: {0}")] - DataParsingError(#[from] serde_pickle::Error), - #[error("Empty distinct_id in request")] EmptyDistinctId, #[error("No distinct_id in request")] @@ -40,6 +37,11 @@ pub enum FlagError { #[error("rate limited")] RateLimited, + + #[error("failed to parse redis cache data")] + DataParsingError, + #[error("redis unavailable")] + RedisUnavailable, } impl IntoResponse for FlagError { @@ -47,7 +49,6 @@ impl IntoResponse for FlagError { match self { FlagError::RequestDecodingError(_) | FlagError::RequestParsingError(_) - | FlagError::DataParsingError(_) | FlagError::EmptyDistinctId | FlagError::MissingDistinctId => (StatusCode::BAD_REQUEST, self.to_string()), @@ -56,6 +57,10 @@ impl IntoResponse for FlagError { } FlagError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, self.to_string()), + + FlagError::DataParsingError | FlagError::RedisUnavailable => { + (StatusCode::SERVICE_UNAVAILABLE, self.to_string()) + } } .into_response() } diff --git a/feature-flags/src/redis.rs b/feature-flags/src/redis.rs index 3f6dd7f..3aeec47 100644 --- a/feature-flags/src/redis.rs +++ b/feature-flags/src/redis.rs @@ -3,11 +3,26 @@ use std::time::Duration; use anyhow::Result; use async_trait::async_trait; use redis::{AsyncCommands, RedisError}; +use thiserror::Error; use tokio::time::timeout; // average for all commands is <10ms, check grafana const REDIS_TIMEOUT_MILLISECS: u64 = 10; +#[derive(Error, Debug)] +pub enum CustomRedisError { + #[error("Not found in redis")] + NotFound, + + #[error("Pickle error: {0}")] + PickleError(#[from] serde_pickle::Error), + + #[error("Redis error: {0}")] + Other(#[from] RedisError), + + #[error("Timeout error")] + Timeout(#[from] tokio::time::error::Elapsed), +} /// A simple redis wrapper /// Copied from capture/src/redis.rs. /// TODO: Modify this to support hincrby @@ -17,7 +32,7 @@ pub trait Client { // A very simplified wrapper, but works for our usage async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result>; - async fn get(&self, k: String) -> Result; + async fn get(&self, k: String) -> Result; async fn set(&self, k: String, v: String) -> Result<()>; } @@ -44,14 +59,26 @@ impl Client for RedisClient { Ok(fut?) } - async fn get(&self, k: String) -> Result { + // TODO: Ask Xavier if there's a better way to handle this. + // The problem: I want to match on the error type from this function, and do appropriate things like 400 or 500 response. + // Buuut, if I use anyhow::Error, I can't reverse-coerce into a NotFound or serde_pickle::Error. + // Thus, I need to create a custom error enum of all possible errors + my own custom not found, so I can match on it. + // Is this the canonical way? + async fn get(&self, k: String) -> Result { let mut conn = self.client.get_async_connection().await?; - let results = conn.get(k.clone()); - // TODO: Is this safe? Should we be doing something else for error handling here? + let results = conn.get(k); let fut: Result, RedisError> = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?; + // return NotFound error when empty or not found + if match &fut { + Ok(v) => v.is_empty(), + Err(_) => false, + } { + return Err(CustomRedisError::NotFound); + } + // TRICKY: We serialise data to json, then django pickles it. // Here we deserialize the bytes using serde_pickle, to get the json string. let string_response: String = serde_pickle::from_slice(&fut?, Default::default())?; diff --git a/feature-flags/src/team.rs b/feature-flags/src/team.rs index 54f7318..ac62ea9 100644 --- a/feature-flags/src/team.rs +++ b/feature-flags/src/team.rs @@ -1,26 +1,16 @@ -use std::sync::Arc; - -use crate::{api::FlagError, redis::Client}; - use serde::{Deserialize, Serialize}; +use std::sync::Arc; use tracing::instrument; -// TRICKY: I'm still not sure where the :1: is coming from. -// The Django prefix is `posthog` only. -// It's from here: https://docs.djangoproject.com/en/4.2/topics/cache/#cache-versioning -// F&!£%% on the bright side we don't use this functionality yet. -// Will rely on integration tests to catch this. +use crate::{ + api::FlagError, + redis::{Client, CustomRedisError}, +}; + +// TRICKY: This cache data is coming from django-redis. If it ever goes out of sync, we'll bork. +// TODO: Add integration tests across repos to ensure this doesn't happen. pub const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:"; -// TODO: Check what happens if json has extra stuff, does serde ignore it? Yes -// Make sure we don't serialize and store team data in redis. Let main decide endpoint control this... -// and track misses. Revisit if this becomes an issue. -// because otherwise very annoying to keep this in sync with main django which has a lot of extra fields we need here. -// will lead to inconsistent behaviour. -// This is turning out to be very annoying, because we have django key prefixes to be mindful of as well. -// Wonder if it would be better to make these caches independent? This generates that new problem of CRUD happening in Django, -// which needs to update this cache immediately, so they can't really ever be independent. -// True for both team cache and flags cache. Hmm. Just I guess need to add tests around the key prefixes... #[derive(Debug, Deserialize, Serialize)] pub struct Team { pub id: i64, @@ -40,16 +30,21 @@ impl Team { let serialized_team = client .get(format!("{TEAM_TOKEN_CACHE_PREFIX}{}", token)) .await - .map_err(|e| { - tracing::error!("failed to fetch data: {}", e); - // TODO: Can be other errors if serde_pickle destructuring fails? - FlagError::TokenValidationError + .map_err(|e| match e { + CustomRedisError::NotFound => FlagError::TokenValidationError, + CustomRedisError::PickleError(_) => { + tracing::error!("failed to fetch data: {}", e); + FlagError::DataParsingError + } + _ => { + tracing::error!("Unknown redis error: {}", e); + FlagError::RedisUnavailable + } })?; let team: Team = serde_json::from_str(&serialized_team).map_err(|e| { tracing::error!("failed to parse data to team: {}", e); - // TODO: Internal error, shouldn't send back to client - FlagError::RequestParsingError(e) + FlagError::DataParsingError })?; Ok(team) @@ -58,8 +53,14 @@ impl Team { #[cfg(test)] mod tests { + use rand::Rng; + use redis::AsyncCommands; + use super::*; - use crate::test_utils::{insert_new_team_in_redis, setup_redis_client}; + use crate::{ + team, + test_utils::{insert_new_team_in_redis, random_string, setup_redis_client}, + }; #[tokio::test] async fn test_fetch_team_from_redis() { @@ -80,13 +81,59 @@ mod tests { async fn test_fetch_invalid_team_from_redis() { let client = setup_redis_client(None); - // TODO: It's not ideal that this can fail on random errors like connection refused. - // Is there a way to be more specific throughout this code? - // Or maybe I shouldn't be mapping conn refused to token validation error, and instead handling it as a - // top level 500 error instead of 400 right now. match Team::from_redis(client.clone(), "banana".to_string()).await { Err(FlagError::TokenValidationError) => (), _ => panic!("Expected TokenValidationError"), }; } + + #[tokio::test] + async fn test_cant_connect_to_redis_error_is_not_token_validation_error() { + let client = setup_redis_client(Some("redis://localhost:1111/".to_string())); + + match Team::from_redis(client.clone(), "banana".to_string()).await { + Err(FlagError::RedisUnavailable) => (), + _ => panic!("Expected RedisUnavailable"), + }; + } + + #[tokio::test] + async fn test_corrupted_data_in_redis_is_handled() { + // TODO: Extend this test with fallback to pg + let id = rand::thread_rng().gen_range(0..10_000_000); + let token = random_string("phc_", 12); + let team = Team { + id, + name: "team".to_string(), + api_token: token, + }; + let serialized_team = serde_json::to_string(&team).expect("Failed to serialise team"); + + // manually insert non-pickled data in redis + let client = + redis::Client::open("redis://localhost:6379/").expect("Failed to create redis client"); + let mut conn = client + .get_async_connection() + .await + .expect("Failed to get redis connection"); + conn.set::( + format!( + "{}{}", + team::TEAM_TOKEN_CACHE_PREFIX, + team.api_token.clone() + ), + serialized_team, + ) + .await + .expect("Failed to write data to redis"); + + // now get client connection for data + let client = setup_redis_client(None); + + match Team::from_redis(client.clone(), team.api_token.clone()).await { + Err(FlagError::DataParsingError) => (), + Err(other) => panic!("Expected DataParsingError, got {:?}", other), + Ok(_) => panic!("Expected DataParsingError"), + }; + } } diff --git a/feature-flags/src/test_utils.rs b/feature-flags/src/test_utils.rs index 1604079..75db86d 100644 --- a/feature-flags/src/test_utils.rs +++ b/feature-flags/src/test_utils.rs @@ -20,7 +20,7 @@ pub async fn insert_new_team_in_redis(client: Arc) -> Result Result { tracing::debug!(len = bytes.len(), "decoding new request"); @@ -65,10 +62,77 @@ impl FlagRequest { _ => return Err(FlagError::NoTokenError), }; - let team = Team::from_redis(redis_client, token.clone()).await?; + // validate token + Team::from_redis(redis_client, token.clone()).await?; + + // TODO: fallback when token not found in redis - // TODO: Remove this, is useless, doing just for now because - tracing::Span::current().record("team_id", &team.id); Ok(token) } + + pub fn extract_distinct_id(&self) -> Result { + let distinct_id = match &self.distinct_id { + None => return Err(FlagError::MissingDistinctId), + Some(id) => id, + }; + + match distinct_id.len() { + 0 => Err(FlagError::EmptyDistinctId), + 1..=200 => Ok(distinct_id.to_owned()), + _ => Ok(distinct_id.chars().take(200).collect()), + } + } +} + +#[cfg(test)] +mod tests { + use crate::api::FlagError; + use crate::v0_request::FlagRequest; + use bytes::Bytes; + use serde_json::json; + + #[test] + fn empty_distinct_id_not_accepted() { + let json = json!({ + "distinct_id": "", + "token": "my_token1", + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + match flag_payload.extract_distinct_id() { + Err(FlagError::EmptyDistinctId) => (), + _ => panic!("expected empty distinct id error"), + }; + } + + #[test] + fn too_large_distinct_id_is_truncated() { + let json = json!({ + "distinct_id": std::iter::repeat("a").take(210).collect::(), + "token": "my_token1", + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + assert_eq!(flag_payload.extract_distinct_id().unwrap().len(), 200); + } + + #[test] + fn distinct_id_is_returned_correctly() { + let json = json!({ + "$distinct_id": "alakazam", + "token": "my_token1", + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + match flag_payload.extract_distinct_id() { + Ok(id) => assert_eq!(id, "alakazam"), + _ => panic!("expected distinct id"), + }; + } } From 2ca642a085fda6980279ca7148886de2f28114df Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Wed, 29 May 2024 14:23:39 +0100 Subject: [PATCH 6/6] address comment --- feature-flags/src/redis.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/feature-flags/src/redis.rs b/feature-flags/src/redis.rs index 3aeec47..89dde42 100644 --- a/feature-flags/src/redis.rs +++ b/feature-flags/src/redis.rs @@ -59,11 +59,6 @@ impl Client for RedisClient { Ok(fut?) } - // TODO: Ask Xavier if there's a better way to handle this. - // The problem: I want to match on the error type from this function, and do appropriate things like 400 or 500 response. - // Buuut, if I use anyhow::Error, I can't reverse-coerce into a NotFound or serde_pickle::Error. - // Thus, I need to create a custom error enum of all possible errors + my own custom not found, so I can match on it. - // Is this the canonical way? async fn get(&self, k: String) -> Result { let mut conn = self.client.get_async_connection().await?;