Skip to content

Commit

Permalink
Refactor websearch and browse fetch logic
Browse files Browse the repository at this point in the history
  • Loading branch information
flvndvd committed Aug 9, 2024
1 parent fd2b2e6 commit 511e9a8
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 60 deletions.
93 changes: 33 additions & 60 deletions front/lib/api/assistant/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ import {
DEFAULT_TABLES_QUERY_ACTION_NAME,
DEFAULT_WEBSEARCH_ACTION_NAME,
} from "@app/lib/api/assistant/actions/names";
import { fetchBrowseActionsConfigurations } from "@app/lib/api/assistant/configuration/browse";
import { fetchDustAppRunActionsConfigurations } from "@app/lib/api/assistant/configuration/dust_app_run";
import { fetchAgentProcessActionsConfigurations } from "@app/lib/api/assistant/configuration/process";
import { fetchAgentRetrievalActionsConfigurations } from "@app/lib/api/assistant/configuration/retrieval";
import { fetchTableQueryActionsConfigurations } from "@app/lib/api/assistant/configuration/table_query";
import { fetchWebsearchActionsConfigurations } from "@app/lib/api/assistant/configuration/websearch";
import {
getGlobalAgents,
isGlobalAgentId,
Expand Down Expand Up @@ -376,56 +378,37 @@ async function fetchWorkspaceAgentConfigurationsForView(
const configurationIds = agentConfigurations.map((a) => a.id);
const configurationSIds = agentConfigurations.map((a) => a.sId);

function groupByAgentConfigurationId<
T extends { agentConfigurationId: number },
>(list: T[]): Record<number, T[]> {
return _.groupBy(list, "agentConfigurationId");
}

const [websearchConfigs, browseConfigs, agentUserRelations] =
await Promise.all([
variant === "full"
? AgentWebsearchConfiguration.findAll({
where: {
agentConfigurationId: { [Op.in]: configurationIds },
},
}).then(groupByAgentConfigurationId)
: Promise.resolve({} as Record<number, AgentWebsearchConfiguration[]>),
variant === "full"
? AgentBrowseConfiguration.findAll({
where: {
agentConfigurationId: { [Op.in]: configurationIds },
},
}).then(groupByAgentConfigurationId)
: Promise.resolve({} as Record<number, AgentBrowseConfiguration[]>),
user && configurationIds.length > 0
? AgentUserRelation.findAll({
where: {
agentConfiguration: { [Op.in]: configurationSIds },
userId: user.id,
},
}).then((relations) =>
relations.reduce(
(acc, relation) => {
acc[relation.agentConfiguration] = relation;
return acc;
},
{} as Record<string, AgentUserRelation>
)
)
: Promise.resolve({} as Record<string, AgentUserRelation>),
]);

const [
retrievalActionsConfigurationsPerAgent,
processActionsConfigurationsPerAgent,
dustAppRunActionsConfigurationsPerAgent,
tableQueryActionsConfigurationsPerAgent,
websearchActionsConfigurationsPerAgent,
browseActionsConfigurationsPerAgent,
agentUserRelations,
] = await Promise.all([
fetchAgentRetrievalActionsConfigurations({ configurationIds, variant }),
fetchAgentProcessActionsConfigurations({ configurationIds, variant }),
fetchDustAppRunActionsConfigurations({ configurationIds, variant }),
fetchTableQueryActionsConfigurations({ configurationIds, variant }),
fetchWebsearchActionsConfigurations({ configurationIds, variant }),
fetchBrowseActionsConfigurations({ configurationIds, variant }),
user && configurationIds.length > 0
? AgentUserRelation.findAll({
where: {
agentConfiguration: { [Op.in]: configurationSIds },
userId: user.id,
},
}).then((relations) =>
relations.reduce(
(acc, relation) => {
acc[relation.agentConfiguration] = relation;
return acc;
},
{} as Record<string, AgentUserRelation>
)
)
: Promise.resolve({} as Record<string, AgentUserRelation>),
]);

let agentConfigurationTypes: AgentConfigurationType[] = [];
Expand All @@ -445,27 +428,17 @@ async function fetchWorkspaceAgentConfigurationsForView(

actions.push(...dustAppRunActionsConfigurations);

const websearchConfigurations = websearchConfigs[agent.id] ?? [];
for (const websearchConfig of websearchConfigurations) {
actions.push({
id: websearchConfig.id,
sId: websearchConfig.sId,
type: "websearch_configuration",
name: websearchConfig.name || DEFAULT_WEBSEARCH_ACTION_NAME,
description: websearchConfig.description,
});
}
// Websearch configurations.
const websearchActionsConfigurations =
websearchActionsConfigurationsPerAgent.get(agent.id) ?? [];

const browseConfigurations = browseConfigs[agent.id] ?? [];
for (const browseConfig of browseConfigurations) {
actions.push({
id: browseConfig.id,
sId: browseConfig.sId,
type: "browse_configuration",
name: browseConfig.name || DEFAULT_BROWSE_ACTION_NAME,
description: browseConfig.description,
});
}
actions.push(...websearchActionsConfigurations);

// Browse configurations.
const browseActionsConfigurations =
browseActionsConfigurationsPerAgent.get(agent.id) ?? [];

actions.push(...browseActionsConfigurations);

// Table query configurations.
const tableQueryActionsConfigurations =
Expand Down
52 changes: 52 additions & 0 deletions front/lib/api/assistant/configuration/browse.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import type { BrowseConfigurationType, ModelId } from "@dust-tt/types";
import _ from "lodash";
import { Op } from "sequelize";

import { DEFAULT_BROWSE_ACTION_NAME } from "@app/lib/api/assistant/actions/names";
import { AgentBrowseConfiguration } from "@app/lib/models/assistant/actions/browse";

export async function fetchBrowseActionsConfigurations({
configurationIds,
variant,
}: {
configurationIds: ModelId[];
variant: "light" | "full";
}): Promise<Map<ModelId, BrowseConfigurationType[]>> {
if (variant !== "full") {
return new Map();
}

const browseConfigurations = await AgentBrowseConfiguration.findAll({
where: { agentConfigurationId: { [Op.in]: configurationIds } },
});

if (browseConfigurations.length === 0) {
return new Map();
}

const groupedBrowseConfigurations = _.groupBy(
browseConfigurations,
"agentConfigurationId"
);

const actionsByConfigurationId: Map<ModelId, BrowseConfigurationType[]> =
new Map();
for (const [agentConfigurationId, configs] of Object.entries(
groupedBrowseConfigurations
)) {
const actions: BrowseConfigurationType[] = [];
for (const c of configs) {
actions.push({
id: c.id,
sId: c.sId,
type: "browse_configuration",
name: c.name || DEFAULT_BROWSE_ACTION_NAME,
description: c.description,
});
}

actionsByConfigurationId.set(parseInt(agentConfigurationId, 10), actions);
}

return actionsByConfigurationId;
}
52 changes: 52 additions & 0 deletions front/lib/api/assistant/configuration/websearch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import type { ModelId, WebsearchConfigurationType } from "@dust-tt/types";
import _ from "lodash";
import { Op } from "sequelize";

import { DEFAULT_WEBSEARCH_ACTION_NAME } from "@app/lib/api/assistant/actions/names";
import { AgentWebsearchConfiguration } from "@app/lib/models/assistant/actions/websearch";

export async function fetchWebsearchActionsConfigurations({
configurationIds,
variant,
}: {
configurationIds: ModelId[];
variant: "light" | "full";
}): Promise<Map<ModelId, WebsearchConfigurationType[]>> {
if (variant !== "full") {
return new Map();
}

const websearchConfigurations = await AgentWebsearchConfiguration.findAll({
where: { agentConfigurationId: { [Op.in]: configurationIds } },
});

if (websearchConfigurations.length === 0) {
return new Map();
}

const groupedWebsearchConfigurations = _.groupBy(
websearchConfigurations,
"agentConfigurationId"
);

const actionsByConfigurationId: Map<ModelId, WebsearchConfigurationType[]> =
new Map();
for (const [agentConfigurationId, configs] of Object.entries(
groupedWebsearchConfigurations
)) {
const actions: WebsearchConfigurationType[] = [];
for (const c of configs) {
actions.push({
id: c.id,
sId: c.sId,
type: "websearch_configuration",
name: c.name || DEFAULT_WEBSEARCH_ACTION_NAME,
description: c.description,
});
}

actionsByConfigurationId.set(parseInt(agentConfigurationId, 10), actions);
}

return actionsByConfigurationId;
}

0 comments on commit 511e9a8

Please sign in to comment.