Skip to content

Commit

Permalink
chore: refactor RedisLimiter to update async out of the hot path (#24818
Browse files Browse the repository at this point in the history
)
  • Loading branch information
frankh authored Sep 6, 2024
1 parent 70c96fd commit 1107dd7
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 145 deletions.
187 changes: 73 additions & 114 deletions rust/capture/src/limiters/redis.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use metrics::gauge;
use std::{collections::HashSet, ops::Sub, sync::Arc};
use std::time::Duration as StdDuration;
use std::{collections::HashSet, sync::Arc};
use time::{Duration, OffsetDateTime};
use tokio::sync::RwLock;
use tokio::task;
use tokio::time::interval;
use tracing::instrument;

use crate::redis::Client;

Expand All @@ -17,18 +23,12 @@ use crate::redis::Client;
/// 2. Capture should cope with redis being _totally down_, and fail open
/// 3. We should not hit redis for every single request
///
/// The solution here is to read from the cache until a time interval is hit, and then fetch new
/// data. The write requires taking a lock that stalls all readers, though so long as redis reads
/// stay fast we're ok.
/// The solution here is to read from the cache and update the set in a background thread.
/// We have to lock all readers briefly while we update the set, but we don't hold the lock
/// until we already have the response from redis so it should be very short.
///
/// Some small delay between an account being limited and the limit taking effect is acceptable.
/// However, ideally we should not allow requests from some pods but 429 from others.
use thiserror::Error;
use time::{Duration, OffsetDateTime};
use tokio::sync::RwLock;
use tracing::instrument;

// todo: fetch from env
const QUOTA_LIMITER_CACHE_KEY: &str = "@posthog/quota-limits/";

#[derive(Debug)]
Expand All @@ -46,19 +46,12 @@ impl QuotaResource {
}
}

#[derive(Error, Debug)]
pub enum LimiterError {
#[error("updater already running - there can only be one")]
UpdaterRunning,
}

#[derive(Clone)]
pub struct RedisLimiter {
limited: Arc<RwLock<HashSet<String>>>,
redis: Arc<dyn Client + Send + Sync>,
redis_key_prefix: String,
key: String,
interval: Duration,
updated: Arc<RwLock<OffsetDateTime>>,
}

