Skip to content

Commit

Permalink
front: cross-step handling of refsoffsets (#5860)
Browse files Browse the repository at this point in the history
* front: cross-step handling of refsoffsets

* lint
  • Loading branch information
spolu authored Jun 25, 2024
1 parent 5ff5383 commit 473946b
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 49 deletions.
5 changes: 5 additions & 0 deletions front/lib/api/assistant/actions/browse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ export class BrowseConfigurationServerRunner extends BaseActionConfigurationServ
});
}

// Browse does not use citations.
getCitationsCount(): number {
return 0;
}

// This method is in charge of running the browse and creating an AgentBrowseAction object in
// the database. It does not create any generic model related to the conversation. It is possible
// for an AgentBrowseAction to be stored (once the query params are infered) but for its execution
Expand Down
5 changes: 5 additions & 0 deletions front/lib/api/assistant/actions/dust_app_run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ export class DustAppRunConfigurationServerRunner extends BaseActionConfiguration
});
}

// DustAppRun does not use citations.
getCitationsCount(): number {
return 0;
}

// This method is in charge of running a dust app and creating an AgentDustAppRunAction object in
// the database. It does not create any generic model related to the conversation. It is possible
// for an AgentDustAppRunAction to be stored (once the params are infered) but for the dust app run
Expand Down
5 changes: 5 additions & 0 deletions front/lib/api/assistant/actions/process.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ export class ProcessConfigurationServerRunner extends BaseActionConfigurationSer
return new Ok(spec);
}

// Process does not use citations.
getCitationsCount(): number {
return 0;
}

// This method is in charge of running the retrieval and creating an AgentProcessAction object in
// the database. It does not create any generic model related to the conversation. It is possible
// for an AgentProcessAction to be stored (once the query params are infered) but for its execution
Expand Down
94 changes: 55 additions & 39 deletions front/lib/api/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type {
AgentActionConfigurationType,
AgentConfigurationType,
FunctionCallType,
FunctionMessageTypeModel,
ModelConfigurationType,
ModelId,
RetrievalErrorEvent,
RetrievalParamsEvent,
Expand All @@ -21,7 +21,6 @@ import {
cloneBaseConfig,
DustProdActionRegistry,
isDevelopment,
isRetrievalConfiguration,
} from "@dust-tt/types";
import { Ok } from "@dust-tt/types";

Expand All @@ -41,6 +40,8 @@ import { frontSequelize } from "@app/lib/resources/storage";
import { rand } from "@app/lib/utils/seeded_random";
import logger from "@app/logger/logger";

import { getRunnerforActionConfiguration } from "./runners";

/**
* TimeFrame parsing
*/
Expand Down Expand Up @@ -265,58 +266,74 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS
return new Ok(spec);
}

// Retrieval shares topK across retrieval actions of a same step and uses citations for these.
getCitationsCount({
agentConfiguration,
stepActions,
}: {
agentConfiguration: AgentConfigurationType;
stepActions: AgentActionConfigurationType[];
}): number {
const actionCount = stepActions.filter(
(a) => a.sId === this.actionConfiguration.sId
).length;

if (actionCount === 0) {
throw new Error("Unexpected: found 0 retrieval actions");
}

const { actionConfiguration } = this;
const model = getSupportedModelConfig(agentConfiguration.model);

let topK = 16;
if (actionConfiguration.topK === "auto") {
if (actionConfiguration.query === "none") {
topK = model.recommendedExhaustiveTopK;
} else {
topK = model.recommendedTopK;
}
} else {
topK = actionConfiguration.topK;
}

// We split the topK among the actions that are uses of the same action configuration.
return Math.ceil(topK / actionCount);
}

