diff --git a/connectors/src/api/slack_channel_link_with_agent.ts b/connectors/src/api/slack_channel_link_with_agent.ts deleted file mode 100644 index c26fea12f4a6..000000000000 --- a/connectors/src/api/slack_channel_link_with_agent.ts +++ /dev/null @@ -1,86 +0,0 @@ -import { Request, Response } from "express"; -import { isLeft } from "fp-ts/lib/Either"; -import * as t from "io-ts"; -import * as reporter from "io-ts-reporters"; - -import { APIErrorWithStatusCode } from "@connectors/lib/error"; -import { SlackChannel } from "@connectors/lib/models"; -import { apiError, withLogging } from "@connectors/logger/withlogging"; - -const LinkSlackChannelWithAgentReqBodySchema = t.type({ - connectorId: t.string, - agentConfigurationId: t.string, -}); - -type LinkSlackChannelWithAgentReqBody = t.TypeOf< - typeof LinkSlackChannelWithAgentReqBodySchema ->; - -type LinkSlackChannelWithAgentResBody = - | { success: true } - | APIErrorWithStatusCode; - -const _linkSlackChannelWithAgentHandler = async ( - req: Request< - { - slackChannelId: string; - }, - LinkSlackChannelWithAgentResBody, - LinkSlackChannelWithAgentReqBody - >, - res: Response -) => { - if (!req.params.slackChannelId) { - return apiError(req, res, { - api_error: { - type: "invalid_request_error", - message: "Missing required parameters. Required: slackChannelId", - }, - status_code: 400, - }); - } - - const { slackChannelId } = req.params; - - const bodyValidation = LinkSlackChannelWithAgentReqBodySchema.decode( - req.body - ); - - if (isLeft(bodyValidation)) { - const pathError = reporter.formatValidationErrors(bodyValidation.left); - return apiError(req, res, { - api_error: { - type: "invalid_request_error", - message: `Invalid request body: ${pathError}`, - }, - status_code: 400, - }); - } - - const { connectorId, agentConfigurationId } = bodyValidation.right; - - const slackChannel = await SlackChannel.findOne({ - where: { connectorId, slackChannelId }, - }); - - if (!slackChannel) { - return apiError(req, res, { - api_error: { - type: "slack_channel_not_found", - message: `Slack channel not found for connectorId ${connectorId} and slackChannelId ${slackChannelId}`, - }, - status_code: 404, - }); - } - - slackChannel.agentConfigurationId = agentConfigurationId; - await slackChannel.save(); - - res.status(200).send({ - success: true, - }); -}; - -export const linkSlackChannelWithAgentHandler = withLogging( - _linkSlackChannelWithAgentHandler -); diff --git a/connectors/src/api/slack_channels_linked_with_agent.ts b/connectors/src/api/slack_channels_linked_with_agent.ts new file mode 100644 index 000000000000..02a455e00e02 --- /dev/null +++ b/connectors/src/api/slack_channels_linked_with_agent.ts @@ -0,0 +1,160 @@ +import { Request, Response } from "express"; +import { isLeft } from "fp-ts/lib/Either"; +import * as t from "io-ts"; +import * as reporter from "io-ts-reporters"; +import { Op } from "sequelize"; + +import { APIErrorWithStatusCode } from "@connectors/lib/error"; +import { sequelize_conn, SlackChannel } from "@connectors/lib/models"; +import { apiError, withLogging } from "@connectors/logger/withlogging"; + +const PatchSlackChannelsLinkedWithAgentReqBodySchema = t.type({ + agent_configuration_id: t.string, + slack_channel_ids: t.array(t.string), + connector_id: t.string, +}); + +type PatchSlackChannelsLinkedWithAgentReqBody = t.TypeOf< + typeof PatchSlackChannelsLinkedWithAgentReqBodySchema +>; + +type PatchSlackChannelsLinkedWithAgentResBody = + | { success: true } + | APIErrorWithStatusCode; + +const _patchSlackChannelsLinkedWithAgentHandler = async ( + req: Request< + Record, + PatchSlackChannelsLinkedWithAgentResBody, + PatchSlackChannelsLinkedWithAgentReqBody + >, + res: Response +) => { + const bodyValidation = PatchSlackChannelsLinkedWithAgentReqBodySchema.decode( + req.body + ); + + if (isLeft(bodyValidation)) { + const pathError = reporter.formatValidationErrors(bodyValidation.left); + return apiError(req, res, { + api_error: { + type: "invalid_request_error", + message: `Invalid request body: ${pathError}`, + }, + status_code: 400, + }); + } + + const { + connector_id: connectorId, + agent_configuration_id: agentConfigurationId, + slack_channel_ids: slackChannelIds, + } = bodyValidation.right; + + const slackChannels = await SlackChannel.findAll({ + where: { + slackChannelId: slackChannelIds, + connectorId, + }, + }); + + const foundSlackChannelIds = new Set( + slackChannels.map((c) => c.slackChannelId) + ); + + const missingSlackChannelIds = Array.from( + new Set(slackChannelIds.filter((id) => !foundSlackChannelIds.has(id))) + ); + + if (missingSlackChannelIds.length) { + return apiError(req, res, { + api_error: { + type: "not_found", + message: `Slack channel(s) not found: ${missingSlackChannelIds.join( + ", " + )}`, + }, + status_code: 404, + }); + } + + await sequelize_conn.transaction(async (t) => { + await SlackChannel.update( + { agentConfigurationId: null }, + { + where: { + connectorId, + agentConfigurationId, + }, + transaction: t, + } + ); + await Promise.all( + slackChannels.map((slackChannel) => + slackChannel.update({ agentConfigurationId }, { transaction: t }) + ) + ); + }); + + res.status(200).json({ + success: true, + }); +}; + +export const patchSlackChannelsLinkedWithAgentHandler = withLogging( + _patchSlackChannelsLinkedWithAgentHandler +); + +type GetSlackChannelsLinkedWithAgentResBody = + | { + slackChannels: { + slackChannelId: string; + slackChannelName: string; + agentConfigurationId: string; + }[]; + } + | APIErrorWithStatusCode; + +const _getSlackChannelsLinkedWithAgentHandler = async ( + req: Request< + Record, + { slackChannelIds: string[] } | APIErrorWithStatusCode, + undefined + >, + res: Response +) => { + const { connector_id: connectorId } = req.query; + + if (!connectorId || typeof connectorId !== "string") { + return apiError(req, res, { + api_error: { + type: "invalid_request_error", + message: `Missing required parameters: connector_id`, + }, + status_code: 400, + }); + } + + const slackChannels = await SlackChannel.findAll({ + where: { + connectorId, + agentConfigurationId: { + [Op.not]: null, + }, + }, + }); + + res.status(200).json({ + slackChannels: slackChannels.map((c) => ({ + slackChannelId: c.slackChannelId, + slackChannelName: c.slackChannelName, + // We know that agentConfigurationId is not null because of the where clause above + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + agentConfigurationId: c.agentConfigurationId!, + })), + }); +}; + +export const getSlackChannelsLinkedWithAgentHandler = withLogging( + _getSlackChannelsLinkedWithAgentHandler +); diff --git a/connectors/src/api_server.ts b/connectors/src/api_server.ts index f3f2cdb68d26..c3036b07ed44 100644 --- a/connectors/src/api_server.ts +++ b/connectors/src/api_server.ts @@ -9,8 +9,14 @@ import { createConnectorAPIHandler } from "@connectors/api/create_connector"; import { deleteConnectorAPIHandler } from "@connectors/api/delete_connector"; import { getConnectorAPIHandler } from "@connectors/api/get_connector"; import { getConnectorPermissionsAPIHandler } from "@connectors/api/get_connector_permissions"; +import { getResourcesParentsAPIHandler } from "@connectors/api/get_resources_parents"; +import { getResourcesTitlesAPIHandler } from "@connectors/api/get_resources_titles"; import { resumeConnectorAPIHandler } from "@connectors/api/resume_connector"; import { setConnectorPermissionsAPIHandler } from "@connectors/api/set_connector_permissions"; +import { + getSlackChannelsLinkedWithAgentHandler, + patchSlackChannelsLinkedWithAgentHandler, +} from "@connectors/api/slack_channels_linked_with_agent"; import { stopConnectorAPIHandler } from "@connectors/api/stop_connector"; import { syncConnectorAPIHandler } from "@connectors/api/sync_connector"; import { getConnectorUpdateAPIHandler } from "@connectors/api/update_connector"; @@ -20,10 +26,6 @@ import { webhookSlackAPIHandler } from "@connectors/api/webhooks/webhook_slack"; import logger from "@connectors/logger/logger"; import { authMiddleware } from "@connectors/middleware/auth"; -import { getResourcesParentsAPIHandler } from "./api/get_resources_parents"; -import { getResourcesTitlesAPIHandler } from "./api/get_resources_titles"; -import { linkSlackChannelWithAgentHandler } from "./api/slack_channel_link_with_agent"; - export function startServer(port: number) { const app = express(); @@ -72,9 +74,13 @@ export function startServer(port: number) { setConnectorPermissionsAPIHandler ); - app.post( - "/slack/channels/:slackChannelId/link_with_agent", - linkSlackChannelWithAgentHandler + app.patch( + "/slack/channels/linked_with_agent", + patchSlackChannelsLinkedWithAgentHandler + ); + app.get( + "/slack/channels/linked_with_agent", + getSlackChannelsLinkedWithAgentHandler ); app.post("/webhooks/:webhook_secret/slack", webhookSlackAPIHandler); diff --git a/connectors/src/connectors/slack/bot.ts b/connectors/src/connectors/slack/bot.ts index 35f580bc04e6..484fc923ff83 100644 --- a/connectors/src/connectors/slack/bot.ts +++ b/connectors/src/connectors/slack/bot.ts @@ -11,6 +11,7 @@ import { import { Connector, ModelId, + SlackChannel, SlackChatBotMessage, SlackConfiguration, } from "@connectors/lib/models"; @@ -258,10 +259,38 @@ async function botAnswerMessage( ); } } - } - - if (mentions.length === 0) { - mentions.push({ assistantId: "dust", assistantName: "dust" }); + } else { + // If no mention is found, we look at channel-based routing rules. + const channel = await SlackChannel.findOne({ + where: { + connectorId: connector.id, + slackChannelId: slackChannel, + }, + }); + if (channel && channel.agentConfigurationId) { + const agentConfigurationsRes = await dustAPI.getAgentConfigurations(); + if (agentConfigurationsRes.isErr()) { + return new Err(new Error(agentConfigurationsRes.error.message)); + } + const agentConfigurations = agentConfigurationsRes.value; + const agentConfiguration = agentConfigurations.find( + (ac) => ac.sId === channel.agentConfigurationId + ); + if (!agentConfiguration) { + return new Err( + new Error( + `Failed to find agent configuration ${channel.agentConfigurationId}` + ) + ); + } + mentions.push({ + assistantId: channel.agentConfigurationId, + assistantName: agentConfiguration.name, + }); + } else { + // If no mention is found and no channel-based routing rule is found, we use the default assistant. + mentions.push({ assistantId: "dust", assistantName: "dust" }); + } } const messageReqBody = { @@ -388,7 +417,6 @@ async function botAnswerMessage( thread_ts: slackMessageTs, }); return new Ok(event); - break; } default: // Nothing to do on unsupported events diff --git a/front/components/DataSourceResourceSelectorTree.tsx b/front/components/DataSourceResourceSelectorTree.tsx index bbd03ca80ea9..ff16292aff92 100644 --- a/front/components/DataSourceResourceSelectorTree.tsx +++ b/front/components/DataSourceResourceSelectorTree.tsx @@ -9,7 +9,10 @@ import { import { CircleStackIcon, FolderIcon } from "@heroicons/react/20/solid"; import { useState } from "react"; -import { ConnectorResourceType } from "@app/lib/connectors_api"; +import { + ConnectorPermission, + ConnectorResourceType, +} from "@app/lib/connectors_api"; import { useConnectorPermissions } from "@app/lib/swr"; import { classNames } from "@app/lib/utils"; import { DataSourceType } from "@app/types/data_source"; @@ -23,6 +26,7 @@ export default function DataSourceResourceSelectorTree({ onSelectChange, parentsById, fullySelected, + filterPermission = "read", }: { owner: WorkspaceType; dataSource: DataSourceType; @@ -34,6 +38,7 @@ export default function DataSourceResourceSelectorTree({ ) => void; parentsById: Record>; fullySelected: boolean; + filterPermission?: ConnectorPermission; }) { return (
@@ -48,6 +53,7 @@ export default function DataSourceResourceSelectorTree({ parents={[]} isChecked={false} fullySelected={fullySelected} + filterPermission={filterPermission} />
); @@ -87,6 +93,7 @@ function DataSourceResourceSelectorChildren({ parentsById, parents, fullySelected, + filterPermission, }: { owner: WorkspaceType; dataSource: DataSourceType; @@ -101,13 +108,14 @@ function DataSourceResourceSelectorChildren({ ) => void; parentsById: Record>; fullySelected: boolean; + filterPermission: ConnectorPermission; }) { const { resources, isResourcesLoading, isResourcesError } = useConnectorPermissions({ owner: owner, dataSource, parentId, - filterPermission: "read", + filterPermission, disabled: dataSource.connectorId === null, }); @@ -223,6 +231,7 @@ function DataSourceResourceSelectorChildren({ parentsById={parentsById} parents={[...parents, r.internalId]} fullySelected={fullySelected} + filterPermission={filterPermission} /> )} diff --git a/front/components/assistant_builder/AssistantBuilder.tsx b/front/components/assistant_builder/AssistantBuilder.tsx index fa56eb717e90..27eb4b4ae050 100644 --- a/front/components/assistant_builder/AssistantBuilder.tsx +++ b/front/components/assistant_builder/AssistantBuilder.tsx @@ -4,9 +4,14 @@ import { Avatar, Button, Collapsible, + ContextItem, DropdownMenu, Input, + Modal, + PageHeader, PencilSquareIcon, + PlusIcon, + SlackLogo, TrashIcon, } from "@dust-tt/sparkle"; import * as t from "io-ts"; @@ -14,6 +19,7 @@ import { useRouter } from "next/router"; import { useCallback, useEffect, useState } from "react"; import React from "react"; import ReactTextareaAutosize from "react-textarea-autosize"; +import { mutate } from "swr"; import { AvatarPicker } from "@app/components/assistant_builder/AssistantBuilderAvatarPicker"; import AssistantBuilderDataSourceModal from "@app/components/assistant_builder/AssistantBuilderDataSourceModal"; @@ -37,14 +43,20 @@ import { GPT_4_32K_MODEL_CONFIG, SupportedModel, } from "@app/lib/assistant"; +import { CONNECTOR_CONFIGURATIONS } from "@app/lib/connector_providers"; import { ConnectorProvider } from "@app/lib/connectors_api"; -import { useAgentConfigurations } from "@app/lib/swr"; +import { + useAgentConfigurations, + useSlackChannelsLinkedWithAgent, +} from "@app/lib/swr"; import { classNames } from "@app/lib/utils"; import { PostOrPatchAgentConfigurationRequestBodySchema } from "@app/pages/api/w/[wId]/assistant/agent_configurations"; import { TimeframeUnit } from "@app/types/assistant/actions/retrieval"; import { DataSourceType } from "@app/types/data_source"; import { UserType, WorkspaceType } from "@app/types/user"; +import DataSourceResourceSelectorTree from "../DataSourceResourceSelectorTree"; + const usedModelConfigs = [ GPT_4_32K_MODEL_CONFIG, GPT_3_5_TURBO_16K_MODEL_CONFIG, @@ -173,6 +185,10 @@ export default function AssistantBuilder({ }: AssistantBuilderProps) { const router = useRouter(); + const slackDataSource = dataSources.find( + (ds) => ds.connectorProvider === "slack" + ); + const [builderState, setBuilderState] = useState({ ...DEFAULT_ASSISTANT_STATE, generationSettings: { @@ -258,9 +274,36 @@ export default function AssistantBuilder({ } }, [initialBuilderState]); - const removeLeadingAt = (handle: string) => { - return handle.startsWith("@") ? handle.slice(1) : handle; - }; + // This state stores the slack channels that should have the current agent as default. + const [selectedSlackChannels, setSelectedSlackChannels] = useState< + { + channelId: string; + channelName: string; + }[] + >([]); + + // Retrieve all the slack channels that are linked with an agent. + const { slackChannels: slackChannelsLinkedWithAgent } = + useSlackChannelsLinkedWithAgent({ + workspaceId: owner.sId, + dataSourceName: slackDataSource?.name ?? undefined, + }); + + // This effect is used to initially set the selectedSlackChannels state using the data retrieved from the API. + useEffect(() => { + if (slackChannelsLinkedWithAgent && agentConfigurationId && !edited) { + setSelectedSlackChannels( + slackChannelsLinkedWithAgent + .filter( + (channel) => channel.agentConfigurationId === agentConfigurationId + ) + .map((channel) => ({ + channelId: channel.slackChannelId, + channelName: channel.slackChannelName, + })) + ); + } + }, [slackChannelsLinkedWithAgent, agentConfigurationId, edited]); const assistantHandleIsValid = useCallback((handle: string) => { return /^[a-zA-Z0-9_-]{1,20}$/.test(removeLeadingAt(handle)); @@ -461,7 +504,45 @@ export default function AssistantBuilder({ throw new Error("An error occurred while saving the configuration."); } - return res.json(); + const newAgentConfiguration = await res.json(); + const agentConfigurationSid = newAgentConfiguration.agentConfiguration.sId; + + // PATCH the linked slack channels if either: + // - there were already linked channels + // - there are newly selected channels + // If the user selected channels that were already routed to a different assistant, the current behavior is to + // unlink them from the previous assistant and link them to the this one. + if ( + selectedSlackChannels.length || + slackChannelsLinkedWithAgent.filter( + (channel) => channel.agentConfigurationId === agentConfigurationId + ).length + ) { + const slackLinkRes = await fetch( + `/api/w/${owner.sId}/assistant/agent_configurations/${agentConfigurationSid}/linked_slack_channels`, + { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + slack_channel_ids: selectedSlackChannels.map( + ({ channelId }) => channelId + ), + }), + } + ); + + if (!slackLinkRes.ok) { + throw new Error("An error occurred while linking Slack channels."); + } + + await mutate( + `/api/w/${owner.sId}/data_sources/${slackDataSource?.name}/managed/slack/channels_linked_with_agent` + ); + } + + return newAgentConfiguration; }; const handleDeleteAgent = async () => { @@ -763,7 +844,7 @@ export default function AssistantBuilder({ -
+
+ {slackDataSource && ( + { + setEdited(true); + setSelectedSlackChannels(channels); + }} + existingSelection={selectedSlackChannels} + /> + )}
{agentConfigurationId && ( -
+
+ {existingSelection.length ? ( + <> +
+ Your assistant will answer by default when @dust is mentioned in the + following channels: +
+ + {existingSelection.map(({ channelId, channelName }) => { + return ( + } + action={ + +