Skip to content

Commit

Permalink
Improve Redis connection hygiene (#6187)
Browse files Browse the repository at this point in the history
* Improve Redis connection hygiene

* ✂️

* Add acquireTimeoutMillis

* ✨

* 📖
  • Loading branch information
flvndvd authored Jul 12, 2024
1 parent d33cf48 commit af3fbe1
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 238 deletions.
157 changes: 77 additions & 80 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ import {
renderConversationForModelMultiActions,
} from "@app/lib/api/assistant/generation";
import { isLegacyAgentConfiguration } from "@app/lib/api/assistant/legacy_agent";
import { getRedisClient } from "@app/lib/api/redis";
import type { Authenticator } from "@app/lib/auth";
import { AgentMessageContent } from "@app/lib/models/assistant/agent_message_content";
import { redisClient } from "@app/lib/redis";
import logger from "@app/logger/logger";

const CANCELLATION_CHECK_INTERVAL = 500;
Expand Down Expand Up @@ -522,7 +522,7 @@ export async function* runMultiActionsAgent(

let shouldYieldCancel = false;
let lastCheckCancellation = Date.now();
const redis = await redisClient();
const redis = await getRedisClient();
let isGeneration = true;
const contentParser = new AgentMessageContentParser(
agentConfiguration,
Expand All @@ -531,34 +531,76 @@ export async function* runMultiActionsAgent(
);

let rawContent = "";
try {
const _checkCancellation = async () => {
try {
const cancelled = await redis.get(
`assistant:generation:cancelled:${agentMessage.sId}`

const _checkCancellation = async () => {
try {
const cancelled = await redis.get(
`assistant:generation:cancelled:${agentMessage.sId}`
);
if (cancelled === "1") {
shouldYieldCancel = true;
await redis.set(
`assistant:generation:cancelled:${agentMessage.sId}`,
0,
{
EX: 3600, // 1 hour
}
);
if (cancelled === "1") {
shouldYieldCancel = true;
await redis.set(
`assistant:generation:cancelled:${agentMessage.sId}`,
0,
{
EX: 3600, // 1 hour
}
);
}
} catch (error) {
logger.error({ error }, "Error checking cancellation");
return false;
}
};
} catch (error) {
logger.error({ error }, "Error checking cancellation");
return false;
}
};

for await (const event of eventStream) {
if (event.type === "function_call") {
isGeneration = false;
}
for await (const event of eventStream) {
if (event.type === "function_call") {
isGeneration = false;
}

if (event.type === "error") {
if (event.type === "error") {
yield* contentParser.flushTokens();
yield {
type: "agent_error",
created: Date.now(),
configurationId: agentConfiguration.sId,
messageId: agentMessage.sId,
error: {
code: "multi_actions_error",
message: `Error running assistant: ${event.content.message}`,
},
} satisfies AgentErrorEvent;
return;
}

const currentTimestamp = Date.now();
if (
currentTimestamp - lastCheckCancellation >=
CANCELLATION_CHECK_INTERVAL
) {
void _checkCancellation(); // Trigger the async function without awaiting
lastCheckCancellation = currentTimestamp;
}

if (shouldYieldCancel) {
yield* contentParser.flushTokens();
yield {
type: "generation_cancel",
created: Date.now(),
configurationId: agentConfiguration.sId,
messageId: agentMessage.sId,
} satisfies GenerationCancelEvent;
return;
}

if (event.type === "tokens" && isGeneration) {
rawContent += event.content.tokens.text;
yield* contentParser.emitTokens(event.content.tokens.text);
}

if (event.type === "block_execution") {
const e = event.content.execution[0][0];
if (e.error) {
yield* contentParser.flushTokens();
yield {
type: "agent_error",
Expand All @@ -567,71 +609,26 @@ export async function* runMultiActionsAgent(
messageId: agentMessage.sId,
error: {
code: "multi_actions_error",
message: `Error running assistant: ${event.content.message}`,
message: `Error running assistant: ${e.error}`,
},
} satisfies AgentErrorEvent;
return;
}

const currentTimestamp = Date.now();
if (
currentTimestamp - lastCheckCancellation >=
CANCELLATION_CHECK_INTERVAL
) {
void _checkCancellation(); // Trigger the async function without awaiting
lastCheckCancellation = currentTimestamp;
}

if (shouldYieldCancel) {
if (event.content.block_name === "OUTPUT" && e.value) {
// Flush early as we know the generation is terminated here.
yield* contentParser.flushTokens();
yield {
type: "generation_cancel",
created: Date.now(),
configurationId: agentConfiguration.sId,
messageId: agentMessage.sId,
} satisfies GenerationCancelEvent;
return;
}

if (event.type === "tokens" && isGeneration) {
rawContent += event.content.tokens.text;
yield* contentParser.emitTokens(event.content.tokens.text);
}

if (event.type === "block_execution") {
const e = event.content.execution[0][0];
if (e.error) {
yield* contentParser.flushTokens();
yield {
type: "agent_error",
created: Date.now(),
configurationId: agentConfiguration.sId,
messageId: agentMessage.sId,
error: {
code: "multi_actions_error",
message: `Error running assistant: ${e.error}`,
},
} satisfies AgentErrorEvent;
return;
const v = e.value as any;
if ("actions" in v) {
output.actions = v.actions;
}

if (event.content.block_name === "OUTPUT" && e.value) {
// Flush early as we know the generation is terminated here.
yield* contentParser.flushTokens();

const v = e.value as any;
if ("actions" in v) {
output.actions = v.actions;
}
if ("generation" in v) {
output.generation = v.generation;
}
break;
if ("generation" in v) {
output.generation = v.generation;
}
break;
}
}
} finally {
await redis.quit();
}

yield* contentParser.flushTokens();
Expand Down
114 changes: 48 additions & 66 deletions front/lib/api/assistant/agent_usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ import type { AgentConfigurationType } from "@dust-tt/types";
import type { RedisClientType } from "redis";
import { literal, Op, Sequelize } from "sequelize";

import { getRedisClient } from "@app/lib/api/redis";
import type { Authenticator } from "@app/lib/auth";
import {
Conversation,
Mention,
Message,
} from "@app/lib/models/assistant/conversation";
import { Workspace } from "@app/lib/models/workspace";
import { redisClient } from "@app/lib/redis";
import { launchMentionsCountWorkflow } from "@app/temporal/mentions_count_queue/client";

// Ranking of agents is done over a 7 days period.
Expand Down Expand Up @@ -60,40 +60,34 @@ export async function getAgentsUsage({

const agentMessageCountKey = _getUsageKey(workspaceId);

try {
redis = providedRedis ?? (await redisClient());
const agentMessageCountTTL = await redis.ttl(agentMessageCountKey);

// agent mention count doesn't exist or wasn't set to expire
if (
agentMessageCountTTL === TTL_KEY_NOT_EXIST ||
agentMessageCountTTL === TTL_KEY_NOT_SET
) {
await launchMentionsCountWorkflow({ workspaceId });
return [];
// agent mention count is stale
} else if (
agentMessageCountTTL <
MENTION_COUNT_TTL - MENTION_COUNT_UPDATE_PERIOD_SEC
) {
await launchMentionsCountWorkflow({ workspaceId });
}

// Retrieve and parse agents usage
const agentsUsage = await redis.hGetAll(agentMessageCountKey);
return Object.entries(agentsUsage)
.map(([agentId, count]) => ({
agentId,
messageCount: parseInt(count),
timePeriodSec: rankingTimeframeSec,
}))
.sort((a, b) => b.messageCount - a.messageCount)
.slice(0, limit);
} finally {
if (redis && !providedRedis) {
await redis.quit();
}
redis = providedRedis ?? (await getRedisClient());
const agentMessageCountTTL = await redis.ttl(agentMessageCountKey);

// agent mention count doesn't exist or wasn't set to expire
if (
agentMessageCountTTL === TTL_KEY_NOT_EXIST ||
agentMessageCountTTL === TTL_KEY_NOT_SET
) {
await launchMentionsCountWorkflow({ workspaceId });
return [];
// agent mention count is stale
} else if (
agentMessageCountTTL <
MENTION_COUNT_TTL - MENTION_COUNT_UPDATE_PERIOD_SEC
) {
await launchMentionsCountWorkflow({ workspaceId });
}

// Retrieve and parse agents usage
const agentsUsage = await redis.hGetAll(agentMessageCountKey);
return Object.entries(agentsUsage)
.map(([agentId, count]) => ({
agentId,
messageCount: parseInt(count),
timePeriodSec: rankingTimeframeSec,
}))
.sort((a, b) => b.messageCount - a.messageCount)
.slice(0, limit);
}

export async function getAgentUsage(
Expand Down Expand Up @@ -121,25 +115,19 @@ export async function getAgentUsage(

const agentMessageCountKey = _getUsageKey(workspaceId);

try {
redis = providedRedis ?? (await redisClient());

const agentUsage = await redis.hGet(
agentMessageCountKey,
agentConfigurationId
);
return agentUsage
? {
agentId: agentConfigurationId,
messageCount: parseInt(agentUsage, 10),
timePeriodSec: rankingTimeframeSec,
}
: null;
} finally {
if (redis && !providedRedis) {
await redis.quit();
}
}
redis = providedRedis ?? (await getRedisClient());

const agentUsage = await redis.hGet(
agentMessageCountKey,
agentConfigurationId
);
return agentUsage
? {
agentId: agentConfigurationId,
messageCount: parseInt(agentUsage, 10),
timePeriodSec: rankingTimeframeSec,
}
: null;
}

export async function agentMentionsCount(
Expand Down Expand Up @@ -227,18 +215,12 @@ export async function signalAgentUsage({
}) {
let redis: RedisClientType | null = null;

try {
redis = await redisClient();
const agentMessageCountKey = _getUsageKey(workspaceId);
const agentMessageCountTTL = await redis.ttl(agentMessageCountKey);

if (agentMessageCountTTL !== TTL_KEY_NOT_EXIST) {
// We only want to increment if the counts have already been computed
await redis.hIncrBy(agentMessageCountKey, agentConfigurationId, 1);
}
} finally {
if (redis) {
await redis.quit();
}
redis = await getRedisClient();
const agentMessageCountKey = _getUsageKey(workspaceId);
const agentMessageCountTTL = await redis.ttl(agentMessageCountKey);

if (agentMessageCountTTL !== TTL_KEY_NOT_EXIST) {
// We only want to increment if the counts have already been computed
await redis.hIncrBy(agentMessageCountKey, agentConfigurationId, 1);
}
}
Loading

0 comments on commit af3fbe1

Please sign in to comment.