// stepTopKAndRefsOffsetForAction returns the references offset and the number of documents an
// action will use as part of the current step. We share topK among multiple instances of a same
// retrieval action so that we don't overflow the context when the model asks for many retrievals
// at the same time. Based on the nature of the retrieval actions (query being `auto` or `null`),
// the topK can vary (exhaustive or not). So we need all the actions of the current step to
// properly split the topK among them and decide which slice of references we will allocate to the
// current action.
static stepTopKAndRefsOffsetForAction({
action,
model,
stepTopKAndRefsOffsetForAction({
agentConfiguration,
stepActionIndex,
stepActions,
refsOffset,
}: {
action: RetrievalConfigurationType;
model: ModelConfigurationType;
agentConfiguration: AgentConfigurationType;
stepActionIndex: number;
stepActions: AgentActionConfigurationType[];
refsOffset: number;
}): { topK: number; refsOffset: number } {
const topKForAction = (
action: RetrievalConfigurationType,
actionCount: number
) => {
let topK = 16;
if (action.topK === "auto") {
if (action.query === "none") {
topK = model.recommendedExhaustiveTopK;
} else {
topK = model.recommendedTopK;
}
} else {
topK = action.topK;
}
// We split the topK among the actions that are uses of the same action configuration.
return Math.ceil(topK / actionCount);
};

const actionCounts: Record<string, number> = {};
stepActions.forEach((a) => {
actionCounts[a.sId] = actionCounts[a.sId] ?? 0;
actionCounts[a.sId]++;
});

let refsOffset = 0;
for (let i = 0; i < stepActionIndex; i++) {
const r = stepActions[i];
if (isRetrievalConfiguration(r)) {
refsOffset += topKForAction(r, actionCounts[stepActions[i].sId]);
}
refsOffset += getRunnerforActionConfiguration(r).getCitationsCount({
agentConfiguration,
stepActions,
});
}

return {
topK: topKForAction(action, actionCounts[action.sId]),
topK: this.getCitationsCount({ agentConfiguration, stepActions }),
refsOffset,
};
}
Expand All @@ -340,9 +357,11 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS
{
stepActionIndex,
stepActions,
citationsRefsOffset,
}: {
stepActionIndex: number;
stepActions: AgentActionConfigurationType[];
citationsRefsOffset: number;
}
): AsyncGenerator<
RetrievalParamsEvent | RetrievalSuccessEvent | RetrievalErrorEvent,
Expand Down Expand Up @@ -391,15 +410,12 @@ export class RetrievalConfigurationServerRunner extends BaseActionConfigurationS
}
}

const { model } = agentConfiguration;

const { topK, refsOffset } =
RetrievalConfigurationServerRunner.stepTopKAndRefsOffsetForAction({
action: actionConfiguration,
model: getSupportedModelConfig(model),
stepActionIndex,
stepActions,
});
const { topK, refsOffset } = this.stepTopKAndRefsOffsetForAction({
agentConfiguration,
stepActionIndex,
stepActions,
refsOffset: citationsRefsOffset,
});

// Create the AgentRetrievalAction object in the database and yield an event for the generation of
// the params. We store the action here as the params have been generated, if an error occurs
Expand Down
5 changes: 5 additions & 0 deletions front/lib/api/assistant/actions/tables_query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ export class TablesQueryConfigurationServerRunner extends BaseActionConfiguratio
return new Ok(spec);
}

// TablesQuery does not use citations.
getCitationsCount(): number {
return 0;
}

