diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 117d4bae8a717..fa17aa0e7e504 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1165,9 +1165,12 @@ dependencies = [ "async-trait", "axum 0.7.5", "axum-client-ip", + "base64 0.22.0", "bytes", "common-alloc", "envconfig", + "flate2", + "futures", "maxminddb", "once_cell", "rand", @@ -1216,7 +1219,7 @@ checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" dependencies = [ "futures-core", "futures-sink", - "spin 0.9.8", + "spin", ] [[package]] @@ -2057,11 +2060,11 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ - "spin 0.5.2", + "spin", ] [[package]] @@ -3223,7 +3226,7 @@ dependencies = [ "cfg-if", "getrandom", "libc", - "spin 0.9.8", + "spin", "untrusted", "windows-sys 0.52.0", ] @@ -3604,12 +3607,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "spin" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" - [[package]] name = "spin" version = "0.9.8" diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index 3d898dfdbfa72..9bb1cd2bf7a23 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -31,6 +31,8 @@ regex = "1.10.4" maxminddb = "0.17" sqlx = { workspace = true } uuid = { workspace = true } +base64.workspace = true +flate2.workspace = true common-alloc = { path = "../common/alloc" } [lints] @@ -39,4 +41,4 @@ workspace = true [dev-dependencies] assert-json-diff = { workspace = true } reqwest = { workspace = true } - +futures = "0.3.30" diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index 285e09edc5c77..2bf8f265e30ae 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -96,6 +96,8 @@ pub enum FlagError { DatabaseUnavailable, #[error("Timed out while fetching data")] TimeoutError, + #[error("No group type mappings")] + NoGroupTypeMappings, } impl IntoResponse for FlagError { @@ -167,6 +169,13 @@ impl IntoResponse for FlagError { "The request timed out. This could be due to high load or network issues. Please try again later.".to_string(), ) } + FlagError::NoGroupTypeMappings => { + tracing::error!("No group type mappings: {:?}", self); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "No group type mappings found. This is likely a configuration issue. Please contact support.".to_string(), + ) + } } .into_response() } diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index df0b1998cd1bf..7cf66c88f7b99 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -7,11 +7,6 @@ use tracing::instrument; // TODO: Add integration tests across repos to ensure this doesn't happen. pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_"; -// TODO: Hmm, revisit when dealing with groups, but seems like -// ideal to just treat it as a u8 and do our own validation on top -#[derive(Debug, Deserialize)] -pub enum GroupTypeIndex {} - #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum OperatorType { @@ -42,7 +37,7 @@ pub struct PropertyFilter { pub operator: Option, #[serde(rename = "type")] pub prop_type: String, - pub group_type_index: Option, + pub group_type_index: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -68,7 +63,7 @@ pub struct MultivariateFlagOptions { pub struct FlagFilters { pub groups: Vec, pub multivariate: Option, - pub aggregation_group_type_index: Option, + pub aggregation_group_type_index: Option, pub payloads: Option, pub super_groups: Option>, } @@ -101,7 +96,7 @@ pub struct FeatureFlagRow { } impl FeatureFlag { - pub fn get_group_type_index(&self) -> Option { + pub fn get_group_type_index(&self) -> Option { self.filters.aggregation_group_type_index } diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 88911c90bb7be..d4487874a2081 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,49 +1,168 @@ use crate::{ - api::FlagError, + api::{FlagError, FlagValue, FlagsResponse}, database::Client as DatabaseClient, - flag_definitions::{FeatureFlag, FlagGroupType, PropertyFilter}, + flag_definitions::{FeatureFlag, FeatureFlagList, FlagGroupType, PropertyFilter}, property_matching::match_property, }; +use anyhow::Result; use serde_json::Value; use sha1::{Digest, Sha1}; -use std::{collections::HashMap, fmt::Write, sync::Arc}; +use sqlx::FromRow; +use std::collections::{HashMap, HashSet}; +use std::fmt::Write; +use std::sync::Arc; +use tracing::error; + +type TeamId = i32; +type DatabaseClientArc = Arc; +type GroupTypeIndex = i32; #[derive(Debug, PartialEq, Eq)] pub struct FeatureFlagMatch { pub matches: bool, pub variant: Option, - //reason - //condition_index - //payload } -#[derive(Debug, sqlx::FromRow)] -pub struct Person { - pub properties: sqlx::types::Json>, +#[derive(Debug, FromRow)] +pub struct GroupTypeMapping { + pub group_type: String, + pub group_type_index: GroupTypeIndex, +} + +/// This struct is a cache for group type mappings, which are stored in a DB. We use these mappings +/// to look up group names based on the group aggregation indices stored on flag filters, which lets us +/// perform group property matching. We cache them per request so that we can perform multiple flag evaluations +/// without needing to fetch the mappings from the DB each time. +/// Typically, the mappings look like this: +/// +/// let group_types = vec![ +/// ("project", 0), +/// ("organization", 1), +/// ("instance", 2), +/// ("customer", 3), +/// ("team", 4), ]; +/// +/// But for backwards compatibility, we also support whatever mappings may lie in the table. +/// These mappings are ingested via the plugin server. +#[derive(Clone)] +pub struct GroupTypeMappingCache { + team_id: TeamId, + failed_to_fetch_flags: bool, + group_types_to_indexes: HashMap, + group_indexes_to_types: HashMap, + database_client: DatabaseClientArc, +} + +impl GroupTypeMappingCache { + pub fn new(team_id: TeamId, database_client: DatabaseClientArc) -> Self { + GroupTypeMappingCache { + team_id, + failed_to_fetch_flags: false, + group_types_to_indexes: HashMap::new(), + group_indexes_to_types: HashMap::new(), + database_client, + } + } + + pub async fn group_type_to_group_type_index_map( + &mut self, + ) -> Result, FlagError> { + if self.failed_to_fetch_flags { + return Err(FlagError::DatabaseUnavailable); + } + + if !self.group_types_to_indexes.is_empty() { + return Ok(self.group_types_to_indexes.clone()); + } + + let database_client = self.database_client.clone(); + let team_id = self.team_id; + let mapping = match self + .fetch_group_type_mapping(database_client, team_id) + .await + { + Ok(mapping) if !mapping.is_empty() => mapping, + Ok(_) => { + self.failed_to_fetch_flags = true; + return Err(FlagError::NoGroupTypeMappings); + } + Err(e) => { + self.failed_to_fetch_flags = true; + return Err(e); + } + }; + self.group_types_to_indexes = mapping.clone(); + + Ok(mapping) + } + + pub async fn group_type_index_to_group_type_map( + &mut self, + ) -> Result, FlagError> { + if !self.group_indexes_to_types.is_empty() { + return Ok(self.group_indexes_to_types.clone()); + } + + let types_to_indexes = self.group_type_to_group_type_index_map().await?; + let result: HashMap = + types_to_indexes.into_iter().map(|(k, v)| (v, k)).collect(); + + if !result.is_empty() { + self.group_indexes_to_types = result.clone(); + Ok(result) + } else { + Err(FlagError::NoGroupTypeMappings) + } + } + + async fn fetch_group_type_mapping( + &mut self, + database_client: DatabaseClientArc, + team_id: TeamId, + ) -> Result, FlagError> { + let mut conn = database_client.as_ref().get_connection().await?; + + let query = r#" + SELECT group_type, group_type_index + FROM posthog_grouptypemapping + WHERE team_id = $1 + "#; + + let rows = sqlx::query_as::<_, GroupTypeMapping>(query) + .bind(team_id) + .fetch_all(&mut *conn) + .await?; + + let mapping: HashMap = rows + .into_iter() + .map(|row| (row.group_type, row.group_type_index)) + .collect(); + + if mapping.is_empty() { + Err(FlagError::NoGroupTypeMappings) + } else { + Ok(mapping) + } + } +} + +/// This struct is a cache for group and person properties fetched from the database. +/// We cache them per request so that we can perform multiple flag evaluations without needing +/// to fetch the properties from the DB each time. +#[derive(Clone, Default, Debug)] +pub struct PropertiesCache { + person_properties: Option>, + group_properties: HashMap>, } -// TODO: Rework FeatureFlagMatcher - python has a pretty awkward interface, where we pass in all flags, and then again -// the flag to match. I don't think there's any reason anymore to store the flags in the matcher, since we can just -// pass the flag to match directly to the get_match method. This will also make the matcher more stateless. -// Potentially, we could also make the matcher a long-lived object, with caching for group keys and such. -// It just takes in the flag and distinct_id and returns the match... -// Or, make this fully stateless -// and have a separate cache struct for caching group keys, cohort definitions, etc. - and check size, if we can keep it in memory -// for all teams. If not, we can have a LRU cache, or a cache that stores only the most recent N keys. -// But, this can be a future refactor, for now just focusing on getting the basic matcher working, write lots and lots of tests -// and then we can easily refactor stuff around. -// #[derive(Debug)] +#[derive(Clone)] pub struct FeatureFlagMatcher { - // pub flags: Vec, pub distinct_id: String, - pub database_client: Option>, - // TODO do I need cached_properties, or do I get them from the request? - // like, in python I get them from the request. Hmm. Let me try that. - // OH, or is this the FlagMatcherCache. Yeah, so this is the flag matcher cache - cached_properties: Option>, - person_property_overrides: Option>, - // TODO handle group properties - // group_property_overrides: Option>>, + pub team_id: TeamId, + pub database_client: DatabaseClientArc, + group_type_mapping_cache: GroupTypeMappingCache, + properties_cache: PropertiesCache, + groups: HashMap, } const LONG_SCALE: u64 = 0xfffffffffffffff; @@ -51,25 +170,201 @@ const LONG_SCALE: u64 = 0xfffffffffffffff; impl FeatureFlagMatcher { pub fn new( distinct_id: String, - database_client: Option>, - person_property_overrides: Option>, - // group_property_overrides: Option>>, + team_id: TeamId, + database_client: DatabaseClientArc, + group_type_mapping_cache: Option, + properties_cache: Option, + groups: Option>, ) -> Self { FeatureFlagMatcher { - // flags, distinct_id, - database_client, - cached_properties: None, - person_property_overrides, - // group_property_overrides, + team_id, + database_client: database_client.clone(), + group_type_mapping_cache: group_type_mapping_cache + .unwrap_or_else(|| GroupTypeMappingCache::new(team_id, database_client.clone())), + properties_cache: properties_cache.unwrap_or_default(), + groups: groups.unwrap_or_default(), + } + } + + /// Evaluate feature flags for a given distinct_id + /// - Returns a map of feature flag keys to their values + /// - If an error occurs while evaluating a flag, it will be logged and the flag will be omitted from the result + pub async fn evaluate_feature_flags( + &mut self, + feature_flags: FeatureFlagList, + person_property_overrides: Option>, + group_property_overrides: Option>>, + ) -> FlagsResponse { + let mut result = HashMap::new(); + let mut error_while_computing_flags = false; + let mut flags_needing_db_properties = Vec::new(); + + // Step 1: Evaluate flags that can be resolved with overrides + for flag in &feature_flags.flags { + if !flag.active || flag.deleted { + continue; + } + + match self + .match_flag_with_overrides( + flag, + &person_property_overrides, + &group_property_overrides, + ) + .await + { + Ok(Some(flag_match)) => { + let flag_value = self.flag_match_to_value(&flag_match); + result.insert(flag.key.clone(), flag_value); + } + Ok(None) => { + flags_needing_db_properties.push(flag.clone()); + } + // We had overrides, but couldn't evaluate the flag + Err(e) => { + error_while_computing_flags = true; + error!( + "Error evaluating feature flag '{}' with overrides for distinct_id '{}': {:?}", + flag.key, self.distinct_id, e + ); + } + } + } + + // Step 2: Fetch and cache properties for remaining flags + if !flags_needing_db_properties.is_empty() { + let group_type_indexes: HashSet = flags_needing_db_properties + .iter() + .filter_map(|flag| flag.get_group_type_index()) + .collect(); + + let database_client = self.database_client.clone(); + let distinct_id = self.distinct_id.clone(); + let team_id = self.team_id; + + match fetch_and_locally_cache_all_properties( + &mut self.properties_cache, + database_client, + distinct_id, + team_id, + &group_type_indexes, + ) + .await + { + Ok(_) => {} + Err(e) => { + error_while_computing_flags = true; + error!("Error fetching properties: {:?}", e); + } + } + + // Step 3: Evaluate remaining flags + for flag in flags_needing_db_properties { + match self.get_match(&flag, None).await { + Ok(flag_match) => { + let flag_value = self.flag_match_to_value(&flag_match); + result.insert(flag.key.clone(), flag_value); + } + Err(e) => { + error_while_computing_flags = true; + error!( + "Error evaluating feature flag '{}' for distinct_id '{}': {:?}", + flag.key, self.distinct_id, e + ); + } + } + } + } + + FlagsResponse { + error_while_computing_flags, + feature_flags: result, + } + } + + async fn match_flag_with_overrides( + &mut self, + flag: &FeatureFlag, + person_property_overrides: &Option>, + group_property_overrides: &Option>>, + ) -> Result, FlagError> { + let flag_property_filters: Vec = flag + .get_conditions() + .iter() + .flat_map(|c| c.properties.clone().unwrap_or_default()) + .collect(); + + let overrides = match flag.get_group_type_index() { + Some(group_type_index) => { + self.get_group_overrides( + group_type_index, + group_property_overrides, + &flag_property_filters, + ) + .await? + } + None => self.get_person_overrides(person_property_overrides, &flag_property_filters), + }; + + match overrides { + Some(props) => self.get_match(flag, Some(props)).await.map(Some), + None => Ok(None), + } + } + + async fn get_group_overrides( + &mut self, + group_type_index: GroupTypeIndex, + group_property_overrides: &Option>>, + flag_property_filters: &[PropertyFilter], + ) -> Result>, FlagError> { + let index_to_type_map = self + .group_type_mapping_cache + .group_type_index_to_group_type_map() + .await?; + + if let Some(group_type) = index_to_type_map.get(&group_type_index) { + if let Some(group_overrides) = group_property_overrides { + if let Some(group_overrides_by_type) = group_overrides.get(group_type) { + return Ok(locally_computable_property_overrides( + &Some(group_overrides_by_type.clone()), + flag_property_filters, + )); + } + } + } + + Ok(None) + } + + fn get_person_overrides( + &self, + person_property_overrides: &Option>, + flag_property_filters: &[PropertyFilter], + ) -> Option> { + person_property_overrides.as_ref().and_then(|overrides| { + locally_computable_property_overrides(&Some(overrides.clone()), flag_property_filters) + }) + } + + fn flag_match_to_value(&self, flag_match: &FeatureFlagMatch) -> FlagValue { + if flag_match.matches { + match &flag_match.variant { + Some(variant) => FlagValue::String(variant.clone()), + None => FlagValue::Boolean(true), + } + } else { + FlagValue::Boolean(false) } } pub async fn get_match( &mut self, - feature_flag: &FeatureFlag, + flag: &FeatureFlag, + property_overrides: Option>, ) -> Result { - if self.hashed_identifier(feature_flag).is_none() { + if self.hashed_identifier(flag).await?.is_empty() { return Ok(FeatureFlagMatch { matches: false, variant: None, @@ -79,26 +374,23 @@ impl FeatureFlagMatcher { // TODO: super groups for early access // TODO: Variant overrides condition sort - for (index, condition) in feature_flag.get_conditions().iter().enumerate() { + for condition in flag.get_conditions().iter() { let (is_match, _evaluation_reason) = self - .is_condition_match(feature_flag, condition, index) + .is_condition_match(flag, condition, property_overrides.clone()) .await?; if is_match { // TODO: this is a bit awkward, we should only handle variants when overrides exist let variant = match condition.variant.clone() { - Some(variant_override) => { - if feature_flag + Some(variant_override) + if flag .get_variants() .iter() - .any(|v| v.key == variant_override) - { - Some(variant_override) - } else { - self.get_matching_variant(feature_flag) - } + .any(|v| v.key == variant_override) => + { + Some(variant_override) } - None => self.get_matching_variant(feature_flag), + _ => self.get_matching_variant(flag).await?, }; return Ok(FeatureFlagMatch { @@ -107,103 +399,127 @@ impl FeatureFlagMatcher { }); } } + Ok(FeatureFlagMatch { matches: false, variant: None, }) } - fn check_rollout(&self, feature_flag: &FeatureFlag, rollout_percentage: f64) -> (bool, String) { - if rollout_percentage == 100.0 - || self.get_hash(feature_flag, "") <= (rollout_percentage / 100.0) - { - (true, "CONDITION_MATCH".to_string()) - } else { - (false, "OUT_OF_ROLLOUT_BOUND".to_string()) - } - } - - // TODO: Making all this mutable just to store a cached value is annoying. Can I refactor this to be non-mutable? - // Leaning a bit more towards a separate cache store for this. - pub async fn is_condition_match( + async fn is_condition_match( &mut self, feature_flag: &FeatureFlag, condition: &FlagGroupType, - _index: usize, + property_overrides: Option>, ) -> Result<(bool, String), FlagError> { let rollout_percentage = condition.rollout_percentage.unwrap_or(100.0); - if let Some(properties) = &condition.properties { - if properties.is_empty() { - return Ok(self.check_rollout(feature_flag, rollout_percentage)); + + if let Some(flag_property_filters) = &condition.properties { + if flag_property_filters.is_empty() { + return self.check_rollout(feature_flag, rollout_percentage).await; } - let target_properties = self.get_target_properties(feature_flag, properties).await?; + let properties_to_check = + // Group-based flag + if let Some(group_type_index) = feature_flag.get_group_type_index() { + if let Some(local_overrides) = locally_computable_property_overrides( + &property_overrides.clone(), + flag_property_filters, + ) { + local_overrides + } else { + self.get_group_properties_from_cache_or_db(group_type_index) + .await? + } + } else { + // Person-based flag + if let Some(person_overrides) = property_overrides { + if let Some(local_overrides) = locally_computable_property_overrides( + &Some(person_overrides), + flag_property_filters, + ) { + local_overrides + } else { + self.get_person_properties_from_cache_or_db().await? + } + } else { + // We hit this block if there are no overrides AND we know it's not a group-based flag + self.get_person_properties_from_cache_or_db().await? + } + }; + + let properties_match = + all_properties_match(flag_property_filters, &properties_to_check); - if !self.all_properties_match(properties, &target_properties) { + if !properties_match { return Ok((false, "NO_CONDITION_MATCH".to_string())); } } - Ok(self.check_rollout(feature_flag, rollout_percentage)) + self.check_rollout(feature_flag, rollout_percentage).await } - async fn get_target_properties( + async fn get_group_properties_from_cache_or_db( &mut self, - feature_flag: &FeatureFlag, - properties: &[PropertyFilter], + group_type_index: GroupTypeIndex, ) -> Result, FlagError> { - self.get_person_properties(feature_flag.team_id, properties) - .await - // TODO handle group properties, will go something like this - // if let Some(group_index) = feature_flag.get_group_type_index() { - // self.get_group_properties(feature_flag.team_id, group_index, properties) - // } else { - // self.get_person_properties(feature_flag.team_id, properties) - // .await - // } + if let Some(properties) = self + .properties_cache + .group_properties + .get(&group_type_index) + .cloned() + { + return Ok(properties); + } + + let database_client = self.database_client.clone(); + let team_id = self.team_id; + let db_properties = + fetch_group_properties_from_db(database_client, team_id, group_type_index).await?; + + self.properties_cache + .group_properties + .insert(group_type_index, db_properties.clone()); + + Ok(db_properties) } - async fn get_person_properties( + async fn get_person_properties_from_cache_or_db( &mut self, - team_id: i32, - properties: &[PropertyFilter], ) -> Result, FlagError> { - if let Some(person_overrides) = &self.person_property_overrides { - // Check if all required properties are present in the overrides - // and none of them are of type "cohort" - let should_prefer_overrides = properties - .iter() - .all(|prop| person_overrides.contains_key(&prop.key) && prop.prop_type != "cohort"); - - if should_prefer_overrides { - // TODO let's count how often this happens - return Ok(person_overrides.clone()); - } + if let Some(properties) = &self.properties_cache.person_properties { + return Ok(properties.clone()); } - // If we don't prefer the overrides (they're either not present, don't contain enough properties to evaluate the condition, - // or contain a cohort property), fall back to getting properties from cache or DB - self.get_person_properties_from_cache_or_db(team_id, self.distinct_id.clone()) - .await - } + let database_client = self.database_client.clone(); + let distinct_id = self.distinct_id.clone(); + let team_id = self.team_id; + let db_properties = + fetch_person_properties_from_db(database_client, distinct_id, team_id).await?; - fn all_properties_match( - &self, - condition_properties: &[PropertyFilter], - target_properties: &HashMap, - ) -> bool { - condition_properties - .iter() - .all(|property| match_property(property, target_properties, false).unwrap_or(false)) + self.properties_cache.person_properties = Some(db_properties.clone()); + + Ok(db_properties) } - pub fn hashed_identifier(&self, feature_flag: &FeatureFlag) -> Option { - if feature_flag.get_group_type_index().is_none() { - // TODO: Use hash key overrides for experience continuity - Some(self.distinct_id.clone()) + async fn hashed_identifier(&mut self, feature_flag: &FeatureFlag) -> Result { + // TODO: Use hash key overrides for experience continuity + + if let Some(group_type_index) = feature_flag.get_group_type_index() { + // Group-based flag + let group_key = self + .group_type_mapping_cache + .group_type_index_to_group_type_map() + .await? + .get(&group_type_index) + .and_then(|group_type_name| self.groups.get(group_type_name)) + .cloned() + .unwrap_or_default(); + + Ok(group_key.to_string()) } else { - // TODO: Handle getting group key - Some("".to_string()) + // Person-based flag + Ok(self.distinct_id.clone()) } } @@ -211,11 +527,13 @@ impl FeatureFlagMatcher { /// Given the same identifier and key, it'll always return the same float. These floats are /// uniformly distributed between 0 and 1, so if we want to show this feature to 20% of traffic /// we can do _hash(key, identifier) < 0.2 - pub fn get_hash(&self, feature_flag: &FeatureFlag, salt: &str) -> f64 { - // check if hashed_identifier is None - let hashed_identifier = self - .hashed_identifier(feature_flag) - .expect("hashed_identifier is None when computing hash"); + async fn get_hash(&mut self, feature_flag: &FeatureFlag, salt: &str) -> Result { + let hashed_identifier = self.hashed_identifier(feature_flag).await?; + if hashed_identifier.is_empty() { + // Return a hash value that will make the flag evaluate to false + // TODO make this cleaner – we should have a way to return a default value + return Ok(0.0); + } let hash_key = format!("{}.{}{}", feature_flag.key, hashed_identifier, salt); let mut hasher = Sha1::new(); hasher.update(hash_key.as_bytes()); @@ -228,101 +546,218 @@ impl FeatureFlagMatcher { .to_string(); let hash_val = u64::from_str_radix(&hex_str, 16).unwrap(); - hash_val as f64 / LONG_SCALE as f64 + Ok(hash_val as f64 / LONG_SCALE as f64) } - /// This function takes a feature flag and returns the key of the variant that should be shown to the user. - pub fn get_matching_variant(&self, feature_flag: &FeatureFlag) -> Option { - let hash = self.get_hash(feature_flag, "variant"); - let mut total_percentage = 0.0; - - for variant in feature_flag.get_variants() { - total_percentage += variant.rollout_percentage / 100.0; - if hash < total_percentage { - return Some(variant.key.clone()); - } + async fn check_rollout( + &mut self, + feature_flag: &FeatureFlag, + rollout_percentage: f64, + ) -> Result<(bool, String), FlagError> { + if rollout_percentage == 100.0 + || self.get_hash(feature_flag, "").await? <= (rollout_percentage / 100.0) + { + Ok((true, "CONDITION_MATCH".to_string())) // TODO enum, I'll implement this when I implement evaluation reasons + } else { + Ok((false, "OUT_OF_ROLLOUT_BOUND".to_string())) // TODO enum, I'll implement this when I implement evaluation reasons } - None } /// This function takes a feature flag and returns the key of the variant that should be shown to the user. - pub async fn get_person_properties_from_cache_or_db( + async fn get_matching_variant( &mut self, - team_id: i32, - distinct_id: String, - ) -> Result, FlagError> { - // TODO: Do we even need to cache here anymore? - // Depends on how often we're calling this function - // to match all flags for a single person - - // TODO which of these properties do we need to cache? - if let Some(cached_props) = self.cached_properties.clone() { - // TODO: Maybe we don't want to copy around all user properties, this will by far be the largest chunk - // of data we're copying around. Can we work with references here? - // Worst case, just use a Rc. - return Ok(cached_props); - } + feature_flag: &FeatureFlag, + ) -> Result, FlagError> { + let hash = self.get_hash(feature_flag, "variant").await?; + let mut cumulative_percentage = 0.0; - if self.database_client.is_none() { - return Err(FlagError::DatabaseUnavailable); + for variant in feature_flag.get_variants() { + cumulative_percentage += variant.rollout_percentage / 100.0; + if hash < cumulative_percentage { + return Ok(Some(variant.key.clone())); + } } + Ok(None) + } +} - let mut conn = self - .database_client - .as_ref() - .expect("client should exist here") - .get_connection() - .await?; +async fn fetch_and_locally_cache_all_properties( + properties_cache: &mut PropertiesCache, + database_client: DatabaseClientArc, + distinct_id: String, + team_id: TeamId, + group_type_indexes: &HashSet, +) -> Result<(), FlagError> { + let mut conn = database_client.as_ref().get_connection().await?; - let query = r#" - SELECT "posthog_person"."properties" - FROM "posthog_person" - INNER JOIN "posthog_persondistinctid" ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") - WHERE ("posthog_persondistinctid"."distinct_id" = $1 + let query = r#" + SELECT + (SELECT "posthog_person"."properties" + FROM "posthog_person" + INNER JOIN "posthog_persondistinctid" + ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") + WHERE ("posthog_persondistinctid"."distinct_id" = $1 AND "posthog_persondistinctid"."team_id" = $2 - AND "posthog_person"."team_id" = $3) - LIMIT 1; - "#; + AND "posthog_person"."team_id" = $2) + LIMIT 1) as person_properties, + + (SELECT json_object_agg("posthog_group"."group_type_index", "posthog_group"."group_properties") + FROM "posthog_group" + WHERE ("posthog_group"."team_id" = $2 + AND "posthog_group"."group_type_index" = ANY($3))) as group_properties + "#; - let row = sqlx::query_as::<_, Person>(query) - .bind(&distinct_id) - .bind(team_id) - .bind(team_id) - .fetch_optional(&mut *conn) - .await?; + let group_type_indexes_vec: Vec = group_type_indexes.iter().cloned().collect(); - let props = match row { - Some(row) => row.properties.0, - None => HashMap::new(), - }; + let row: (Option, Option) = sqlx::query_as(query) + .bind(&distinct_id) + .bind(team_id) + .bind(&group_type_indexes_vec) + .fetch_optional(&mut *conn) + .await? + .unwrap_or((None, None)); + + if let Some(person_props) = row.0 { + properties_cache.person_properties = Some( + person_props + .as_object() + .unwrap_or(&serde_json::Map::new()) + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(), + ); + } - self.cached_properties = Some(props.clone()); + if let Some(group_props) = row.1 { + let group_props_map: HashMap> = group_props + .as_object() + .unwrap_or(&serde_json::Map::new()) + .iter() + .map(|(k, v)| { + let group_type_index = k.parse().unwrap_or_default(); + let properties: HashMap = v + .as_object() + .unwrap_or(&serde_json::Map::new()) + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + (group_type_index, properties) + }) + .collect(); - Ok(props) + properties_cache.group_properties.extend(group_props_map); } - // async fn get_group_properties_from_cache_or_db( - // &self, - // team_id: i32, - // group_index: usize, - // properties: &Vec, - // ) -> HashMap { - // todo!() - // } + Ok(()) +} + +async fn fetch_person_properties_from_db( + database_client: DatabaseClientArc, + distinct_id: String, + team_id: TeamId, +) -> Result, FlagError> { + let mut conn = database_client.as_ref().get_connection().await?; + + let query = r#" + SELECT "posthog_person"."properties" as person_properties + FROM "posthog_person" + INNER JOIN "posthog_persondistinctid" ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") + WHERE ("posthog_persondistinctid"."distinct_id" = $1 + AND "posthog_persondistinctid"."team_id" = $2 + AND "posthog_person"."team_id" = $2) + LIMIT 1 + "#; + + let row: Option = sqlx::query_scalar(query) + .bind(&distinct_id) + .bind(team_id) + .fetch_optional(&mut *conn) + .await?; + + Ok(row + .and_then(|v| v.as_object().cloned()) + .unwrap_or_default() + .into_iter() + .map(|(k, v)| (k, v.clone())) + .collect()) +} + +async fn fetch_group_properties_from_db( + database_client: DatabaseClientArc, + team_id: TeamId, + group_type_index: GroupTypeIndex, +) -> Result, FlagError> { + let mut conn = database_client.as_ref().get_connection().await?; + + let query = r#" + SELECT "posthog_group"."group_properties" + FROM "posthog_group" + WHERE ("posthog_group"."team_id" = $1 + AND "posthog_group"."group_type_index" = $2) + LIMIT 1 + "#; + + let row: Option = sqlx::query_scalar(query) + .bind(team_id) + .bind(group_type_index) + .fetch_optional(&mut *conn) + .await?; + + Ok(row + .and_then(|v| v.as_object().cloned()) + .unwrap_or_default() + .into_iter() + .map(|(k, v)| (k, v.clone())) + .collect()) +} + +/// Check if all required properties are present in the overrides +/// and none of them are of type "cohort" – if so, return the overrides, +/// otherwise return None, because we can't locally compute cohort properties +fn locally_computable_property_overrides( + property_overrides: &Option>, + property_filters: &[PropertyFilter], +) -> Option> { + property_overrides.as_ref().and_then(|overrides| { + // TODO handle note from Neil: https://github.com/PostHog/posthog/pull/24589#discussion_r1735828561 + // TL;DR – we'll need to handle cohort properties at the DB level, i.e. we'll need to adjust the cohort query + // to account for if a given person is an element of the cohort X, Y, Z, etc + let should_prefer_overrides = property_filters + .iter() + .all(|prop| overrides.contains_key(&prop.key) && prop.prop_type != "cohort"); + + if should_prefer_overrides { + Some(overrides.clone()) + } else { + None + } + }) +} + +/// Check if all properties match the given filters +fn all_properties_match( + flag_condition_properties: &[PropertyFilter], + target_properties: &HashMap, +) -> bool { + flag_condition_properties + .iter() + .all(|property| match_property(property, target_properties, false).unwrap_or(false)) } #[cfg(test)] mod tests { - use serde_json::json; + use std::collections::HashMap; use super::*; use crate::{ - flag_definitions::{FlagFilters, MultivariateFlagOptions, MultivariateFlagVariant}, + flag_definitions::{ + FlagFilters, MultivariateFlagOptions, MultivariateFlagVariant, OperatorType, + }, test_utils::{insert_new_team_in_pg, insert_person_for_team_in_pg, setup_pg_client}, }; - fn create_test_flag(team_id: i32, properties: Vec) -> FeatureFlag { + fn create_test_flag(team_id: TeamId, properties: Vec) -> FeatureFlag { FeatureFlag { id: 1, team_id, @@ -347,20 +782,20 @@ mod tests { #[tokio::test] async fn test_fetch_properties_from_pg_to_match() { - let client = setup_pg_client(None).await; + let database_client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()) + let team = insert_new_team_in_pg(database_client.clone()) .await .expect("Failed to insert team in pg"); let distinct_id = "user_distinct_id".to_string(); - insert_person_for_team_in_pg(client.clone(), team.id, distinct_id.clone(), None) + insert_person_for_team_in_pg(database_client.clone(), team.id, distinct_id.clone(), None) .await .expect("Failed to insert person"); let not_matching_distinct_id = "not_matching_distinct_id".to_string(); insert_person_for_team_in_pg( - client.clone(), + database_client.clone(), team.id, not_matching_distinct_id.clone(), Some(json!({ "email": "a@x.com"})), @@ -368,7 +803,7 @@ mod tests { .await .expect("Failed to insert person"); - let flag = serde_json::from_value(json!( + let flag: FeatureFlag = serde_json::from_value(json!( { "id": 1, "team_id": team.id, @@ -392,30 +827,49 @@ mod tests { )) .unwrap(); - let mut matcher = FeatureFlagMatcher::new(distinct_id, Some(client.clone()), None); - let match_result = matcher.get_match(&flag).await.unwrap(); + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + database_client.clone(), + None, + None, + None, + ); + let match_result = matcher.get_match(&flag, None).await.unwrap(); assert_eq!(match_result.matches, true); assert_eq!(match_result.variant, None); - // property value is different - let mut matcher = - FeatureFlagMatcher::new(not_matching_distinct_id, Some(client.clone()), None); - let match_result = matcher.get_match(&flag).await.unwrap(); + let mut matcher = FeatureFlagMatcher::new( + not_matching_distinct_id.clone(), + team.id, + database_client.clone(), + None, + None, + None, + ); + let match_result = matcher.get_match(&flag, None).await.unwrap(); assert_eq!(match_result.matches, false); assert_eq!(match_result.variant, None); - // person does not exist - let mut matcher = - FeatureFlagMatcher::new("other_distinct_id".to_string(), Some(client.clone()), None); - let match_result = matcher.get_match(&flag).await.unwrap(); + let mut matcher = FeatureFlagMatcher::new( + "other_distinct_id".to_string(), + team.id, + database_client.clone(), + None, + None, + None, + ); + let match_result = matcher.get_match(&flag, None).await.unwrap(); assert_eq!(match_result.matches, false); assert_eq!(match_result.variant, None); } #[tokio::test] async fn test_person_property_overrides() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); let flag = create_test_flag( team.id, @@ -423,7 +877,7 @@ mod tests { key: "email".to_string(), value: json!("override@example.com"), operator: None, - prop_type: "email".to_string(), + prop_type: "person".to_string(), group_type_index: None, }], ); @@ -432,44 +886,182 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), - Some(client.clone()), - Some(overrides), + team.id, + database_client, + None, + None, + None, ); - let match_result = matcher.get_match(&flag).await.unwrap(); - assert_eq!(match_result.matches, true); + let flags = FeatureFlagList { + flags: vec![flag.clone()], + }; + let result = matcher + .evaluate_feature_flags(flags, Some(overrides), None) + .await; + + assert!(!result.error_while_computing_flags); + assert_eq!( + result.feature_flags.get("test_flag"), + Some(&FlagValue::Boolean(true)) + ); } - #[test] - fn test_hashed_identifier() { - let flag = create_test_flag(1, vec![]); + #[tokio::test] + async fn test_group_property_overrides() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + let mut flag = create_test_flag( + team.id, + vec![PropertyFilter { + key: "industry".to_string(), + value: json!("tech"), + operator: None, + prop_type: "group".to_string(), + group_type_index: Some(1), + }], + ); + + flag.filters.aggregation_group_type_index = Some(1); + + let mut cache = GroupTypeMappingCache::new(team.id, database_client.clone()); + cache.group_types_to_indexes = [("organization".to_string(), 1)].into_iter().collect(); + cache.group_indexes_to_types = [(1, "organization".to_string())].into_iter().collect(); - let matcher = FeatureFlagMatcher::new("test_user".to_string(), None, None); + let groups = HashMap::from([("organization".to_string(), json!("org_123"))]); + + let group_overrides = HashMap::from([( + "organization".to_string(), + HashMap::from([ + ("industry".to_string(), json!("tech")), + ("$group_key".to_string(), json!("org_123")), + ]), + )]); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + database_client, + Some(cache), + None, + Some(groups), + ); + + let flags = FeatureFlagList { + flags: vec![flag.clone()], + }; + let result = matcher + .evaluate_feature_flags(flags, None, Some(group_overrides)) + .await; + + assert!(!result.error_while_computing_flags); assert_eq!( - matcher.hashed_identifier(&flag), - Some("test_user".to_string()) + result.feature_flags.get("test_flag"), + Some(&FlagValue::Boolean(true)) ); + } + + #[tokio::test] + async fn test_get_matching_variant_with_cache() { + let flag = create_test_flag_with_variants(1); + let database_client = setup_pg_client(None).await; + + let mut cache = GroupTypeMappingCache::new(1, database_client.clone()); + + let group_types_to_indexes = [("group_type_1".to_string(), 1)].into_iter().collect(); + let group_type_index_to_name = [(1, "group_type_1".to_string())].into_iter().collect(); + + cache.group_types_to_indexes = group_types_to_indexes; + cache.group_indexes_to_types = group_type_index_to_name; - // Test with a group type index (this part of the functionality is not implemented yet) - // let mut group_flag = flag.clone(); - // group_flag.filters.aggregation_group_type_index = Some(1); - // assert_eq!(matcher.hashed_identifier(&group_flag), Some("".to_string())); + let groups = HashMap::from([("group_type_1".to_string(), json!("group_key_1"))]); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + 1, + database_client, + Some(cache), + None, + Some(groups), + ); + let variant = matcher.get_matching_variant(&flag).await.unwrap(); + assert!(variant.is_some(), "No variant was selected"); + assert!( + ["control", "test", "test2"].contains(&variant.unwrap().as_str()), + "Selected variant is not one of the expected options" + ); } - #[test] - fn test_get_matching_variant() { - let flag = FeatureFlag { - id: 1, - team_id: 1, - name: Some("Test Flag".to_string()), - key: "test_flag".to_string(), - filters: FlagFilters { - groups: vec![], - multivariate: Some(MultivariateFlagOptions { - variants: vec![ - MultivariateFlagVariant { - name: Some("Control".to_string()), - key: "control".to_string(), + #[tokio::test] + async fn test_get_matching_variant_with_db() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + let flag = create_test_flag_with_variants(team.id); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + database_client, + None, + None, + None, + ); + + let variant = matcher.get_matching_variant(&flag).await.unwrap(); + assert!(variant.is_some()); + assert!(["control", "test", "test2"].contains(&variant.unwrap().as_str())); + } + + #[tokio::test] + async fn test_is_condition_match_empty_properties() { + let database_client = setup_pg_client(None).await; + let flag = create_test_flag(1, vec![]); + + let condition = FlagGroupType { + variant: None, + properties: Some(vec![]), + rollout_percentage: Some(100.0), + }; + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + 1, + database_client, + None, + None, + None, + ); + let (is_match, reason) = matcher + .is_condition_match(&flag, &condition, None) + .await + .unwrap(); + assert_eq!(is_match, true); + assert_eq!(reason, "CONDITION_MATCH"); + } + + fn create_test_flag_with_variants(team_id: TeamId) -> FeatureFlag { + FeatureFlag { + id: 1, + team_id, + name: Some("Test Flag".to_string()), + key: "test_flag".to_string(), + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: None, + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: Some(MultivariateFlagOptions { + variants: vec![ + MultivariateFlagVariant { + name: Some("Control".to_string()), + key: "control".to_string(), rollout_percentage: 33.0, }, MultivariateFlagVariant { @@ -484,37 +1076,680 @@ mod tests { }, ], }), - aggregation_group_type_index: None, + aggregation_group_type_index: Some(1), payloads: None, super_groups: None, }, deleted: false, active: true, ensure_experience_continuity: false, - }; + } + } - let matcher = FeatureFlagMatcher::new("test_user".to_string(), None, None); - let variant = matcher.get_matching_variant(&flag); - assert!(variant.is_some()); - assert!(["control", "test", "test2"].contains(&variant.unwrap().as_str())); + #[tokio::test] + async fn test_overrides_avoid_db_lookups() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + let flag = create_test_flag( + team.id, + vec![PropertyFilter { + key: "email".to_string(), + value: json!("test@example.com"), + operator: Some(OperatorType::Exact), + prop_type: "person".to_string(), + group_type_index: None, + }], + ); + + let person_property_overrides = + HashMap::from([("email".to_string(), json!("test@example.com"))]); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + database_client.clone(), + None, + None, + None, + ); + + let result = matcher + .evaluate_feature_flags( + FeatureFlagList { + flags: vec![flag.clone()], + }, + Some(person_property_overrides), + None, + ) + .await; + + assert!(!result.error_while_computing_flags); + assert_eq!( + result.feature_flags.get("test_flag"), + Some(&FlagValue::Boolean(true)) + ); + + let cache = &matcher.properties_cache; + assert!(cache.person_properties.is_none()); } #[tokio::test] - async fn test_is_condition_match_empty_properties() { + async fn test_fallback_to_db_when_overrides_insufficient() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + let flag = create_test_flag( + team.id, + vec![ + PropertyFilter { + key: "email".to_string(), + value: json!("test@example.com"), + operator: Some(OperatorType::Exact), + prop_type: "person".to_string(), + group_type_index: None, + }, + PropertyFilter { + key: "age".to_string(), + value: json!(25), + operator: Some(OperatorType::Gte), + prop_type: "person".to_string(), + group_type_index: None, + }, + ], + ); + + let person_property_overrides = Some(HashMap::from([( + "email".to_string(), + json!("test@example.com"), + )])); + + insert_person_for_team_in_pg( + database_client.clone(), + team.id, + "test_user".to_string(), + Some(json!({"email": "test@example.com", "age": 30})), + ) + .await + .unwrap(); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + database_client.clone(), + None, + None, + None, + ); + + let result = matcher + .get_match(&flag, person_property_overrides.clone()) + .await + .unwrap(); + + assert!(result.matches); + + let cache = &matcher.properties_cache; + assert!(cache.person_properties.is_some()); + assert_eq!( + cache.person_properties.as_ref().unwrap().get("age"), + Some(&json!(30)) + ); + } + + #[tokio::test] + async fn test_property_fetching_and_caching() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + let distinct_id = "test_user".to_string(); + insert_person_for_team_in_pg( + database_client.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "test@example.com", "age": 30})), + ) + .await + .unwrap(); + + let mut matcher = FeatureFlagMatcher::new( + distinct_id, + team.id, + database_client.clone(), + None, + None, + None, + ); + + let properties = matcher + .get_person_properties_from_cache_or_db() + .await + .unwrap(); + + assert_eq!(properties.get("email").unwrap(), &json!("test@example.com")); + assert_eq!(properties.get("age").unwrap(), &json!(30)); + + let cached_properties = matcher.properties_cache.person_properties.clone(); + assert!(cached_properties.is_some()); + assert_eq!( + cached_properties.unwrap().get("email").unwrap(), + &json!("test@example.com") + ); + } + + #[tokio::test] + async fn test_overrides_locally_computable() { + let overrides = Some(HashMap::from([ + ("email".to_string(), json!("test@example.com")), + ("age".to_string(), json!(30)), + ])); + + let property_filters = vec![ + PropertyFilter { + key: "email".to_string(), + value: json!("test@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + }, + PropertyFilter { + key: "age".to_string(), + value: json!(25), + operator: Some(OperatorType::Gte), + prop_type: "person".to_string(), + group_type_index: None, + }, + ]; + + let result = locally_computable_property_overrides(&overrides, &property_filters); + assert!(result.is_some()); + + let property_filters_with_cohort = vec![ + PropertyFilter { + key: "email".to_string(), + value: json!("test@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + }, + PropertyFilter { + key: "cohort".to_string(), + value: json!(1), + operator: None, + prop_type: "cohort".to_string(), + group_type_index: None, + }, + ]; + + let result = + locally_computable_property_overrides(&overrides, &property_filters_with_cohort); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_concurrent_flag_evaluation() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + let flag = Arc::new(create_test_flag(team.id, vec![])); + + let mut handles = vec![]; + for i in 0..100 { + let flag_clone = flag.clone(); + let database_client_clone = database_client.clone(); + handles.push(tokio::spawn(async move { + let mut matcher = FeatureFlagMatcher::new( + format!("test_user_{}", i), + team.id, + database_client_clone, + None, + None, + None, + ); + matcher.get_match(&flag_clone, None).await.unwrap() + })); + } + + let results: Vec = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Check that all evaluations completed without errors + assert_eq!(results.len(), 100); + } + + #[tokio::test] + async fn test_property_operators() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + let flag = create_test_flag( + team.id, + vec![ + PropertyFilter { + key: "age".to_string(), + value: json!(25), + operator: Some(OperatorType::Gte), + prop_type: "person".to_string(), + group_type_index: None, + }, + PropertyFilter { + key: "email".to_string(), + value: json!("example@domain.com"), + operator: Some(OperatorType::Icontains), + prop_type: "person".to_string(), + group_type_index: None, + }, + ], + ); + + insert_person_for_team_in_pg( + database_client.clone(), + team.id, + "test_user".to_string(), + Some(json!({"email": "user@example@domain.com", "age": 30})), + ) + .await + .unwrap(); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + database_client.clone(), + None, + None, + None, + ); + + let result = matcher.get_match(&flag, None).await.unwrap(); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_empty_hashed_identifier() { + let database_client = setup_pg_client(None).await; let flag = create_test_flag(1, vec![]); - let condition = FlagGroupType { - variant: None, - properties: Some(vec![]), - rollout_percentage: Some(100.0), - }; + let mut matcher = + FeatureFlagMatcher::new("".to_string(), 1, database_client, None, None, None); + + let result = matcher.get_match(&flag, None).await.unwrap(); + + assert!(!result.matches); + } + + #[tokio::test] + async fn test_rollout_percentage() { + let database_client = setup_pg_client(None).await; + let mut flag = create_test_flag(1, vec![]); + // Set the rollout percentage to 0% + flag.filters.groups[0].rollout_percentage = Some(0.0); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + 1, + database_client, + None, + None, + None, + ); + + let result = matcher.get_match(&flag, None).await.unwrap(); + + assert!(!result.matches); + + // Now set the rollout percentage to 100% + flag.filters.groups[0].rollout_percentage = Some(100.0); + + let result = matcher.get_match(&flag, None).await.unwrap(); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_uneven_variant_distribution() { + let database_client = setup_pg_client(None).await; + let mut flag = create_test_flag_with_variants(1); + + // Adjust variant rollout percentages to be uneven + flag.filters.multivariate.as_mut().unwrap().variants = vec![ + MultivariateFlagVariant { + name: Some("Control".to_string()), + key: "control".to_string(), + rollout_percentage: 10.0, + }, + MultivariateFlagVariant { + name: Some("Test".to_string()), + key: "test".to_string(), + rollout_percentage: 30.0, + }, + MultivariateFlagVariant { + name: Some("Test2".to_string()), + key: "test2".to_string(), + rollout_percentage: 60.0, + }, + ]; + + // Ensure the flag is person-based by setting aggregation_group_type_index to None + flag.filters.aggregation_group_type_index = None; + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + 1, + database_client, + None, + None, + None, + ); + + let mut control_count = 0; + let mut test_count = 0; + let mut test2_count = 0; + + // Run the test multiple times to simulate distribution + for i in 0..1000 { + matcher.distinct_id = format!("user_{}", i); + let variant = matcher.get_matching_variant(&flag).await.unwrap(); + match variant.as_deref() { + Some("control") => control_count += 1, + Some("test") => test_count += 1, + Some("test2") => test2_count += 1, + _ => (), + } + } + + // Check that the distribution roughly matches the rollout percentages + let total = control_count + test_count + test2_count; + assert!((control_count as f64 / total as f64 - 0.10).abs() < 0.05); + assert!((test_count as f64 / total as f64 - 0.30).abs() < 0.05); + assert!((test2_count as f64 / total as f64 - 0.60).abs() < 0.05); + } + + #[tokio::test] + async fn test_missing_properties_in_db() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + // Insert a person without properties + insert_person_for_team_in_pg( + database_client.clone(), + team.id, + "test_user".to_string(), + None, + ) + .await + .unwrap(); + + let flag = create_test_flag( + team.id, + vec![PropertyFilter { + key: "email".to_string(), + value: json!("test@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + }], + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + database_client.clone(), + None, + None, + None, + ); + + let result = matcher.get_match(&flag, None).await.unwrap(); + + assert!(!result.matches); + } + + #[tokio::test] + async fn test_malformed_property_data() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + // Insert a person with malformed properties + insert_person_for_team_in_pg( + database_client.clone(), + team.id, + "test_user".to_string(), + Some(json!({"age": "not_a_number"})), + ) + .await + .unwrap(); + + let flag = create_test_flag( + team.id, + vec![PropertyFilter { + key: "age".to_string(), + value: json!(25), + operator: Some(OperatorType::Gte), + prop_type: "person".to_string(), + group_type_index: None, + }], + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + database_client.clone(), + None, + None, + None, + ); + + let result = matcher.get_match(&flag, None).await.unwrap(); + + // The match should fail due to invalid data type + assert!(!result.matches); + } + + #[tokio::test] + async fn test_property_caching() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + let distinct_id = "test_user".to_string(); + insert_person_for_team_in_pg( + database_client.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "test@example.com", "age": 30})), + ) + .await + .unwrap(); + + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + database_client.clone(), + None, + None, + None, + ); + + // First access should fetch from the database + let properties = matcher + .get_person_properties_from_cache_or_db() + .await + .unwrap(); + + assert!(matcher.properties_cache.person_properties.is_some()); + + // Second access should use the cache and not error out + let cached_properties = matcher + .get_person_properties_from_cache_or_db() + .await + .unwrap(); + + assert_eq!(properties, cached_properties); + } + + #[tokio::test] + async fn test_get_match_with_insufficient_overrides() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + let flag = create_test_flag( + team.id, + vec![ + PropertyFilter { + key: "email".to_string(), + value: json!("test@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + }, + PropertyFilter { + key: "age".to_string(), + value: json!(25), + operator: Some(OperatorType::Gte), + prop_type: "person".to_string(), + group_type_index: None, + }, + ], + ); + + let person_overrides = Some(HashMap::from([( + "email".to_string(), + json!("test@example.com"), + )])); + + insert_person_for_team_in_pg( + database_client.clone(), + team.id, + "test_user".to_string(), + Some(json!({"email": "test@example.com", "age": 30})), + ) + .await + .unwrap(); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + database_client.clone(), + None, + None, + None, + ); + + let result = matcher.get_match(&flag, person_overrides).await.unwrap(); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_evaluation_reasons() { + let database_client = setup_pg_client(None).await; + let flag = create_test_flag(1, vec![]); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + 1, + database_client, + None, + None, + None, + ); - let mut matcher = FeatureFlagMatcher::new("test_user".to_string(), None, None); let (is_match, reason) = matcher - .is_condition_match(&flag, &condition, 0) + .is_condition_match(&flag, &flag.filters.groups[0], None) .await .unwrap(); - assert_eq!(is_match, true); + + assert!(is_match); assert_eq!(reason, "CONDITION_MATCH"); } + + #[tokio::test] + async fn test_complex_conditions() { + let database_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(database_client.clone()) + .await + .unwrap(); + + let flag = FeatureFlag { + id: 1, + team_id: team.id, + name: Some("Complex Flag".to_string()), + key: "complex_flag".to_string(), + filters: FlagFilters { + groups: vec![ + FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "email".to_string(), + value: json!("user1@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + }]), + rollout_percentage: Some(100.0), + variant: None, + }, + FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "age".to_string(), + value: json!(30), + operator: Some(OperatorType::Gte), + prop_type: "person".to_string(), + group_type_index: None, + }]), + rollout_percentage: Some(100.0), + variant: None, + }, + ], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + deleted: false, + active: true, + ensure_experience_continuity: false, + }; + + insert_person_for_team_in_pg( + database_client.clone(), + team.id, + "test_user".to_string(), + Some(json!({"email": "user2@example.com", "age": 35})), + ) + .await + .unwrap(); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + database_client.clone(), + None, + None, + None, + ); + + let result = matcher.get_match(&flag, None).await.unwrap(); + + assert!(result.matches); + } } diff --git a/rust/feature-flags/src/flag_request.rs b/rust/feature-flags/src/flag_request.rs index d15876a37481b..d41ac5d1e07b9 100644 --- a/rust/feature-flags/src/flag_request.rs +++ b/rust/feature-flags/src/flag_request.rs @@ -39,14 +39,24 @@ impl FlagRequest { #[instrument(skip_all)] pub fn from_bytes(bytes: Bytes) -> Result { tracing::debug!(len = bytes.len(), "decoding new request"); - // TODO: Add base64 decoding - let payload = String::from_utf8(bytes.into()).map_err(|e| { + + let payload = String::from_utf8(bytes.to_vec()).map_err(|e| { tracing::error!("failed to decode body: {}", e); FlagError::RequestDecodingError(String::from("invalid body encoding")) })?; tracing::debug!(json = payload, "decoded event data"); - Ok(serde_json::from_str::(&payload)?) + + // Attempt to parse as JSON, rejecting invalid JSON + match serde_json::from_str::(&payload) { + Ok(request) => Ok(request), + Err(e) => { + tracing::error!("failed to parse JSON: {}", e); + Err(FlagError::RequestDecodingError(String::from( + "invalid JSON", + ))) + } + } } /// Extracts the token from the request and verifies it against the cache. diff --git a/rust/feature-flags/src/property_matching.rs b/rust/feature-flags/src/property_matching.rs index 9f7d9ea173963..f29f93d90690c 100644 --- a/rust/feature-flags/src/property_matching.rs +++ b/rust/feature-flags/src/property_matching.rs @@ -36,7 +36,6 @@ pub fn match_property( ) -> Result { // only looks for matches where key exists in override_property_values // doesn't support operator is_not_set with partial_props - if partial_props && !matching_property_values.contains_key(&property.key) { return Err(FlagMatchingError::MissingProperty(format!( "can't match properties without a value. Missing property: {}", diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs index 35606727f3259..83d1c0f66f352 100644 --- a/rust/feature-flags/src/request_handler.rs +++ b/rust/feature-flags/src/request_handler.rs @@ -1,32 +1,38 @@ use crate::{ - api::{FlagError, FlagValue, FlagsResponse}, + api::{FlagError, FlagsResponse}, database::Client, flag_definitions::FeatureFlagList, - flag_matching::FeatureFlagMatcher, + flag_matching::{FeatureFlagMatcher, GroupTypeMappingCache}, flag_request::FlagRequest, geoip::GeoIpClient, router, }; use axum::{extract::State, http::HeaderMap}; +use base64::{engine::general_purpose, Engine as _}; use bytes::Bytes; -use serde::Deserialize; +use flate2::read::GzDecoder; +use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::sync::Arc; use std::{collections::HashMap, net::IpAddr}; -use tracing::error; +use std::{io::Read, sync::Arc}; -#[derive(Deserialize, Default)] +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)] +#[serde(rename_all = "lowercase")] pub enum Compression { + #[serde(rename = "gzip")] + #[serde(alias = "gzip-js")] + Gzip, + Base64, #[default] + #[serde(other)] Unsupported, - #[serde(rename = "gzip", alias = "gzip-js")] - Gzip, } impl Compression { pub fn as_str(&self) -> &'static str { match self { Compression::Gzip => "gzip", + Compression::Base64 => "base64", Compression::Unsupported => "unsupported", } } @@ -71,24 +77,28 @@ pub async fn process_request(context: RequestContext) -> Result Result { - match headers + let content_type = headers .get("content-type") - .map_or("", |v| v.to_str().unwrap_or("")) - { - "application/json" => FlagRequest::from_bytes(body), + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + let content_encoding = headers + .get("content-encoding") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + let decoded_body = match content_encoding { + "gzip" => decompress_gzip(body)?, + "" => body, + encoding => { + return Err(FlagError::RequestDecodingError(format!( + "unsupported content encoding: {}", + encoding + ))) + } + }; + + match content_type { + "application/json" => FlagRequest::from_bytes(decoded_body), + "application/json; encoding=base64" => { + let decoded = general_purpose::STANDARD + .decode(decoded_body) + .map_err(|e| { + FlagError::RequestDecodingError(format!("Base64 decoding error: {}", e)) + })?; + FlagRequest::from_bytes(Bytes::from(decoded)) + } ct => Err(FlagError::RequestDecodingError(format!( "unsupported content type: {}", ct @@ -148,70 +184,63 @@ fn decode_request(headers: &HeaderMap, body: Bytes) -> Result>, + database_client: Arc, person_property_overrides: Option>, - // group_property_overrides: Option>>, + group_property_overrides: Option>>, + groups: Option>, ) -> FlagsResponse { - let mut matcher = FeatureFlagMatcher::new( + let group_type_mapping_cache = GroupTypeMappingCache::new(team_id, database_client.clone()); + let mut feature_flag_matcher = FeatureFlagMatcher::new( distinct_id.clone(), + team_id, database_client, - person_property_overrides, - // group_property_overrides, + Some(group_type_mapping_cache), + None, + groups, ); - let mut feature_flags = HashMap::new(); - let mut error_while_computing_flags = false; - let feature_flag_list = feature_flags_from_cache_or_pg.flags; - - for flag in feature_flag_list { - if !flag.active || flag.deleted { - continue; - } - - match matcher.get_match(&flag).await { - Ok(flag_match) => { - let flag_value = if flag_match.matches { - match flag_match.variant { - Some(variant) => FlagValue::String(variant), - None => FlagValue::Boolean(true), - } - } else { - FlagValue::Boolean(false) - }; - feature_flags.insert(flag.key.clone(), flag_value); - } - Err(e) => { - error_while_computing_flags = true; - error!( - "Error evaluating feature flag '{}' for distinct_id '{}': {:?}", - flag.key, distinct_id, e - ); - } - } - } + feature_flag_matcher + .evaluate_feature_flags( + feature_flags_from_cache_or_pg, + person_property_overrides, + group_property_overrides, + ) + .await +} - FlagsResponse { - error_while_computing_flags, - feature_flags, - } +// TODO: Make sure this protects against zip bombs, etc. `/capture` does this +// and it's a good idea to do that here as well, probably worth extracting that method into +// /common given that it's used in multiple places +fn decompress_gzip(compressed: Bytes) -> Result { + let mut decoder = GzDecoder::new(&compressed[..]); + let mut decompressed = Vec::new(); + decoder.read_to_end(&mut decompressed).map_err(|e| { + FlagError::RequestDecodingError(format!("gzip decompression failed: {}", e)) + })?; + Ok(Bytes::from(decompressed)) } #[cfg(test)] mod tests { use crate::{ + api::FlagValue, config::Config, flag_definitions::{FeatureFlag, FlagFilters, FlagGroupType, OperatorType, PropertyFilter}, - test_utils::setup_pg_client, + test_utils::{insert_new_team_in_pg, setup_pg_client}, }; use super::*; use axum::http::HeaderMap; - use serde_json::json; - use std::net::Ipv4Addr; + use serde_json::{json, Value}; + use std::net::{Ipv4Addr, Ipv6Addr}; fn create_test_geoip_service() -> GeoIpClient { let config = Config::default_test_config(); @@ -340,10 +369,13 @@ mod tests { person_properties.insert("country".to_string(), json!("US")); let result = evaluate_feature_flags( + 1, "user123".to_string(), feature_flag_list, - Some(pg_client), + pg_client, Some(person_properties), + None, + None, ) .await; @@ -367,9 +399,292 @@ mod tests { assert_eq!(request.distinct_id, Some("user123".to_string())); } + #[test] + fn test_decode_request_unsupported_content_encoding() { + let mut headers = HeaderMap::new(); + headers.insert("content-type", "application/json".parse().unwrap()); + headers.insert("content-encoding", "deflate".parse().unwrap()); + let body = Bytes::from_static(b"{\"token\": \"test_token\", \"distinct_id\": \"user123\"}"); + let result = decode_request(&headers, body); + assert!(matches!(result, Err(FlagError::RequestDecodingError(_)))); + } + + #[test] + fn test_decode_request_invalid_base64() { + let mut headers = HeaderMap::new(); + headers.insert( + "content-type", + "application/json; encoding=base64".parse().unwrap(), + ); + let body = Bytes::from_static(b"invalid_base64=="); + let result = decode_request(&headers, body); + assert!(matches!(result, Err(FlagError::RequestDecodingError(_)))); + } + #[test] fn test_compression_as_str() { assert_eq!(Compression::Gzip.as_str(), "gzip"); assert_eq!(Compression::Unsupported.as_str(), "unsupported"); } + + #[test] + fn test_get_person_property_overrides_ipv4() { + let geoip_service = create_test_geoip_service(); + let result = get_person_property_overrides( + true, + Some(HashMap::new()), + &IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), + &geoip_service, + ); + assert!(result.is_some()); + let props = result.unwrap(); + assert!(props.contains_key("$geoip_country_name")); + } + + #[test] + fn test_get_person_property_overrides_ipv6() { + let geoip_service = create_test_geoip_service(); + let result = get_person_property_overrides( + true, + Some(HashMap::new()), + &IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)), + &geoip_service, + ); + assert!(result.is_some()); + let props = result.unwrap(); + assert!(props.contains_key("$geoip_country_name")); + } + + #[test] + fn test_decode_request_unsupported_content_type() { + let mut headers = HeaderMap::new(); + headers.insert("content-type", "text/plain".parse().unwrap()); + let body = Bytes::from_static(b"test"); + let result = decode_request(&headers, body); + assert!(matches!(result, Err(FlagError::RequestDecodingError(_)))); + } + + #[test] + fn test_decode_request_malformed_json() { + let mut headers = HeaderMap::new(); + headers.insert("content-type", "application/json".parse().unwrap()); + let body = Bytes::from_static(b"{invalid json}"); + let result = decode_request(&headers, body); + // If the actual implementation doesn't return a RequestDecodingError, + // we should adjust our expectation. Let's check if it's an error at all: + assert!(result.is_err(), "Expected an error, but got Ok"); + // If you want to check for a specific error type, you might need to adjust + // the FlagError enum or the decode_request function. + } + + #[tokio::test] + async fn test_evaluate_feature_flags_multiple_flags() { + let pg_client = setup_pg_client(None).await; + let flags = vec![ + FeatureFlag { + name: Some("Flag 1".to_string()), + id: 1, + key: "flag_1".to_string(), + active: true, + deleted: false, + team_id: 1, + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + ensure_experience_continuity: false, + }, + FeatureFlag { + name: Some("Flag 2".to_string()), + id: 2, + key: "flag_2".to_string(), + active: true, + deleted: false, + team_id: 1, + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![]), + rollout_percentage: Some(0.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + ensure_experience_continuity: false, + }, + ]; + + let feature_flag_list = FeatureFlagList { flags }; + + let result = evaluate_feature_flags( + 1, + "user123".to_string(), + feature_flag_list, + pg_client, + None, + None, + None, + ) + .await; + + assert!(!result.error_while_computing_flags); + assert_eq!(result.feature_flags["flag_1"], FlagValue::Boolean(true)); + assert_eq!(result.feature_flags["flag_2"], FlagValue::Boolean(false)); + } + + #[test] + fn test_flags_query_params_deserialization() { + let json = r#"{ + "v": "1.0", + "compression": "gzip", + "lib_version": "2.0", + "sent_at": 1234567890 + }"#; + let params: FlagsQueryParams = serde_json::from_str(json).unwrap(); + assert_eq!(params.version, Some("1.0".to_string())); + assert!(matches!(params.compression, Some(Compression::Gzip))); + assert_eq!(params.lib_version, Some("2.0".to_string())); + assert_eq!(params.sent_at, Some(1234567890)); + } + + #[test] + fn test_compression_deserialization() { + assert_eq!( + serde_json::from_str::("\"gzip\"").unwrap(), + Compression::Gzip + ); + assert_eq!( + serde_json::from_str::("\"gzip-js\"").unwrap(), + Compression::Gzip + ); + // If "invalid" is actually deserialized to Unsupported, we should change our expectation + assert_eq!( + serde_json::from_str::("\"invalid\"").unwrap(), + Compression::Unsupported + ); + } + + #[test] + fn test_flag_error_request_decoding() { + let error = FlagError::RequestDecodingError("Test error".to_string()); + assert!(matches!(error, FlagError::RequestDecodingError(_))); + } + + #[tokio::test] + async fn test_evaluate_feature_flags_with_overrides() { + let pg_client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(pg_client.clone()).await.unwrap(); + + let flag = FeatureFlag { + name: Some("Test Flag".to_string()), + id: 1, + key: "test_flag".to_string(), + active: true, + deleted: false, + team_id: team.id, + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "industry".to_string(), + value: json!("tech"), + operator: Some(OperatorType::Exact), + prop_type: "group".to_string(), + group_type_index: Some(0), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: Some(0), + payloads: None, + super_groups: None, + }, + ensure_experience_continuity: false, + }; + let feature_flag_list = FeatureFlagList { flags: vec![flag] }; + + let groups = HashMap::from([("project".to_string(), json!("project_123"))]); + let group_property_overrides = HashMap::from([( + "project".to_string(), + HashMap::from([ + ("industry".to_string(), json!("tech")), + ("$group_key".to_string(), json!("project_123")), + ]), + )]); + + let result = evaluate_feature_flags( + team.id, + "user123".to_string(), + feature_flag_list, + pg_client, + None, + Some(group_property_overrides), + Some(groups), + ) + .await; + + assert!( + !result.error_while_computing_flags, + "Error while computing flags" + ); + assert!( + result.feature_flags.contains_key("test_flag"), + "test_flag not found in result" + ); + + let flag_value = result + .feature_flags + .get("test_flag") + .expect("test_flag not found"); + + assert_eq!( + flag_value, + &FlagValue::Boolean(true), + "Flag value is not true as expected" + ); + } + + #[tokio::test] + async fn test_long_distinct_id() { + let long_id = "a".repeat(1000); + let pg_client = setup_pg_client(None).await; + let flag = FeatureFlag { + name: Some("Test Flag".to_string()), + id: 1, + key: "test_flag".to_string(), + active: true, + deleted: false, + team_id: 1, + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + ensure_experience_continuity: false, + }; + + let feature_flag_list = FeatureFlagList { flags: vec![flag] }; + + let result = + evaluate_feature_flags(1, long_id, feature_flag_list, pg_client, None, None, None) + .await; + + assert!(!result.error_while_computing_flags); + assert_eq!(result.feature_flags["test_flag"], FlagValue::Boolean(true)); + } } diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index 1a32e0837cede..14ebac3c16b06 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -9,8 +9,8 @@ use crate::{ #[derive(Clone)] pub struct State { + // TODO add writers when ready pub redis: Arc, - // TODO: Add pgClient when ready pub postgres: Arc, pub geoip: Arc, } diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index 20b33ba5c3543..e3a0a916e641d 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -207,6 +207,30 @@ pub async fn insert_new_team_in_pg(client: Arc) -> Result assert_eq!(res.rows_affected(), 1); + // Insert group type mappings + let group_types = vec![ + ("project", 0), + ("organization", 1), + ("instance", 2), + ("customer", 3), + ("team", 4), + ]; + + for (group_type, group_type_index) in group_types { + let res = sqlx::query( + r#"INSERT INTO posthog_grouptypemapping + (group_type, group_type_index, name_singular, name_plural, team_id) + VALUES + ($1, $2, NULL, NULL, $3)"#, + ) + .bind(group_type) + .bind(group_type_index) + .bind(team.id) + .execute(&mut *conn) + .await?; + + assert_eq!(res.rows_affected(), 1); + } Ok(team) } diff --git a/rust/feature-flags/tests/test_flag_matching_consistency.rs b/rust/feature-flags/tests/test_flag_matching_consistency.rs index 2a4972962019c..9a44de9debed9 100644 --- a/rust/feature-flags/tests/test_flag_matching_consistency.rs +++ b/rust/feature-flags/tests/test_flag_matching_consistency.rs @@ -2,7 +2,7 @@ /// This ensures there are no mismatches between implementations. use feature_flags::flag_matching::{FeatureFlagMatch, FeatureFlagMatcher}; -use feature_flags::test_utils::create_flag_from_json; +use feature_flags::test_utils::{create_flag_from_json, setup_pg_client}; use serde_json::json; #[tokio::test] @@ -105,12 +105,15 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { ]; for i in 0..1000 { + let database_client = setup_pg_client(None).await; + let distinct_id = format!("distinct_id_{}", i); - let feature_flag_match = FeatureFlagMatcher::new(distinct_id, None, None) - .get_match(&flags[0]) - .await - .unwrap(); + let feature_flag_match = + FeatureFlagMatcher::new(distinct_id, 1, database_client, None, None, None) + .get_match(&flags[0], None) + .await + .unwrap(); if results[i] { assert_eq!( @@ -1187,12 +1190,14 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { ]; for i in 0..1000 { + let database_client = setup_pg_client(None).await; let distinct_id = format!("distinct_id_{}", i); - let feature_flag_match = FeatureFlagMatcher::new(distinct_id, None, None) - .get_match(&flags[0]) - .await - .unwrap(); + let feature_flag_match = + FeatureFlagMatcher::new(distinct_id, 1, database_client, None, None, None) + .get_match(&flags[0], None) + .await + .unwrap(); if results[i].is_some() { assert_eq!( diff --git a/rust/feature-flags/tests/test_flags.rs b/rust/feature-flags/tests/test_flags.rs index 706d8fdfed0da..f12f8434aface 100644 --- a/rust/feature-flags/tests/test_flags.rs +++ b/rust/feature-flags/tests/test_flags.rs @@ -189,7 +189,15 @@ async fn it_handles_malformed_json() -> Result<()> { let payload = "{invalid_json}"; let res = server.send_flags_request(payload.to_string()).await; assert_eq!(StatusCode::BAD_REQUEST, res.status()); - assert!(res.text().await?.starts_with("Failed to parse request:")); + + let response_text = res.text().await?; + println!("Response text: {:?}", response_text); + + assert!( + response_text.contains("Failed to decode request: invalid JSON"), + "Unexpected error message: {:?}", + response_text + ); Ok(()) }