Skip to content

Commit

Permalink
Rate limiter for postUserMessageWithPubsub() (#2449)
Browse files Browse the repository at this point in the history
* Rate limiter for postUserMessageWithPubsub()

* No relative import

* Return 1 in case of failure + add monitoring

* Adjust rate limit error message

* Remove remaining distribution and add exeeded count
  • Loading branch information
lasryaric authored Nov 10, 2023
1 parent 8db1408 commit 30518fa
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 1 deletion.
27 changes: 27 additions & 0 deletions front/lib/api/assistant/pubsub.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { GenerationTokensEvent } from "@app/lib/api/assistant/generation";
import { Authenticator } from "@app/lib/auth";
import { APIErrorWithStatusCode } from "@app/lib/error";
import { AgentMessage, Message } from "@app/lib/models";
import { rateLimiter } from "@app/lib/rate_limiter";
import { redisClient } from "@app/lib/redis";
import { Err, Ok, Result } from "@app/lib/result";
import { wakeLock } from "@app/lib/wake_lock";
Expand Down Expand Up @@ -47,6 +48,32 @@ export async function postUserMessageWithPubSub(
context: UserMessageContext;
}
): Promise<Result<UserMessageType, PubSubError>> {
let maxPerTimeframe: number | undefined = undefined;
let timeframeSeconds: number | undefined = undefined;
let rateLimitKey: string | undefined = "";
if (auth.user()?.id) {
maxPerTimeframe = 3;
timeframeSeconds = 120;
rateLimitKey = `postUserMessageUser:${auth.user()?.id}`;
} else {
maxPerTimeframe = 20;
timeframeSeconds = 120;
rateLimitKey = `postUserMessageWorkspace:${auth.workspace()?.id}`;
}

if (
(await rateLimiter(rateLimitKey, maxPerTimeframe, timeframeSeconds)) === 0
) {
return new Err({
status_code: 429,
api_error: {
type: "rate_limit_error",
message: `You have reached the maximum number of ${maxPerTimeframe} messages per ${Math.ceil(
timeframeSeconds / 60
)} minutes of your account. Please try again later.`,
},
});
}
const postMessageEvents = postUserMessage(auth, {
conversation,
content,
Expand Down
3 changes: 2 additions & 1 deletion front/lib/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ export type APIErrorType =
| "subscription_error"
| "stripe_webhook_error"
| "stripe_api_error"
| "stripe_invalid_product_id_error";
| "stripe_invalid_product_id_error"
| "rate_limit_error";

export type APIError = {
type: APIErrorType;
Expand Down
63 changes: 63 additions & 0 deletions front/lib/rate_limiter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import StatsD from "hot-shots";
import { v4 as uuidv4 } from "uuid";

import { redisClient } from "@app/lib/redis";
import logger from "@app/logger/logger";

export const statsDClient = new StatsD();
export async function rateLimiter(
key: string,
maxPerTimeframe: number,
timeframeSeconds: number
): Promise<number> {
let redis: undefined | Awaited<ReturnType<typeof redisClient>> = undefined;
const now = new Date();
const tags = [`rate_limiter:${key}`];
try {
redis = await redisClient();
const redisKey = `rate_limiter:${key}`;

const zcountRes = await redis.zCount(
redisKey,
new Date().getTime() - timeframeSeconds * 1000,
"+inf"
);
const remaining = maxPerTimeframe - zcountRes;
if (remaining > 0) {
await redis.zAdd(redisKey, {
score: new Date().getTime(),
value: uuidv4(),
});
await redis.expire(redisKey, timeframeSeconds * 2);
} else {
statsDClient.increment("ratelimiter.exceeded.count", 1, tags);
}
const totalTimeMs = new Date().getTime() - now.getTime();

statsDClient.distribution(
"ratelimiter.latency.distribution",
totalTimeMs,
tags
);

return remaining;
} catch (e) {
statsDClient.increment("ratelimiter.error.count", 1, tags);
logger.error(
{
key,
maxPerTimeframe,
timeframeSeconds,
error: e,
},
`RateLimiter error`
);

// in case of error on our side, we allow the request.
return 1;
} finally {
if (redis) {
await redis.quit();
}
}
}

0 comments on commit 30518fa

Please sign in to comment.