diff --git a/front/lib/api/assistant/actions/retrieval.ts b/front/lib/api/assistant/actions/retrieval.ts index 2513b425da0d..e98460a7ccc4 100644 --- a/front/lib/api/assistant/actions/retrieval.ts +++ b/front/lib/api/assistant/actions/retrieval.ts @@ -1,6 +1,8 @@ import type { + AgentActionConfigurationType, FunctionCallType, FunctionMessageTypeModel, + ModelConfigurationType, ModelId, RetrievalErrorEvent, RetrievalParamsEvent, @@ -18,6 +20,7 @@ import { BaseAction, cloneBaseConfig, DustProdActionRegistry, + isRetrievalConfiguration, } from "@dust-tt/types"; import { Ok } from "@dust-tt/types"; @@ -264,6 +267,62 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS return new Ok(spec); } + // stepTopKAndRefsOffsetForAction returns the references offset and the number of documents an + // action will use as part of the current step. We share topK among multiple instances of a same + // retrieval action so that we don't overflow the context when the model asks for many retrievals + // at the same time. Based on the nature of the retrieval actions (query being `auto` or `null`), + // the topK can vary (exhaustive or not). So we need all the actions of the current step to + // properly split the topK among them and decide which slice of references we will allocate to the + // current action. + static stepTopKAndRefsOffsetForAction({ + action, + model, + stepActionIndex, + stepActions, + }: { + action: RetrievalConfigurationType; + model: ModelConfigurationType; + stepActionIndex: number; + stepActions: AgentActionConfigurationType[]; + }): { topK: number; refsOffset: number } { + const topKForAction = ( + action: RetrievalConfigurationType, + actionCount: number + ) => { + let topK = 16; + if (action.topK === "auto") { + if (action.query === "none") { + topK = model.recommendedExhaustiveTopK; + } else { + topK = model.recommendedTopK; + } + } else { + topK = action.topK; + } + // We split the topK among the actions that are uses of the same action configuration. + return Math.ceil(topK / actionCount); + }; + + const actionCounts: Record = {}; + stepActions.forEach((a) => { + actionCounts[a.sId] = actionCounts[a.sId] ?? 0; + actionCounts[a.sId]++; + }); + + let refsOffset = 0; + for (let i = 0; i < stepActionIndex; i++) { + const r = stepActions[i]; + if (isRetrievalConfiguration(r)) { + refsOffset += topKForAction(r, actionCounts[stepActions[i].sId]); + } + } + + return { + topK: topKForAction(action, actionCounts[action.sId]), + refsOffset, + }; + } + // This method is in charge of running the retrieval and creating an AgentRetrievalAction object in // the database (along with the RetrievalDocument and RetrievalDocumentChunk objects). It does not // create any generic model related to the conversation. It is possible for an AgentRetrievalAction @@ -280,7 +339,13 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS functionCallId, step, }: BaseActionRunParams, - { refsOffset = 0 }: { refsOffset?: number } + { + stepActionIndex, + stepActions, + }: { + stepActionIndex: number; + stepActions: AgentActionConfigurationType[]; + } ): AsyncGenerator< RetrievalParamsEvent | RetrievalSuccessEvent | RetrievalErrorEvent, void @@ -330,17 +395,13 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS const { model } = agentConfiguration; - let topK = 16; - if (actionConfiguration.topK === "auto") { - const supportedModel = getSupportedModelConfig(model); - if (actionConfiguration.query === "none") { - topK = supportedModel.recommendedExhaustiveTopK; - } else { - topK = supportedModel.recommendedTopK; - } - } else { - topK = actionConfiguration.topK; - } + const { topK, refsOffset } = + RetrievalConfigurationServerRunner.stepTopKAndRefsOffsetForAction({ + action: actionConfiguration, + model: getSupportedModelConfig(model), + stepActionIndex, + stepActions, + }); // Create the AgentRetrievalAction object in the database and yield an event for the generation of // the params. We store the action here as the params have been generated, if an error occurs @@ -546,7 +607,17 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS token_count: number; }[]; - const refs = getRefs().slice(refsOffset, refsOffset + v.length); + if (refsOffset + topK > getRefs().length) { + // This is a stream dropping error since the guardrails put in place should prevent this + // from ever happeaning (max 16 actions per step and sharing of topK among actions of + // the same type). + throw new Error( + "The retrieval actions exhausted the total number of references available: " + + `refsOffset=${refsOffset} topK=${topK}` + ); + } + + const refs = getRefs().slice(refsOffset, refsOffset + topK); documents = v.map((d, i) => { const reference = refs[i % refs.length]; diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index cde9d3e0a4e2..ec3e62db1611 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -48,6 +48,7 @@ import { redisClient } from "@app/lib/redis"; import logger from "@app/logger/logger"; const CANCELLATION_CHECK_INTERVAL = 500; +const MAX_ACTIONS_PER_STEP = 16; // This interface is used to execute an agent. It is not in charge of creating the AgentMessage, // nor updating it (responsability of the caller based on the emitted events). @@ -168,13 +169,15 @@ export async function* runMultiActionsAgentLoop( "[ASSISTANT_TRACE] Action inputs generation" ); + // We received the actions to run, but will enforce a limit on the number of actions (16) + // which is very high. Over that the latency will just be too high. This is a guardrail + // against the model outputing something unreasonable. + event.actions = event.actions.slice(0, MAX_ACTIONS_PER_STEP); + yield event; - const actionIndexByType: Record = {}; const eventStreamGenerators = event.actions.map( - ({ action, inputs, functionCallId, specification }) => { - const index = actionIndexByType[action.type] ?? 0; - actionIndexByType[action.type] = index + 1; + ({ action, inputs, functionCallId, specification }, index) => { return runAction(auth, { configuration: configuration, actionConfiguration: action, @@ -185,7 +188,8 @@ export async function* runMultiActionsAgentLoop( specification, functionCallId, step: i, - indexForType: index, + stepActionIndex: index, + stepActions: event.actions.map((a) => a.action), }); } ); @@ -715,7 +719,8 @@ async function* runAction( specification, functionCallId, step, - indexForType, + stepActionIndex, + stepActions, }: { configuration: AgentConfigurationType; actionConfiguration: AgentActionConfigurationType; @@ -726,7 +731,8 @@ async function* runAction( specification: AgentActionSpecification | null; functionCallId: string | null; step: number; - indexForType: number; + stepActionIndex: number; + stepActions: AgentActionConfigurationType[]; } ): AsyncGenerator< AgentActionSpecificEvent | AgentErrorEvent | AgentActionSuccessEvent, @@ -748,8 +754,8 @@ async function* runAction( step, }, { - // We allocate 32 refs per retrieval action. - refsOffset: indexForType * 32, + stepActionIndex, + stepActions, } );