From eb5c21f18e72eeaab121611bc0a1de3137dffea1 Mon Sep 17 00:00:00 2001 From: Aric Lasry Date: Fri, 22 Dec 2023 11:12:50 +0100 Subject: [PATCH] Assistants usage (#2959) * Write side of assistants usage * WIP * Fully functional + UI * Clean up * Clean up after self review * signal agent usage when editing a message * Update front/pages/api/w/[wId]/assistant/agent_configurations/[aId]/usage.ts Co-authored-by: Philippe Rolet * Clean up * Message plural --------- Co-authored-by: Philippe Rolet --- .../components/assistant/AssistantDetails.tsx | 42 ++- front/lib/api/assistant/agent_usage.ts | 254 ++++++++++++++++++ front/lib/api/assistant/conversation.ts | 28 +- front/lib/models/assistant/conversation.ts | 1 + front/lib/swr.ts | 23 ++ .../agent_configurations/[aId]/usage.ts | 77 ++++++ types/src/front/assistant/agent.ts | 8 + 7 files changed, 431 insertions(+), 2 deletions(-) create mode 100644 front/lib/api/assistant/agent_usage.ts create mode 100644 front/pages/api/w/[wId]/assistant/agent_configurations/[aId]/usage.ts diff --git a/front/components/assistant/AssistantDetails.tsx b/front/components/assistant/AssistantDetails.tsx index 7d9f232e045c..12a5d08d4377 100644 --- a/front/components/assistant/AssistantDetails.tsx +++ b/front/components/assistant/AssistantDetails.tsx @@ -11,6 +11,7 @@ import { XMarkIcon, } from "@dust-tt/sparkle"; import { + AgentUsageType, AgentUserListStatus, ConnectorProvider, DatabaseQueryConfigurationType, @@ -33,7 +34,7 @@ import ReactMarkdown from "react-markdown"; import { DeleteAssistantDialog } from "@app/components/assistant/AssistantActions"; import { SendNotificationsContext } from "@app/components/sparkle/Notification"; import { CONNECTOR_CONFIGURATIONS } from "@app/lib/connector_providers"; -import { useApp, useDatabase } from "@app/lib/swr"; +import { useAgentUsage, useApp, useDatabase } from "@app/lib/swr"; import { PostAgentListStatusRequestBody } from "@app/pages/api/w/[wId]/members/me/agent_list_status"; type AssistantDetailsFlow = "personal" | "workspace"; @@ -53,6 +54,10 @@ export function AssistantDetails({ onUpdate: () => void; flow: AssistantDetailsFlow; }) { + const agentUsage = useAgentUsage({ + workspaceId: owner.sId, + agentConfigurationId: assistant.sId, + }); const DescriptionSection = () => (
( +
+
Usage
+ {(() => { + if (isError) { + return "Error loading usage data."; + } else if (isLoading) { + return "Loading usage data..."; + } else if (usage) { + return ( + <> + @{assistant.name} has been used by {usage.userCount} people in{" "} + {usage.messageCount}{" "} + {usage.messageCount > 1 ? <>messages : <>message} over the + last {usage.timePeriodSec / (60 * 60 * 24)} days. + + ); + } + })()} +
+ ); + const ActionSection = () => assistant.action ? ( isDustAppRunConfiguration(assistant.action) ? ( @@ -119,6 +154,11 @@ export function AssistantDetails({ /> +
diff --git a/front/lib/api/assistant/agent_usage.ts b/front/lib/api/assistant/agent_usage.ts new file mode 100644 index 000000000000..46de05bb571c --- /dev/null +++ b/front/lib/api/assistant/agent_usage.ts @@ -0,0 +1,254 @@ +import { AgentUsageType, ModelId } from "@dust-tt/types"; +import { literal, Op } from "sequelize"; +import { v4 as uuidv4 } from "uuid"; + +import { + Conversation as DBConversation, + Mention, + Message, + UserMessage, + Workspace, +} from "@app/lib/models"; +import { redisClient } from "@app/lib/redis"; + +// Ranking of agents is done over a 30 days period. +const rankingTimeframeSec = 60 * 60 * 24 * 30; // 30 days + +function _getKeys({ + workspaceId, + agentConfigurationId, +}: { + workspaceId: string; + agentConfigurationId: string; +}) { + // One sorted set per agent for counting the number of times the agent has been used. + // score is: timestamp of each use of the agent + // value: random unique distinct value. + const agentMessageCountKey = `agent_usage_count_${workspaceId}_${agentConfigurationId}`; + + // One sorted set per agent for counting the number of users that have used the agent. + // score is: timestamp of last usage by a given user + // value: user_id + const agentUserCountKey = `agent_user_count_${workspaceId}_${agentConfigurationId}`; + return { + agentMessageCountKey, + agentUserCountKey, + }; +} + +async function signalInRedis({ + agentConfigurationId, + workspaceId, + userId, + timestamp, + redis, +}: { + agentConfigurationId: string; + workspaceId: string; + userId: string; + timestamp: number; + messageId: ModelId; + redis: Awaited>; +}) { + const { agentMessageCountKey, agentUserCountKey } = _getKeys({ + workspaceId, + agentConfigurationId, + }); + + await redis.zAdd(agentMessageCountKey, { + score: timestamp, + value: uuidv4(), + }); + await redis.expire(agentMessageCountKey, rankingTimeframeSec); + + await redis.zAdd(agentUserCountKey, { + score: timestamp, + value: userId, + }); + await redis.expire(agentUserCountKey, rankingTimeframeSec); +} + +async function populateUsageIfNeeded({ + agentConfigurationId, + workspaceId, + messageId, + redis, +}: { + agentConfigurationId: string; + workspaceId: string; + messageId: ModelId | null; + redis: Awaited>; +}) { + const owner = await Workspace.findOne({ + where: { + sId: workspaceId, + }, + }); + if (!owner) { + throw new Error(`Workspace ${workspaceId} not found`); + } + const { agentMessageCountKey, agentUserCountKey } = _getKeys({ + agentConfigurationId, + workspaceId, + }); + + const existCount = await redis.exists([ + agentMessageCountKey, + agentUserCountKey, + ]); + if (existCount === 0) { + // Sorted sets for this agent usage do not exist, we'll populate them + // by fetching the data from the database. + // We need to ensure that only one process is going through the populate code path + // so we are using redis.incr() to act as a non blocking lock. + const populateLockKey = `agent_usage_populate_${workspaceId}_${agentConfigurationId}`; + const needToPopulate = (await redis.incr(populateLockKey)) === 1; + + // Keeping the lock key around for 10 minutes, which essentially gives 10 minutes + // to create the sorted sets, before running the risk of a race conditions. + // A race condition in creating the sorted sets would result in double counting + // usage of the agent. + const populateTimeoutSec = 60 * 10; // 10 minutes + await redis.expire(populateLockKey, populateTimeoutSec); + if (!needToPopulate) { + return; + } + + // We are safe to populate the sorted sets until the Redis populateLockKey expires. + // Get all mentions for this agent that have a messageId smaller than messageId + // and that happened within the last 30 days. + const mentions = await Mention.findAll({ + where: { + ...{ + agentConfigurationId: agentConfigurationId, + createdAt: { + [Op.gt]: literal(`NOW() - INTERVAL '30 days'`), + }, + }, + ...(messageId ? { messageId: { [Op.lt]: messageId } } : {}), + }, + include: [ + { + model: Message, + required: true, + include: [ + { + model: UserMessage, + as: "userMessage", + required: true, + }, + { + model: DBConversation, + as: "conversation", + required: true, + where: { + workspaceId: owner.id, + }, + }, + ], + }, + ], + }); + for (const mention of mentions) { + // No need to promise.all() here, as one Redis connection can only execute one command + // at a time. + if (mention.message?.userMessage) { + await signalInRedis({ + agentConfigurationId, + workspaceId, + userId: + mention.message.userMessage.userId?.toString() || + mention.message.userMessage.userContextEmail || + mention.message.userMessage.userContextUsername, + timestamp: mention.createdAt.getTime(), + messageId: mention.messageId, + redis, + }); + } + } + } +} + +export async function getAgentUsage({ + workspaceId, + agentConfigurationId, +}: { + workspaceId: string; + agentConfigurationId: string; +}): Promise { + let redis: Awaited> | null = null; + + const { agentMessageCountKey, agentUserCountKey } = _getKeys({ + agentConfigurationId, + workspaceId, + }); + + try { + redis = await redisClient(); + await populateUsageIfNeeded({ + agentConfigurationId, + workspaceId, + messageId: null, + redis, + }); + const now = new Date(); + const thirtyDaysAgo = new Date(now.getTime() - 1000 * rankingTimeframeSec); + const messageCount = await redis.zCount( + agentMessageCountKey, + thirtyDaysAgo.getTime(), + now.getTime() + ); + const userCount = await redis.zCount( + agentUserCountKey, + thirtyDaysAgo.getTime(), + now.getTime() + ); + + return { + messageCount, + userCount, + timePeriodSec: rankingTimeframeSec, + }; + } finally { + if (redis) { + await redis.quit(); + } + } +} + +export async function signalAgentUsage({ + agentConfigurationId, + workspaceId, + userId, + timestamp, + messageId, +}: { + agentConfigurationId: string; + workspaceId: string; + userId: string; + timestamp: number; + messageId: ModelId; +}) { + let redis: Awaited> | null = null; + try { + redis = await redisClient(); + await populateUsageIfNeeded({ + agentConfigurationId, + workspaceId, + messageId, + redis, + }); + await signalInRedis({ + agentConfigurationId, + workspaceId, + userId, + timestamp, + messageId, + redis, + }); + } finally { + if (redis) { + await redis.quit(); + } + } +} diff --git a/front/lib/api/assistant/conversation.ts b/front/lib/api/assistant/conversation.ts index bdfd875e0545..68dcf2fe579b 100644 --- a/front/lib/api/assistant/conversation.ts +++ b/front/lib/api/assistant/conversation.ts @@ -41,6 +41,7 @@ import { Op, Transaction } from "sequelize"; import { runActionStreamed } from "@app/lib/actions/server"; import { runAgent } from "@app/lib/api/assistant/agent"; +import { signalAgentUsage } from "@app/lib/api/assistant/agent_usage"; import { getAgentConfiguration } from "@app/lib/api/assistant/configuration"; import { renderConversationForModel } from "@app/lib/api/assistant/generation"; import { Authenticator } from "@app/lib/auth"; @@ -1055,7 +1056,17 @@ export async function* postUserMessage( if (agentMessageRows.length !== agentMessages.length) { throw new Error("Unreachable: agentMessageRows and agentMessages mismatch"); } - + if (agentMessages.length > 0) { + for (const agentMessage of agentMessages) { + void signalAgentUsage({ + userId: user?.id.toString() || context.email || context.username, + agentConfigurationId: agentMessage.configuration.sId, + workspaceId: owner.sId, + messageId: agentMessage.id, + timestamp: agentMessage.created, + }); + } + } yield { type: "user_message_new", created: Date.now(), @@ -1534,6 +1545,21 @@ export async function* editUserMessage( message: userMessage, }; + if (agentMessages.length > 0) { + for (const agentMessage of agentMessages) { + void signalAgentUsage({ + userId: + user?.id.toString() || + message.context.email || + message.context.username, + agentConfigurationId: agentMessage.configuration.sId, + messageId: agentMessage.id, + timestamp: agentMessage.created, + workspaceId: owner.sId, + }); + } + } + for (let i = 0; i < agentMessages.length; i++) { const agentMessage = agentMessages[i]; diff --git a/front/lib/models/assistant/conversation.ts b/front/lib/models/assistant/conversation.ts index 28bfc9d1c16c..179e45e585df 100644 --- a/front/lib/models/assistant/conversation.ts +++ b/front/lib/models/assistant/conversation.ts @@ -659,6 +659,7 @@ export class Mention extends Model< declare userId: ForeignKey | null; declare agentConfigurationId: string | null; // Not a relation as global agents are not in the DB + declare message: NonAttribute; declare user?: NonAttribute; } diff --git a/front/lib/swr.ts b/front/lib/swr.ts index 483101b464a2..1d671146bfb0 100644 --- a/front/lib/swr.ts +++ b/front/lib/swr.ts @@ -20,6 +20,7 @@ import { GetRunsResponseBody } from "@app/pages/api/w/[wId]/apps/[aId]/runs"; import { GetRunBlockResponseBody } from "@app/pages/api/w/[wId]/apps/[aId]/runs/[runId]/blocks/[type]/[name]"; import { GetRunStatusResponseBody } from "@app/pages/api/w/[wId]/apps/[aId]/runs/[runId]/status"; import { GetAgentConfigurationsResponseBody } from "@app/pages/api/w/[wId]/assistant/agent_configurations"; +import { GetAgentUsageResponseBody } from "@app/pages/api/w/[wId]/assistant/agent_configurations/[aId]/usage"; import { GetAgentNamesResponseBody } from "@app/pages/api/w/[wId]/assistant/agent_configurations/names"; import { GetDataSourcesResponseBody } from "@app/pages/api/w/[wId]/data_sources"; import { GetDocumentsResponseBody } from "@app/pages/api/w/[wId]/data_sources/[name]/documents"; @@ -467,6 +468,28 @@ export function useAgentConfigurations({ }; } +export function useAgentUsage({ + workspaceId, + agentConfigurationId, +}: { + workspaceId: string; + agentConfigurationId: string; +}) { + const agentUsageFetcher: Fetcher = fetcher; + + const { data, error, mutate } = useSWR( + `/api/w/${workspaceId}/assistant/agent_configurations/${agentConfigurationId}/usage`, + agentUsageFetcher + ); + + return { + agentUsage: data ? data.agentUsage : null, + isAgentUsageLoading: !error && !data, + isAgentUsageError: error, + mutateAgentUsage: mutate, + }; +} + export function useSlackChannelsLinkedWithAgent({ workspaceId, dataSourceName, diff --git a/front/pages/api/w/[wId]/assistant/agent_configurations/[aId]/usage.ts b/front/pages/api/w/[wId]/assistant/agent_configurations/[aId]/usage.ts new file mode 100644 index 000000000000..cd48c6dae97a --- /dev/null +++ b/front/pages/api/w/[wId]/assistant/agent_configurations/[aId]/usage.ts @@ -0,0 +1,77 @@ +import { AgentUsageType } from "@dust-tt/types"; +import { ReturnedAPIErrorType } from "@dust-tt/types"; +import { NextApiRequest, NextApiResponse } from "next"; + +import { getAgentUsage } from "@app/lib/api/assistant/agent_usage"; +import { getAgentConfiguration } from "@app/lib/api/assistant/configuration"; +import { Authenticator, getSession } from "@app/lib/auth"; +import { apiError, withLogging } from "@app/logger/withlogging"; + +export type GetAgentUsageResponseBody = { + agentUsage: AgentUsageType; +}; + +async function handler( + req: NextApiRequest, + res: NextApiResponse +): Promise { + const session = await getSession(req, res); + const auth = await Authenticator.fromSession( + session, + req.query.wId as string + ); + const owner = auth.workspace(); + if (!owner) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "workspace_not_found", + message: "The workspace you're trying to access was not found.", + }, + }); + } + + if (!auth.isUser()) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "app_auth_error", + message: + "Only the users that are members for the current workspace can access the workspace's assistants.", + }, + }); + } + + switch (req.method) { + case "GET": + const agentConfiguration = await getAgentConfiguration( + auth, + req.query.aId as string + ); + if (!agentConfiguration) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "agent_configuration_not_found", + message: "The Assistant you're trying to access was not found.", + }, + }); + } + const agentUsage = await getAgentUsage({ + agentConfigurationId: agentConfiguration.sId, + workspaceId: owner.sId, + }); + return res.status(200).json({ agentUsage }); + + default: + return apiError(req, res, { + status_code: 405, + api_error: { + type: "method_not_supported_error", + message: "The method passed is not supported, POST is expected.", + }, + }); + } +} + +export default withLogging(handler); diff --git a/types/src/front/assistant/agent.ts b/types/src/front/assistant/agent.ts index 17190862238e..b60af4e97c59 100644 --- a/types/src/front/assistant/agent.ts +++ b/types/src/front/assistant/agent.ts @@ -132,3 +132,11 @@ export type AgentConfigurationType = { // If undefined, no text generation. generation: AgentGenerationConfigurationType | null; }; + +export type AgentUsageType = { + userCount: number; + messageCount: number; + + // userCount and messageCount are over the last `timePeriodSec` seconds + timePeriodSec: number; +};