From af3fbe1955b500fe83a415ec26ca22e1384d390f Mon Sep 17 00:00:00 2001 From: Flavien David Date: Fri, 12 Jul 2024 14:17:23 +0200 Subject: [PATCH] Improve Redis connection hygiene (#6187) * Improve Redis connection hygiene * :scissors: * Add acquireTimeoutMillis * :sparkles: * :book: --- front/lib/api/assistant/agent.ts | 157 +++++++++--------- front/lib/api/assistant/agent_usage.ts | 114 ++++++------- front/lib/api/assistant/pubsub.ts | 109 ++++++------ front/lib/api/assistant/recent_authors.ts | 7 +- front/lib/api/redis.ts | 40 +++++ front/lib/redis.ts | 28 ---- .../assistant/agent_configurations/index.ts | 4 +- .../mentions_count_queue/activities.ts | 4 +- 8 files changed, 225 insertions(+), 238 deletions(-) create mode 100644 front/lib/api/redis.ts delete mode 100644 front/lib/redis.ts diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index 97d8bc76dd81..0119c9884196 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -43,9 +43,9 @@ import { renderConversationForModelMultiActions, } from "@app/lib/api/assistant/generation"; import { isLegacyAgentConfiguration } from "@app/lib/api/assistant/legacy_agent"; +import { getRedisClient } from "@app/lib/api/redis"; import type { Authenticator } from "@app/lib/auth"; import { AgentMessageContent } from "@app/lib/models/assistant/agent_message_content"; -import { redisClient } from "@app/lib/redis"; import logger from "@app/logger/logger"; const CANCELLATION_CHECK_INTERVAL = 500; @@ -522,7 +522,7 @@ export async function* runMultiActionsAgent( let shouldYieldCancel = false; let lastCheckCancellation = Date.now(); - const redis = await redisClient(); + const redis = await getRedisClient(); let isGeneration = true; const contentParser = new AgentMessageContentParser( agentConfiguration, @@ -531,34 +531,76 @@ export async function* runMultiActionsAgent( ); let rawContent = ""; - try { - const _checkCancellation = async () => { - try { - const cancelled = await redis.get( - `assistant:generation:cancelled:${agentMessage.sId}` + + const _checkCancellation = async () => { + try { + const cancelled = await redis.get( + `assistant:generation:cancelled:${agentMessage.sId}` + ); + if (cancelled === "1") { + shouldYieldCancel = true; + await redis.set( + `assistant:generation:cancelled:${agentMessage.sId}`, + 0, + { + EX: 3600, // 1 hour + } ); - if (cancelled === "1") { - shouldYieldCancel = true; - await redis.set( - `assistant:generation:cancelled:${agentMessage.sId}`, - 0, - { - EX: 3600, // 1 hour - } - ); - } - } catch (error) { - logger.error({ error }, "Error checking cancellation"); - return false; } - }; + } catch (error) { + logger.error({ error }, "Error checking cancellation"); + return false; + } + }; - for await (const event of eventStream) { - if (event.type === "function_call") { - isGeneration = false; - } + for await (const event of eventStream) { + if (event.type === "function_call") { + isGeneration = false; + } - if (event.type === "error") { + if (event.type === "error") { + yield* contentParser.flushTokens(); + yield { + type: "agent_error", + created: Date.now(), + configurationId: agentConfiguration.sId, + messageId: agentMessage.sId, + error: { + code: "multi_actions_error", + message: `Error running assistant: ${event.content.message}`, + }, + } satisfies AgentErrorEvent; + return; + } + + const currentTimestamp = Date.now(); + if ( + currentTimestamp - lastCheckCancellation >= + CANCELLATION_CHECK_INTERVAL + ) { + void _checkCancellation(); // Trigger the async function without awaiting + lastCheckCancellation = currentTimestamp; + } + + if (shouldYieldCancel) { + yield* contentParser.flushTokens(); + yield { + type: "generation_cancel", + created: Date.now(), + configurationId: agentConfiguration.sId, + messageId: agentMessage.sId, + } satisfies GenerationCancelEvent; + return; + } + + if (event.type === "tokens" && isGeneration) { + rawContent += event.content.tokens.text; + yield* contentParser.emitTokens(event.content.tokens.text); + } + + if (event.type === "block_execution") { + const e = event.content.execution[0][0]; + if (e.error) { yield* contentParser.flushTokens(); yield { type: "agent_error", @@ -567,71 +609,26 @@ export async function* runMultiActionsAgent( messageId: agentMessage.sId, error: { code: "multi_actions_error", - message: `Error running assistant: ${event.content.message}`, + message: `Error running assistant: ${e.error}`, }, } satisfies AgentErrorEvent; return; } - const currentTimestamp = Date.now(); - if ( - currentTimestamp - lastCheckCancellation >= - CANCELLATION_CHECK_INTERVAL - ) { - void _checkCancellation(); // Trigger the async function without awaiting - lastCheckCancellation = currentTimestamp; - } - - if (shouldYieldCancel) { + if (event.content.block_name === "OUTPUT" && e.value) { + // Flush early as we know the generation is terminated here. yield* contentParser.flushTokens(); - yield { - type: "generation_cancel", - created: Date.now(), - configurationId: agentConfiguration.sId, - messageId: agentMessage.sId, - } satisfies GenerationCancelEvent; - return; - } - if (event.type === "tokens" && isGeneration) { - rawContent += event.content.tokens.text; - yield* contentParser.emitTokens(event.content.tokens.text); - } - - if (event.type === "block_execution") { - const e = event.content.execution[0][0]; - if (e.error) { - yield* contentParser.flushTokens(); - yield { - type: "agent_error", - created: Date.now(), - configurationId: agentConfiguration.sId, - messageId: agentMessage.sId, - error: { - code: "multi_actions_error", - message: `Error running assistant: ${e.error}`, - }, - } satisfies AgentErrorEvent; - return; + const v = e.value as any; + if ("actions" in v) { + output.actions = v.actions; } - - if (event.content.block_name === "OUTPUT" && e.value) { - // Flush early as we know the generation is terminated here. - yield* contentParser.flushTokens(); - - const v = e.value as any; - if ("actions" in v) { - output.actions = v.actions; - } - if ("generation" in v) { - output.generation = v.generation; - } - break; + if ("generation" in v) { + output.generation = v.generation; } + break; } } - } finally { - await redis.quit(); } yield* contentParser.flushTokens(); diff --git a/front/lib/api/assistant/agent_usage.ts b/front/lib/api/assistant/agent_usage.ts index 044c482b8a49..c0f896034c9e 100644 --- a/front/lib/api/assistant/agent_usage.ts +++ b/front/lib/api/assistant/agent_usage.ts @@ -2,6 +2,7 @@ import type { AgentConfigurationType } from "@dust-tt/types"; import type { RedisClientType } from "redis"; import { literal, Op, Sequelize } from "sequelize"; +import { getRedisClient } from "@app/lib/api/redis"; import type { Authenticator } from "@app/lib/auth"; import { Conversation, @@ -9,7 +10,6 @@ import { Message, } from "@app/lib/models/assistant/conversation"; import { Workspace } from "@app/lib/models/workspace"; -import { redisClient } from "@app/lib/redis"; import { launchMentionsCountWorkflow } from "@app/temporal/mentions_count_queue/client"; // Ranking of agents is done over a 7 days period. @@ -60,40 +60,34 @@ export async function getAgentsUsage({ const agentMessageCountKey = _getUsageKey(workspaceId); - try { - redis = providedRedis ?? (await redisClient()); - const agentMessageCountTTL = await redis.ttl(agentMessageCountKey); - - // agent mention count doesn't exist or wasn't set to expire - if ( - agentMessageCountTTL === TTL_KEY_NOT_EXIST || - agentMessageCountTTL === TTL_KEY_NOT_SET - ) { - await launchMentionsCountWorkflow({ workspaceId }); - return []; - // agent mention count is stale - } else if ( - agentMessageCountTTL < - MENTION_COUNT_TTL - MENTION_COUNT_UPDATE_PERIOD_SEC - ) { - await launchMentionsCountWorkflow({ workspaceId }); - } - - // Retrieve and parse agents usage - const agentsUsage = await redis.hGetAll(agentMessageCountKey); - return Object.entries(agentsUsage) - .map(([agentId, count]) => ({ - agentId, - messageCount: parseInt(count), - timePeriodSec: rankingTimeframeSec, - })) - .sort((a, b) => b.messageCount - a.messageCount) - .slice(0, limit); - } finally { - if (redis && !providedRedis) { - await redis.quit(); - } + redis = providedRedis ?? (await getRedisClient()); + const agentMessageCountTTL = await redis.ttl(agentMessageCountKey); + + // agent mention count doesn't exist or wasn't set to expire + if ( + agentMessageCountTTL === TTL_KEY_NOT_EXIST || + agentMessageCountTTL === TTL_KEY_NOT_SET + ) { + await launchMentionsCountWorkflow({ workspaceId }); + return []; + // agent mention count is stale + } else if ( + agentMessageCountTTL < + MENTION_COUNT_TTL - MENTION_COUNT_UPDATE_PERIOD_SEC + ) { + await launchMentionsCountWorkflow({ workspaceId }); } + + // Retrieve and parse agents usage + const agentsUsage = await redis.hGetAll(agentMessageCountKey); + return Object.entries(agentsUsage) + .map(([agentId, count]) => ({ + agentId, + messageCount: parseInt(count), + timePeriodSec: rankingTimeframeSec, + })) + .sort((a, b) => b.messageCount - a.messageCount) + .slice(0, limit); } export async function getAgentUsage( @@ -121,25 +115,19 @@ export async function getAgentUsage( const agentMessageCountKey = _getUsageKey(workspaceId); - try { - redis = providedRedis ?? (await redisClient()); - - const agentUsage = await redis.hGet( - agentMessageCountKey, - agentConfigurationId - ); - return agentUsage - ? { - agentId: agentConfigurationId, - messageCount: parseInt(agentUsage, 10), - timePeriodSec: rankingTimeframeSec, - } - : null; - } finally { - if (redis && !providedRedis) { - await redis.quit(); - } - } + redis = providedRedis ?? (await getRedisClient()); + + const agentUsage = await redis.hGet( + agentMessageCountKey, + agentConfigurationId + ); + return agentUsage + ? { + agentId: agentConfigurationId, + messageCount: parseInt(agentUsage, 10), + timePeriodSec: rankingTimeframeSec, + } + : null; } export async function agentMentionsCount( @@ -227,18 +215,12 @@ export async function signalAgentUsage({ }) { let redis: RedisClientType | null = null; - try { - redis = await redisClient(); - const agentMessageCountKey = _getUsageKey(workspaceId); - const agentMessageCountTTL = await redis.ttl(agentMessageCountKey); - - if (agentMessageCountTTL !== TTL_KEY_NOT_EXIST) { - // We only want to increment if the counts have already been computed - await redis.hIncrBy(agentMessageCountKey, agentConfigurationId, 1); - } - } finally { - if (redis) { - await redis.quit(); - } + redis = await getRedisClient(); + const agentMessageCountKey = _getUsageKey(workspaceId); + const agentMessageCountTTL = await redis.ttl(agentMessageCountKey); + + if (agentMessageCountTTL !== TTL_KEY_NOT_EXIST) { + // We only want to increment if the counts have already been computed + await redis.hIncrBy(agentMessageCountKey, agentConfigurationId, 1); } } diff --git a/front/lib/api/assistant/pubsub.ts b/front/lib/api/assistant/pubsub.ts index 40a843421cae..e3dd3502c70e 100644 --- a/front/lib/api/assistant/pubsub.ts +++ b/front/lib/api/assistant/pubsub.ts @@ -24,10 +24,11 @@ import type { } from "@dust-tt/types"; import { assertNever, Err, Ok } from "@dust-tt/types"; import type { RedisClientType } from "redis"; +import { commandOptions } from "redis"; +import { getRedisClient } from "@app/lib/api/redis"; import type { Authenticator } from "@app/lib/auth"; import { AgentMessage, Message } from "@app/lib/models/assistant/conversation"; -import { redisClient } from "@app/lib/redis"; import { wakeLock } from "@app/lib/wake_lock"; import logger from "@app/logger/logger"; @@ -151,7 +152,7 @@ async function handleUserMessageEvents( > > = new Promise((resolve) => { void wakeLock(async () => { - const redis = await redisClient(); + const redis = await getRedisClient(); let didResolve = false; let userMessage: UserMessageType | undefined = undefined; @@ -270,7 +271,6 @@ async function handleUserMessageEvents( "Error Posting message" ); } finally { - await redis.quit(); if (!didResolve) { resolve( new Err({ @@ -302,7 +302,7 @@ export async function retryAgentMessageWithPubSub( const promise: Promise> = new Promise( (resolve) => { void wakeLock(async () => { - const redis = await redisClient(); + const redis = await getRedisClient(); let didResolve = false; try { for await (const event of retryAgentMessage(auth, { @@ -378,7 +378,6 @@ export async function retryAgentMessageWithPubSub( "Error Posting message" ); } finally { - await redis.quit(); if (!didResolve) { resolve( new Err({ @@ -408,40 +407,39 @@ export async function* getConversationEvents( }, void > { - const redis = await redisClient(); + const redis = await getRedisClient(); const pubsubChannel = getConversationChannelId(conversationId); - try { - while (true) { - const events = await redis.xRead( - { key: pubsubChannel, id: lastEventId ? lastEventId : "0-0" }, - { COUNT: 32, BLOCK: 60 * 1000 } - ); - if (!events) { - return; - } - for (const event of events) { - for (const message of event.messages) { - const payloadStr = message.message["payload"]; - const messageId = message.id; - const payload = JSON.parse(payloadStr); - lastEventId = messageId; - yield { - eventId: messageId, - data: payload, - }; - } + while (true) { + // Use an isolated connection to avoid blocking the main connection. + const events = await redis.xRead( + commandOptions({ isolated: true }), + { key: pubsubChannel, id: lastEventId ? lastEventId : "0-0" }, + { COUNT: 32, BLOCK: 60 * 1000 } + ); + if (!events) { + return; + } + + for (const event of events) { + for (const message of event.messages) { + const payloadStr = message.message["payload"]; + const messageId = message.id; + const payload = JSON.parse(payloadStr); + lastEventId = messageId; + yield { + eventId: messageId, + data: payload, + }; } } - } finally { - await redis.quit(); } } export async function cancelMessageGenerationEvent( messageIds: string[] ): Promise { - const redis = await redisClient(); + const redis = await getRedisClient(); try { const tasks = messageIds.map((messageId) => { @@ -473,8 +471,6 @@ export async function cancelMessageGenerationEvent( await Promise.all(tasks); } catch (e) { logger.error({ error: e }, "Error cancelling message generation"); - } finally { - await redis.quit(); } } @@ -494,38 +490,37 @@ export async function* getMessagesEvents( void > { const pubsubChannel = getMessageChannelId(messageId); - const redis = await redisClient(); + const redis = await getRedisClient(); - try { - while (true) { - const events = await redis.xRead( - { key: pubsubChannel, id: lastEventId ? lastEventId : "0-0" }, - { COUNT: 32, BLOCK: 60 * 1000 } - ); - if (!events) { - return; - } - for (const event of events) { - for (const message of event.messages) { - const payloadStr = message.message["payload"]; - const messageId = message.id; - const payload = JSON.parse(payloadStr); - lastEventId = messageId; + while (true) { + // Use an isolated connection to avoid blocking the main connection. + const events = await redis.xRead( + commandOptions({ isolated: true }), + { key: pubsubChannel, id: lastEventId ? lastEventId : "0-0" }, + { COUNT: 32, BLOCK: 60 * 1000 } + ); + if (!events) { + return; + } - // If the payload is an end-of-stream event, we stop the generator. - if (payload.type === "end-of-stream") { - return; - } + for (const event of events) { + for (const message of event.messages) { + const payloadStr = message.message["payload"]; + const messageId = message.id; + const payload = JSON.parse(payloadStr); + lastEventId = messageId; - yield { - eventId: messageId, - data: payload, - }; + // If the payload is an end-of-stream event, we stop the generator. + if (payload.type === "end-of-stream") { + return; } + + yield { + eventId: messageId, + data: payload, + }; } } - } finally { - await redis.quit(); } } diff --git a/front/lib/api/assistant/recent_authors.ts b/front/lib/api/assistant/recent_authors.ts index 726f52ff1d39..f3abc463a99b 100644 --- a/front/lib/api/assistant/recent_authors.ts +++ b/front/lib/api/assistant/recent_authors.ts @@ -6,12 +6,12 @@ import type { import { removeNulls } from "@dust-tt/types"; import { Sequelize } from "sequelize"; +import { runOnRedis } from "@app/lib/api/redis"; import { renderUserType } from "@app/lib/api/user"; import { getGlobalAgentAuthorName } from "@app/lib/assistant"; import type { Authenticator } from "@app/lib/auth"; import { AgentConfiguration } from "@app/lib/models/assistant/agent"; import { User } from "@app/lib/models/user"; -import { safeRedisClient } from "@app/lib/redis"; // We keep the most recent authorIds for 3 days. const recentAuthorIdsKeyTTL = 60 * 60 * 24 * 3; // 3 days. @@ -66,7 +66,8 @@ async function setAuthorIdsWithVersionInRedis( agentId, workspaceId, }); - await safeRedisClient(async (redis) => { + + await runOnRedis(async (redis) => { // Add pairs to the sorted set, only if the version is greater than the one stored. await redis.zAdd(agentRecentAuthorIdsKey, authorIdsWithScore, { GT: true }); // Set the expiry for the sorted set to manage its lifecycle. @@ -149,7 +150,7 @@ export async function getAgentsRecentAuthors({ agentId, workspaceId, }); - let recentAuthorIds = await safeRedisClient(async (redis) => + let recentAuthorIds = await runOnRedis(async (redis) => redis.zRange(agentRecentAuthorIdsKey, 0, 2, { REV: true }) ); if (recentAuthorIds.length === 0) { diff --git a/front/lib/api/redis.ts b/front/lib/api/redis.ts new file mode 100644 index 000000000000..08a7d1b4238c --- /dev/null +++ b/front/lib/api/redis.ts @@ -0,0 +1,40 @@ +import type { RedisClientType } from "redis"; +import { createClient } from "redis"; + +import logger from "@app/logger/logger"; + +let client: RedisClientType; + +export async function getRedisClient(): Promise { + if (!client) { + const { REDIS_URI } = process.env; + if (!REDIS_URI) { + throw new Error("REDIS_URI is not defined"); + } + + client = createClient({ + url: REDIS_URI, + isolationPoolOptions: { + acquireTimeoutMillis: 10000, // 10 seconds. + // We support up to 200 concurrent connections for streaming. + max: 200, + }, + }); + client.on("error", (err) => logger.info({ err }, "Redis Client Error")); + client.on("ready", () => logger.info({}, "Redis Client Ready")); + client.on("connect", () => logger.info({}, "Redis Client Connected")); + client.on("end", () => logger.info({}, "Redis Client End")); + + await client.connect(); + } + + return client; +} + +export async function runOnRedis( + fn: (client: RedisClientType) => PromiseLike +): Promise { + const client = await getRedisClient(); + + return fn(client); +} diff --git a/front/lib/redis.ts b/front/lib/redis.ts deleted file mode 100644 index ab0563459778..000000000000 --- a/front/lib/redis.ts +++ /dev/null @@ -1,28 +0,0 @@ -import type { RedisClientType } from "redis"; -import { createClient } from "redis"; - -export async function redisClient(): Promise { - const { REDIS_URI } = process.env; - if (!REDIS_URI) { - throw new Error("REDIS_URI is not defined"); - } - const client: RedisClientType = createClient({ - url: REDIS_URI, - }); - client.on("error", (err) => console.log("Redis Client Error", err)); - - await client.connect(); - - return client; -} - -export async function safeRedisClient( - fn: (client: RedisClientType) => PromiseLike -): Promise { - const client = await redisClient(); - try { - return await fn(client); - } finally { - await client.quit(); - } -} diff --git a/front/pages/api/w/[wId]/assistant/agent_configurations/index.ts b/front/pages/api/w/[wId]/assistant/agent_configurations/index.ts index d4a9f1e33a35..3265924137aa 100644 --- a/front/pages/api/w/[wId]/assistant/agent_configurations/index.ts +++ b/front/pages/api/w/[wId]/assistant/agent_configurations/index.ts @@ -27,9 +27,9 @@ import { unsafeHardDeleteAgentConfiguration, } from "@app/lib/api/assistant/configuration"; import { getAgentsRecentAuthors } from "@app/lib/api/assistant/recent_authors"; +import { runOnRedis } from "@app/lib/api/redis"; import { withSessionAuthenticationForWorkspace } from "@app/lib/api/wrappers"; import type { Authenticator } from "@app/lib/auth"; -import { safeRedisClient } from "@app/lib/redis"; import { ServerSideTracking } from "@app/lib/tracking/server"; import { apiError } from "@app/logger/withlogging"; @@ -107,7 +107,7 @@ async function handler( sort, }); if (withUsage === "true") { - const mentionCounts = await safeRedisClient(async (redis) => { + const mentionCounts = await runOnRedis(async (redis) => { return getAgentsUsage({ providedRedis: redis, workspaceId: owner.sId, diff --git a/front/temporal/mentions_count_queue/activities.ts b/front/temporal/mentions_count_queue/activities.ts index 3efa8b87cff9..43be5041c8cd 100644 --- a/front/temporal/mentions_count_queue/activities.ts +++ b/front/temporal/mentions_count_queue/activities.ts @@ -2,8 +2,8 @@ import { agentMentionsCount, storeCountsInRedis, } from "@app/lib/api/assistant/agent_usage"; +import { runOnRedis } from "@app/lib/api/redis"; import { Workspace } from "@app/lib/models/workspace"; -import { safeRedisClient } from "@app/lib/redis"; export async function mentionsCountActivity(workspaceId: string) { const owner = await Workspace.findOne({ where: { sId: workspaceId } }); @@ -12,7 +12,7 @@ export async function mentionsCountActivity(workspaceId: string) { } const agentMessageCounts = await agentMentionsCount(owner.id); - await safeRedisClient((redis) => + await runOnRedis((redis) => storeCountsInRedis(workspaceId, agentMessageCounts, redis) ); }