Skip to content

Commit

Permalink
front: refactor refs allocation + enforce guardrails on actions per s…
Browse files Browse the repository at this point in the history
…tep (#5810)

* front: refactor refs allocation + enforce guardrails on actions per step

* lint
  • Loading branch information
spolu authored Jun 24, 2024
1 parent 19e7fc5 commit 5c58962
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 22 deletions.
97 changes: 84 additions & 13 deletions front/lib/api/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import type {
AgentActionConfigurationType,
FunctionCallType,
FunctionMessageTypeModel,
ModelConfigurationType,
ModelId,
RetrievalErrorEvent,
RetrievalParamsEvent,
Expand All @@ -18,6 +20,7 @@ import {
BaseAction,
cloneBaseConfig,
DustProdActionRegistry,
isRetrievalConfiguration,
} from "@dust-tt/types";
import { Ok } from "@dust-tt/types";

Expand Down Expand Up @@ -264,6 +267,62 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS
return new Ok(spec);
}

// stepTopKAndRefsOffsetForAction returns the references offset and the number of documents an
// action will use as part of the current step. We share topK among multiple instances of a same
// retrieval action so that we don't overflow the context when the model asks for many retrievals
// at the same time. Based on the nature of the retrieval actions (query being `auto` or `null`),
// the topK can vary (exhaustive or not). So we need all the actions of the current step to
// properly split the topK among them and decide which slice of references we will allocate to the
// current action.
static stepTopKAndRefsOffsetForAction({
action,
model,
stepActionIndex,
stepActions,
}: {
action: RetrievalConfigurationType;
model: ModelConfigurationType;
stepActionIndex: number;
stepActions: AgentActionConfigurationType[];
}): { topK: number; refsOffset: number } {
const topKForAction = (
action: RetrievalConfigurationType,
actionCount: number
) => {
let topK = 16;
if (action.topK === "auto") {
if (action.query === "none") {
topK = model.recommendedExhaustiveTopK;
} else {
topK = model.recommendedTopK;
}
} else {
topK = action.topK;
}
// We split the topK among the actions that are uses of the same action configuration.
return Math.ceil(topK / actionCount);
};

const actionCounts: Record<string, number> = {};
stepActions.forEach((a) => {
actionCounts[a.sId] = actionCounts[a.sId] ?? 0;
actionCounts[a.sId]++;
});

let refsOffset = 0;
for (let i = 0; i < stepActionIndex; i++) {
const r = stepActions[i];
if (isRetrievalConfiguration(r)) {
refsOffset += topKForAction(r, actionCounts[stepActions[i].sId]);
}
}

return {
topK: topKForAction(action, actionCounts[action.sId]),
refsOffset,
};
}

// This method is in charge of running the retrieval and creating an AgentRetrievalAction object in
// the database (along with the RetrievalDocument and RetrievalDocumentChunk objects). It does not
// create any generic model related to the conversation. It is possible for an AgentRetrievalAction
Expand All @@ -280,7 +339,13 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS
functionCallId,
step,
}: BaseActionRunParams,
{ refsOffset = 0 }: { refsOffset?: number }
{
stepActionIndex,
stepActions,
}: {
stepActionIndex: number;
stepActions: AgentActionConfigurationType[];
}
): AsyncGenerator<
RetrievalParamsEvent | RetrievalSuccessEvent | RetrievalErrorEvent,
void
Expand Down Expand Up @@ -330,17 +395,13 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS

const { model } = agentConfiguration;

let topK = 16;
if (actionConfiguration.topK === "auto") {
const supportedModel = getSupportedModelConfig(model);
if (actionConfiguration.query === "none") {
topK = supportedModel.recommendedExhaustiveTopK;
} else {
topK = supportedModel.recommendedTopK;
}
} else {
topK = actionConfiguration.topK;
}
const { topK, refsOffset } =
RetrievalConfigurationServerRunner.stepTopKAndRefsOffsetForAction({
action: actionConfiguration,
model: getSupportedModelConfig(model),
stepActionIndex,
stepActions,
});

// Create the AgentRetrievalAction object in the database and yield an event for the generation of
// the params. We store the action here as the params have been generated, if an error occurs
Expand Down Expand Up @@ -546,7 +607,17 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS
token_count: number;
}[];

const refs = getRefs().slice(refsOffset, refsOffset + v.length);
if (refsOffset + topK > getRefs().length) {
// This is a stream dropping error since the guardrails put in place should prevent this
// from ever happeaning (max 16 actions per step and sharing of topK among actions of
// the same type).
throw new Error(
"The retrieval actions exhausted the total number of references available: " +
`refsOffset=${refsOffset} topK=${topK}`
);
}

const refs = getRefs().slice(refsOffset, refsOffset + topK);

documents = v.map((d, i) => {
const reference = refs[i % refs.length];
Expand Down
24 changes: 15 additions & 9 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import { redisClient } from "@app/lib/redis";
import logger from "@app/logger/logger";

const CANCELLATION_CHECK_INTERVAL = 500;
const MAX_ACTIONS_PER_STEP = 16;

// This interface is used to execute an agent. It is not in charge of creating the AgentMessage,
// nor updating it (responsability of the caller based on the emitted events).
Expand Down Expand Up @@ -168,13 +169,15 @@ export async function* runMultiActionsAgentLoop(
"[ASSISTANT_TRACE] Action inputs generation"
);

// We received the actions to run, but will enforce a limit on the number of actions (16)
// which is very high. Over that the latency will just be too high. This is a guardrail
// against the model outputing something unreasonable.
event.actions = event.actions.slice(0, MAX_ACTIONS_PER_STEP);

yield event;

const actionIndexByType: Record<string, number> = {};
const eventStreamGenerators = event.actions.map(
({ action, inputs, functionCallId, specification }) => {
const index = actionIndexByType[action.type] ?? 0;
actionIndexByType[action.type] = index + 1;
({ action, inputs, functionCallId, specification }, index) => {
return runAction(auth, {
configuration: configuration,
actionConfiguration: action,
Expand All @@ -185,7 +188,8 @@ export async function* runMultiActionsAgentLoop(
specification,
functionCallId,
step: i,
indexForType: index,
stepActionIndex: index,
stepActions: event.actions.map((a) => a.action),
});
}
);
Expand Down Expand Up @@ -715,7 +719,8 @@ async function* runAction(
specification,
functionCallId,
step,
indexForType,
stepActionIndex,
stepActions,
}: {
configuration: AgentConfigurationType;
actionConfiguration: AgentActionConfigurationType;
Expand All @@ -726,7 +731,8 @@ async function* runAction(
specification: AgentActionSpecification | null;
functionCallId: string | null;
step: number;
indexForType: number;
stepActionIndex: number;
stepActions: AgentActionConfigurationType[];
}
): AsyncGenerator<
AgentActionSpecificEvent | AgentErrorEvent | AgentActionSuccessEvent,
Expand All @@ -748,8 +754,8 @@ async function* runAction(
step,
},
{
// We allocate 32 refs per retrieval action.
refsOffset: indexForType * 32,
stepActionIndex,
stepActions,
}
);

Expand Down

0 comments on commit 5c58962

Please sign in to comment.