impl RedisLimiter {
Expand All @@ -74,98 +67,67 @@ impl RedisLimiter {
interval: Duration,
redis: Arc<dyn Client + Send + Sync>,
redis_key_prefix: Option<String>,
resource: QuotaResource,
) -> anyhow::Result<RedisLimiter> {
let limited = Arc::new(RwLock::new(HashSet::new()));
let key_prefix = redis_key_prefix.unwrap_or_default();

// Force an update immediately if we have any reasonable interval
let updated = OffsetDateTime::from_unix_timestamp(0)?;
let updated = Arc::new(RwLock::new(updated));

Ok(RedisLimiter {
let limiter = RedisLimiter {
interval,
limited,
updated,
redis,
redis_key_prefix: redis_key_prefix.unwrap_or_default(),
})
redis: redis.clone(),
key: format!("{key_prefix}{QUOTA_LIMITER_CACHE_KEY}{}", resource.as_str()),
};

// Spawn a background task to periodically fetch data from Redis
limiter.spawn_background_update();

Ok(limiter)
}

fn spawn_background_update(&self) {
let limited = Arc::clone(&self.limited);
let redis = Arc::clone(&self.redis);
let interval_duration = StdDuration::from_nanos(self.interval.whole_nanoseconds() as u64);
let key = self.key.clone();

// Spawn a task to periodically update the cache from Redis
task::spawn(async move {
let mut interval = interval(interval_duration);
loop {
match RedisLimiter::fetch_limited(&redis, &key).await {
Ok(set) => {
let set = HashSet::from_iter(set.iter().cloned());
gauge!("capture_billing_limits_loaded_tokens",).set(set.len() as f64);

let mut limited_lock = limited.write().await;
*limited_lock = set;
}
Err(e) => {
tracing::error!("Failed to update cache from Redis: {:?}", e);
}
}

interval.tick().await;
}
});
}

#[instrument(skip_all)]
async fn fetch_limited(
client: &Arc<dyn Client + Send + Sync>,
key_prefix: &str,
resource: &QuotaResource,
key: &String,
) -> anyhow::Result<Vec<String>> {
let now = OffsetDateTime::now_utc().unix_timestamp();
let key = format!("{key_prefix}{QUOTA_LIMITER_CACHE_KEY}{}", resource.as_str());
client
.zrangebyscore(key, now.to_string(), String::from("+Inf"))
.zrangebyscore(key.to_string(), now.to_string(), String::from("+Inf"))
.await
}

#[instrument(skip_all, fields(key = key))]
pub async fn is_limited(&self, key: &str, resource: QuotaResource) -> bool {
// hold the read lock to clone it, very briefly. clone is ok because it's very small 🤏
// rwlock can have many readers, but one writer. the writer will wait in a queue with all
// the readers, so we want to hold read locks for the smallest time possible to avoid
// writers waiting for too long. and vice versa.
let updated = {
let updated = self.updated.read().await;
*updated
};

let now = OffsetDateTime::now_utc();
let since_update = now.sub(updated);

// If an update is due, fetch the set from redis + cache it until the next update is due.
// Otherwise, return a value from the cache
//
// This update will block readers! Keep it fast.
if since_update > self.interval {
// open the update lock to change the update, and prevent anyone else from doing so
let mut updated = self.updated.write().await;
*updated = OffsetDateTime::now_utc();

let span = tracing::debug_span!("updating billing cache from redis");
let _span = span.enter();

// a few requests might end up in here concurrently, but I don't think a few extra will
// be a big problem. If it is, we can rework the concurrency a bit.
// On prod atm we call this around 15 times per second at peak times, and it usually
// completes in <1ms.

let set = Self::fetch_limited(&self.redis, &self.redis_key_prefix, &resource).await;

tracing::debug!("fetched set from redis, caching");

if let Ok(set) = set {
let set = HashSet::from_iter(set.iter().cloned());
gauge!(
"capture_billing_limits_loaded_tokens",
"resource" => resource.as_str(),
)
.set(set.len() as f64);

let mut limited = self.limited.write().await;
*limited = set;

tracing::debug!("updated cache from redis");

limited.contains(key)
} else {
tracing::error!("failed to fetch from redis in time, failing open");
// If we fail to fetch the set, something really wrong is happening. To avoid
// dropping events that we don't mean to drop, fail open and accept data. Better
// than angry customers :)
//
// TODO: Consider backing off our redis checks
false
}
} else {
let l = self.limited.read().await;

l.contains(key)
}
#[instrument(skip_all, fields(value = value))]
pub async fn is_limited(&self, value: &str) -> bool {
let limited = self.limited.read().await;
limited.contains(value)
}
}

Expand All @@ -185,15 +147,12 @@ mod tests {
.zrangebyscore_ret("@posthog/quota-limits/events", vec![String::from("banana")]);
let client = Arc::new(client);

let limiter = RedisLimiter::new(Duration::microseconds(1), client, None)
let limiter = RedisLimiter::new(Duration::seconds(1), client, None, QuotaResource::Events)
.expect("Failed to create billing limiter");
tokio::time::sleep(std::time::Duration::from_millis(30)).await;

assert!(
!limiter
.is_limited("not_limited", QuotaResource::Events)
.await,
);
assert!(limiter.is_limited("banana", QuotaResource::Events).await);
assert!(!limiter.is_limited("not_limited").await);
assert!(limiter.is_limited("banana").await);
}

#[tokio::test]
Expand All @@ -205,27 +164,27 @@ mod tests {
let client = Arc::new(client);

// Default lookup without prefix fails
let limiter = RedisLimiter::new(Duration::microseconds(1), client.clone(), None)
.expect("Failed to create billing limiter");
assert!(!limiter.is_limited("banana", QuotaResource::Events).await);
let limiter = RedisLimiter::new(
Duration::seconds(1),
client.clone(),
None,
QuotaResource::Events,
)
.expect("Failed to create billing limiter");
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
assert!(!limiter.is_limited("banana").await);

// Limiter using the correct prefix
let prefixed_limiter = RedisLimiter::new(
Duration::microseconds(1),
client,
Some("prefix//".to_string()),
QuotaResource::Events,
)
.expect("Failed to create billing limiter");
tokio::time::sleep(std::time::Duration::from_millis(30)).await;

assert!(
!prefixed_limiter
.is_limited("not_limited", QuotaResource::Events)
.await,
);
assert!(
prefixed_limiter
.is_limited("banana", QuotaResource::Events)
.await
);
assert!(!prefixed_limiter.is_limited("not_limited").await);
assert!(prefixed_limiter.is_limited("banana").await);
}
}
7 changes: 6 additions & 1 deletion rust/capture/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use health::{ComponentStatus, HealthRegistry};
use time::Duration;
use tokio::net::TcpListener;