async *run(
auth: Authenticator,
{
Expand Down
9 changes: 9 additions & 0 deletions front/lib/api/assistant/actions/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ export abstract class BaseActionConfigurationServerRunner<
{ name, description }: { name: string | null; description: string | null }
): Promise<Result<AgentActionSpecification, Error>>;

// Computes the max number of citation for the actions as part of this step.
abstract getCitationsCount({
agentConfiguration,
stepActions,
}: {
agentConfiguration: AgentConfigurationType;
stepActions: AgentActionConfigurationType[];
}): number;

// Action execution.
abstract run(
auth: Authenticator,
Expand Down
80 changes: 79 additions & 1 deletion front/lib/api/assistant/actions/websearch.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import type {
AgentActionConfigurationType,
AgentActionSpecification,
AgentConfigurationType,
FunctionCallType,
FunctionMessageTypeModel,
ModelId,
Expand Down Expand Up @@ -28,6 +30,10 @@ import type { Authenticator } from "@app/lib/auth";
import { AgentWebsearchAction } from "@app/lib/models/assistant/actions/websearch";
import logger from "@app/logger/logger";

import { getRunnerforActionConfiguration } from "./runners";

const WEBSEARCH_ACTION_NUM_RESULTS = 16;

interface WebsearchActionBlob {
id: ModelId; // AgentWebsearchAction
agentMessageId: ModelId;
Expand Down Expand Up @@ -113,6 +119,59 @@ export class WebsearchConfigurationServerRunner extends BaseActionConfigurationS
});
}

// WebSearch shares results count across web search actions of a same step and uses citations for
// these.
getCitationsCount({
stepActions,
}: {
stepActions: AgentActionConfigurationType[];
}): number {
const actionCount = stepActions.filter(
(a) => a.sId === this.actionConfiguration.sId
).length;

if (actionCount === 0) {
throw new Error("Unexpected: found 0 websearch actions");
}

return Math.ceil(WEBSEARCH_ACTION_NUM_RESULTS / actionCount);
}

// stepTopKAndRefsOffsetForAction returns the references offset and the number of search results
// an action will use as part of the current step. We share number of results among multiple
// instances of a same websearch action from the same step so that we don't overflow the context
// when the model asks for many web searches at the same time.
stepNumResultsAndRefsOffsetForAction({
agentConfiguration,
stepActionIndex,
stepActions,
refsOffset,
}: {
agentConfiguration: AgentConfigurationType;
stepActionIndex: number;
stepActions: AgentActionConfigurationType[];
refsOffset: number;
}): { numResults: number; refsOffset: number } {
const actionCounts: Record<string, number> = {};
stepActions.forEach((a) => {
actionCounts[a.sId] = actionCounts[a.sId] ?? 0;
actionCounts[a.sId]++;
});

for (let i = 0; i < stepActionIndex; i++) {
const r = stepActions[i];
refsOffset += getRunnerforActionConfiguration(r).getCitationsCount({
agentConfiguration,
stepActions,
});
}

return {
numResults: this.getCitationsCount({ stepActions }),
refsOffset,
};
}

// This method is in charge of running the websearch and creating an AgentWebsearchAction object in
// the database. It does not create any generic model related to the conversation. It is possible
// for an AgentWebsearchAction to be stored (once the query params are infered) but for its execution
Expand All @@ -127,7 +186,16 @@ export class WebsearchConfigurationServerRunner extends BaseActionConfigurationS
rawInputs,
functionCallId,
step,
}: BaseActionRunParams
}: BaseActionRunParams,
{
stepActionIndex,
stepActions,
citationsRefsOffset,
}: {
stepActionIndex: number;
stepActions: AgentActionConfigurationType[];
citationsRefsOffset: number;
}
): AsyncGenerator<
WebsearchParamsEvent | WebsearchSuccessEvent | WebsearchErrorEvent,
void
Expand Down Expand Up @@ -158,6 +226,14 @@ export class WebsearchConfigurationServerRunner extends BaseActionConfigurationS
return;
}

const { numResults /*, refsOffset*/ } =
this.stepNumResultsAndRefsOffsetForAction({
agentConfiguration,
stepActionIndex,
stepActions,
refsOffset: citationsRefsOffset,
});

// Create the AgentWebsearchAction object in the database and yield an event for the generation of
// the params. We store the action here as the params have been generated, if an error occurs
// later on, the action won't have outputs but the error will be stored on the parent agent
Expand Down Expand Up @@ -194,6 +270,8 @@ export class WebsearchConfigurationServerRunner extends BaseActionConfigurationS
DustProdActionRegistry["assistant-v2-websearch"].config
);

config.SEARCH.num = numResults;

// Execute the websearch action.
const websearchRes = await runActionStreamed(
auth,
Expand Down
Loading

0 comments on commit 473946b

Please sign in to comment.