diff --git a/feature-flags/src/api.rs b/feature-flags/src/api.rs index ebad1f5..ccf4735 100644 --- a/feature-flags/src/api.rs +++ b/feature-flags/src/api.rs @@ -25,9 +25,6 @@ pub enum FlagError { #[error("failed to parse request: {0}")] RequestParsingError(#[from] serde_json::Error), - #[error("failed to parse redis data: {0}")] - DataParsingError(#[from] serde_pickle::Error), - #[error("Empty distinct_id in request")] EmptyDistinctId, #[error("No distinct_id in request")] @@ -40,6 +37,11 @@ pub enum FlagError { #[error("rate limited")] RateLimited, + + #[error("failed to parse redis cache data")] + DataParsingError, + #[error("redis unavailable")] + RedisUnavailable, } impl IntoResponse for FlagError { @@ -47,7 +49,6 @@ impl IntoResponse for FlagError { match self { FlagError::RequestDecodingError(_) | FlagError::RequestParsingError(_) - | FlagError::DataParsingError(_) | FlagError::EmptyDistinctId | FlagError::MissingDistinctId => (StatusCode::BAD_REQUEST, self.to_string()), @@ -56,6 +57,10 @@ impl IntoResponse for FlagError { } FlagError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, self.to_string()), + + FlagError::DataParsingError | FlagError::RedisUnavailable => { + (StatusCode::SERVICE_UNAVAILABLE, self.to_string()) + } } .into_response() } diff --git a/feature-flags/src/redis.rs b/feature-flags/src/redis.rs index 3f6dd7f..3aeec47 100644 --- a/feature-flags/src/redis.rs +++ b/feature-flags/src/redis.rs @@ -3,11 +3,26 @@ use std::time::Duration; use anyhow::Result; use async_trait::async_trait; use redis::{AsyncCommands, RedisError}; +use thiserror::Error; use tokio::time::timeout; // average for all commands is <10ms, check grafana const REDIS_TIMEOUT_MILLISECS: u64 = 10; +#[derive(Error, Debug)] +pub enum CustomRedisError { + #[error("Not found in redis")] + NotFound, + + #[error("Pickle error: {0}")] + PickleError(#[from] serde_pickle::Error), + + #[error("Redis error: {0}")] + Other(#[from] RedisError), + + #[error("Timeout error")] + Timeout(#[from] tokio::time::error::Elapsed), +} /// A simple redis wrapper /// Copied from capture/src/redis.rs. /// TODO: Modify this to support hincrby @@ -17,7 +32,7 @@ pub trait Client { // A very simplified wrapper, but works for our usage async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result>; - async fn get(&self, k: String) -> Result; + async fn get(&self, k: String) -> Result; async fn set(&self, k: String, v: String) -> Result<()>; } @@ -44,14 +59,26 @@ impl Client for RedisClient { Ok(fut?) } - async fn get(&self, k: String) -> Result { + // TODO: Ask Xavier if there's a better way to handle this. + // The problem: I want to match on the error type from this function, and do appropriate things like 400 or 500 response. + // Buuut, if I use anyhow::Error, I can't reverse-coerce into a NotFound or serde_pickle::Error. + // Thus, I need to create a custom error enum of all possible errors + my own custom not found, so I can match on it. + // Is this the canonical way? + async fn get(&self, k: String) -> Result { let mut conn = self.client.get_async_connection().await?; - let results = conn.get(k.clone()); - // TODO: Is this safe? Should we be doing something else for error handling here? + let results = conn.get(k); let fut: Result, RedisError> = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?; + // return NotFound error when empty or not found + if match &fut { + Ok(v) => v.is_empty(), + Err(_) => false, + } { + return Err(CustomRedisError::NotFound); + } + // TRICKY: We serialise data to json, then django pickles it. // Here we deserialize the bytes using serde_pickle, to get the json string. let string_response: String = serde_pickle::from_slice(&fut?, Default::default())?; diff --git a/feature-flags/src/team.rs b/feature-flags/src/team.rs index 54f7318..ac62ea9 100644 --- a/feature-flags/src/team.rs +++ b/feature-flags/src/team.rs @@ -1,26 +1,16 @@ -use std::sync::Arc; - -use crate::{api::FlagError, redis::Client}; - use serde::{Deserialize, Serialize}; +use std::sync::Arc; use tracing::instrument; -// TRICKY: I'm still not sure where the :1: is coming from. -// The Django prefix is `posthog` only. -// It's from here: https://docs.djangoproject.com/en/4.2/topics/cache/#cache-versioning -// F&!£%% on the bright side we don't use this functionality yet. -// Will rely on integration tests to catch this. +use crate::{ + api::FlagError, + redis::{Client, CustomRedisError}, +}; + +// TRICKY: This cache data is coming from django-redis. If it ever goes out of sync, we'll bork. +// TODO: Add integration tests across repos to ensure this doesn't happen. pub const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:"; -// TODO: Check what happens if json has extra stuff, does serde ignore it? Yes -// Make sure we don't serialize and store team data in redis. Let main decide endpoint control this... -// and track misses. Revisit if this becomes an issue. -// because otherwise very annoying to keep this in sync with main django which has a lot of extra fields we need here. -// will lead to inconsistent behaviour. -// This is turning out to be very annoying, because we have django key prefixes to be mindful of as well. -// Wonder if it would be better to make these caches independent? This generates that new problem of CRUD happening in Django, -// which needs to update this cache immediately, so they can't really ever be independent. -// True for both team cache and flags cache. Hmm. Just I guess need to add tests around the key prefixes... #[derive(Debug, Deserialize, Serialize)] pub struct Team { pub id: i64, @@ -40,16 +30,21 @@ impl Team { let serialized_team = client .get(format!("{TEAM_TOKEN_CACHE_PREFIX}{}", token)) .await - .map_err(|e| { - tracing::error!("failed to fetch data: {}", e); - // TODO: Can be other errors if serde_pickle destructuring fails? - FlagError::TokenValidationError + .map_err(|e| match e { + CustomRedisError::NotFound => FlagError::TokenValidationError, + CustomRedisError::PickleError(_) => { + tracing::error!("failed to fetch data: {}", e); + FlagError::DataParsingError + } + _ => { + tracing::error!("Unknown redis error: {}", e); + FlagError::RedisUnavailable + } })?; let team: Team = serde_json::from_str(&serialized_team).map_err(|e| { tracing::error!("failed to parse data to team: {}", e); - // TODO: Internal error, shouldn't send back to client - FlagError::RequestParsingError(e) + FlagError::DataParsingError })?; Ok(team) @@ -58,8 +53,14 @@ impl Team { #[cfg(test)] mod tests { + use rand::Rng; + use redis::AsyncCommands; + use super::*; - use crate::test_utils::{insert_new_team_in_redis, setup_redis_client}; + use crate::{ + team, + test_utils::{insert_new_team_in_redis, random_string, setup_redis_client}, + }; #[tokio::test] async fn test_fetch_team_from_redis() { @@ -80,13 +81,59 @@ mod tests { async fn test_fetch_invalid_team_from_redis() { let client = setup_redis_client(None); - // TODO: It's not ideal that this can fail on random errors like connection refused. - // Is there a way to be more specific throughout this code? - // Or maybe I shouldn't be mapping conn refused to token validation error, and instead handling it as a - // top level 500 error instead of 400 right now. match Team::from_redis(client.clone(), "banana".to_string()).await { Err(FlagError::TokenValidationError) => (), _ => panic!("Expected TokenValidationError"), }; } + + #[tokio::test] + async fn test_cant_connect_to_redis_error_is_not_token_validation_error() { + let client = setup_redis_client(Some("redis://localhost:1111/".to_string())); + + match Team::from_redis(client.clone(), "banana".to_string()).await { + Err(FlagError::RedisUnavailable) => (), + _ => panic!("Expected RedisUnavailable"), + }; + } + + #[tokio::test] + async fn test_corrupted_data_in_redis_is_handled() { + // TODO: Extend this test with fallback to pg + let id = rand::thread_rng().gen_range(0..10_000_000); + let token = random_string("phc_", 12); + let team = Team { + id, + name: "team".to_string(), + api_token: token, + }; + let serialized_team = serde_json::to_string(&team).expect("Failed to serialise team"); + + // manually insert non-pickled data in redis + let client = + redis::Client::open("redis://localhost:6379/").expect("Failed to create redis client"); + let mut conn = client + .get_async_connection() + .await + .expect("Failed to get redis connection"); + conn.set::( + format!( + "{}{}", + team::TEAM_TOKEN_CACHE_PREFIX, + team.api_token.clone() + ), + serialized_team, + ) + .await + .expect("Failed to write data to redis"); + + // now get client connection for data + let client = setup_redis_client(None); + + match Team::from_redis(client.clone(), team.api_token.clone()).await { + Err(FlagError::DataParsingError) => (), + Err(other) => panic!("Expected DataParsingError, got {:?}", other), + Ok(_) => panic!("Expected DataParsingError"), + }; + } } diff --git a/feature-flags/src/test_utils.rs b/feature-flags/src/test_utils.rs index 1604079..75db86d 100644 --- a/feature-flags/src/test_utils.rs +++ b/feature-flags/src/test_utils.rs @@ -20,7 +20,7 @@ pub async fn insert_new_team_in_redis(client: Arc) -> Result Result { tracing::debug!(len = bytes.len(), "decoding new request"); @@ -65,10 +62,77 @@ impl FlagRequest { _ => return Err(FlagError::NoTokenError), }; - let team = Team::from_redis(redis_client, token.clone()).await?; + // validate token + Team::from_redis(redis_client, token.clone()).await?; + + // TODO: fallback when token not found in redis - // TODO: Remove this, is useless, doing just for now because - tracing::Span::current().record("team_id", &team.id); Ok(token) } + + pub fn extract_distinct_id(&self) -> Result { + let distinct_id = match &self.distinct_id { + None => return Err(FlagError::MissingDistinctId), + Some(id) => id, + }; + + match distinct_id.len() { + 0 => Err(FlagError::EmptyDistinctId), + 1..=200 => Ok(distinct_id.to_owned()), + _ => Ok(distinct_id.chars().take(200).collect()), + } + } +} + +#[cfg(test)] +mod tests { + use crate::api::FlagError; + use crate::v0_request::FlagRequest; + use bytes::Bytes; + use serde_json::json; + + #[test] + fn empty_distinct_id_not_accepted() { + let json = json!({ + "distinct_id": "", + "token": "my_token1", + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + match flag_payload.extract_distinct_id() { + Err(FlagError::EmptyDistinctId) => (), + _ => panic!("expected empty distinct id error"), + }; + } + + #[test] + fn too_large_distinct_id_is_truncated() { + let json = json!({ + "distinct_id": std::iter::repeat("a").take(210).collect::(), + "token": "my_token1", + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + assert_eq!(flag_payload.extract_distinct_id().unwrap().len(), 200); + } + + #[test] + fn distinct_id_is_returned_correctly() { + let json = json!({ + "$distinct_id": "alakazam", + "token": "my_token1", + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + match flag_payload.extract_distinct_id() { + Ok(id) => assert_eq!(id, "alakazam"), + _ => panic!("expected distinct id"), + }; + } }