use crate::config::CaptureMode;
use crate::config::Config;

use crate::limiters::overflow::OverflowLimiter;
use crate::limiters::redis::RedisLimiter;
use crate::limiters::redis::{QuotaResource, RedisLimiter};
use crate::redis::RedisClient;
use crate::router;
use crate::sinks::kafka::KafkaSink;
Expand All @@ -28,6 +29,10 @@ where
Duration::seconds(5),
redis_client.clone(),
config.redis_key_prefix,
match config.capture_mode {
CaptureMode::Events => QuotaResource::Events,
CaptureMode::Recordings => QuotaResource::Recordings,
},
)
.expect("failed to create billing limiter");

Expand Down
31 changes: 3 additions & 28 deletions rust/capture/src/v0_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use serde_json::json;
use serde_json::Value;
use tracing::instrument;

use crate::limiters::redis::QuotaResource;
use crate::prometheus::report_dropped_events;
use crate::v0_request::{Compression, ProcessingContext, RawRequest};
use crate::{
Expand All @@ -29,15 +28,13 @@ use crate::{
///
/// Because it must accommodate several shapes, it is inefficient in places. A v1
/// endpoint should be created, that only accepts the BatchedRequest payload shape.
#[allow(clippy::too_many_arguments)]
async fn handle_common(
state: &State<router::State>,
InsecureClientIp(ip): &InsecureClientIp,
meta: &EventQuery,
headers: &HeaderMap,
method: &Method,
path: &MatchedPath,
quota_resource: QuotaResource,
body: Bytes,
) -> Result<(ProcessingContext, Vec<RawEvent>), CaptureError> {
let user_agent = headers
Expand Down Expand Up @@ -119,7 +116,7 @@ async fn handle_common(

let billing_limited = state
.billing_limiter
.is_limited(context.token.as_str(), quota_resource)
.is_limited(context.token.as_str())
.await;

if billing_limited {
Expand Down Expand Up @@ -157,18 +154,7 @@ pub async fn event(
path: MatchedPath,
body: Bytes,
) -> Result<Json<CaptureResponse>, CaptureError> {
match handle_common(
&state,
&ip,
&meta,
&headers,
&method,
&path,
QuotaResource::Events,
body,
)
.await
{
match handle_common(&state, &ip, &meta, &headers, &method, &path, body).await {
Err(CaptureError::BillingLimit) => {
// for v0 we want to just return ok 🙃
// this is because the clients are pretty dumb and will just retry over and over and
Expand Down Expand Up @@ -227,18 +213,7 @@ pub async fn recording(
path: MatchedPath,
body: Bytes,
) -> Result<Json<CaptureResponse>, CaptureError> {
match handle_common(
&state,
&ip,
&meta,
&headers,
&method,
&path,
QuotaResource::Recordings,
body,
)
.await
{
match handle_common(&state, &ip, &meta, &headers, &method, &path, body).await {
Err(CaptureError::BillingLimit) => Ok(Json(CaptureResponse {
status: CaptureResponseCode::Ok,
quota_limited: Some(vec!["recordings".to_string()]),
Expand Down
10 changes: 8 additions & 2 deletions rust/capture/tests/django_compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use base64::engine::general_purpose;
use base64::Engine;
use capture::api::{CaptureError, CaptureResponse, CaptureResponseCode, DataType, ProcessedEvent};
use capture::config::CaptureMode;
use capture::limiters::redis::QuotaResource;
use capture::limiters::redis::RedisLimiter;
use capture::redis::MockRedisClient;
use capture::router::router;
Expand Down Expand Up @@ -101,8 +102,13 @@ async fn it_matches_django_capture_behaviour() -> anyhow::Result<()> {
let timesource = FixedTime { time: case.now };

let redis = Arc::new(MockRedisClient::new());
let billing_limiter = RedisLimiter::new(Duration::weeks(1), redis.clone(), None)
.expect("failed to create billing limiter");
let billing_limiter = RedisLimiter::new(
Duration::weeks(1),
redis.clone(),
None,
QuotaResource::Events,
)
.expect("failed to create billing limiter");

let app = router(
timesource,
Expand Down

0 comments on commit 1107dd7

Please sign in to comment.