From 91a6c6b236b8839fc883d045b790d2ab6047c734 Mon Sep 17 00:00:00 2001 From: Henry Fontanier Date: Wed, 31 Jul 2024 18:56:07 +0200 Subject: [PATCH] inject files in prompt --- front/lib/api/assistant/actions/process.ts | 1 + front/lib/api/assistant/agent.ts | 1 + front/lib/api/assistant/generation.ts | 89 +++++++++++++++++++++- front/lib/resources/file_resource.ts | 29 ++++--- 4 files changed, 107 insertions(+), 13 deletions(-) diff --git a/front/lib/api/assistant/actions/process.ts b/front/lib/api/assistant/actions/process.ts index 0308d60df1622..c41ba164ef344 100644 --- a/front/lib/api/assistant/actions/process.ts +++ b/front/lib/api/assistant/actions/process.ts @@ -234,6 +234,7 @@ export class ProcessConfigurationServerRunner extends BaseActionConfigurationSer }; const prompt = await constructPromptMultiActions(auth, { + conversation, userMessage, agentConfiguration, fallbackPrompt: diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index c641cfadd7d96..754a2256ff3e6 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -351,6 +351,7 @@ export async function* runMultiActionsAgent( const prompt = await constructPromptMultiActions(auth, { userMessage, + conversation, agentConfiguration, fallbackPrompt, model, diff --git a/front/lib/api/assistant/generation.ts b/front/lib/api/assistant/generation.ts index fcc9d8dfd9b39..9facea29a0787 100644 --- a/front/lib/api/assistant/generation.ts +++ b/front/lib/api/assistant/generation.ts @@ -2,6 +2,7 @@ import type { AgentConfigurationType, AssistantContentMessageTypeModel, AssistantFunctionCallMessageTypeModel, + ContentFragmentType, ConversationType, FunctionCallType, FunctionMessageTypeModel, @@ -24,13 +25,17 @@ import { Ok, removeNulls, } from "@dust-tt/types"; +import _ from "lodash"; import moment from "moment-timezone"; +import * as readline from "readline"; +import type { Readable } from "stream"; import { citationMetaPrompt } from "@app/lib/api/assistant/citations"; import { getAgentConfigurations } from "@app/lib/api/assistant/configuration"; import { visualizationSystemPrompt } from "@app/lib/api/assistant/visualization"; import type { Authenticator } from "@app/lib/auth"; import { renderContentFragmentForModel } from "@app/lib/resources/content_fragment_resource"; +import { FileResource } from "@app/lib/resources/file_resource"; import { tokenCountForTexts, tokenSplit } from "@app/lib/tokenization"; import logger from "@app/logger/logger"; @@ -369,12 +374,14 @@ export async function renderConversationForModelMultiActions({ export async function constructPromptMultiActions( auth: Authenticator, { + conversation, userMessage, agentConfiguration, fallbackPrompt, model, hasAvailableActions, }: { + conversation: ConversationType; userMessage: UserMessageType; agentConfiguration: AgentConfigurationType; fallbackPrompt?: string; @@ -449,7 +456,10 @@ export async function constructPromptMultiActions( } if (agentConfiguration.visualizationEnabled) { - additionalInstructions += visualizationSystemPrompt.trim(); + additionalInstructions += await getVisualizationPrompt({ + auth, + conversation, + }); } const providerMetaPrompt = model.metaPrompt; @@ -494,3 +504,80 @@ function getTextContentFromMessage( }) .join("\n"); } + +async function getVisualizationPrompt({ + auth, + conversation, +}: { + auth: Authenticator; + conversation: ConversationType; +}) { + const readFirstFiveLines = (inputStream: Readable): Promise => { + return new Promise((resolve, reject) => { + const rl: readline.Interface = readline.createInterface({ + input: inputStream, + crlfDelay: Infinity, + }); + + let lineCount: number = 0; + const lines: string[] = []; + + rl.on("line", (line: string) => { + lines.push(line); + lineCount++; + if (lineCount === 5) { + rl.close(); + } + }); + + rl.on("close", () => { + resolve(lines); + }); + + rl.on("error", (err: Error) => { + reject(err); + }); + }); + }; + + const contentFragmentMessages: Array = []; + for (const m of conversation.content.flat(1)) { + if (isContentFragmentType(m)) { + contentFragmentMessages.push(m); + } + } + const contentFragmentFileBySid = _.keyBy( + await FileResource.fetchByIds( + auth, + removeNulls(contentFragmentMessages.map((m) => m.fileId)) + ), + "sId" + ); + + const contentFragmentTextByMessageId: Record = {}; + for (const m of contentFragmentMessages) { + if (!m.fileId || !m.contentType.startsWith("text/")) { + continue; + } + + const file = contentFragmentFileBySid[m.fileId]; + if (!file) { + continue; + } + const readStream = file.getReadStream({ + auth, + version: "original", + }); + contentFragmentTextByMessageId[m.sId] = + await readFirstFiveLines(readStream); + } + + return ( + `${visualizationSystemPrompt.trim()}\n\nYou have access to the following files:\n` + + contentFragmentMessages + .map((m) => { + return `\n${contentFragmentTextByMessageId[m.sId]?.join("\n")}(truncated...)`; + }) + .join("\n") + ); +} diff --git a/front/lib/resources/file_resource.ts b/front/lib/resources/file_resource.ts index 36f5c4b03feff..1c25cc1d52560 100644 --- a/front/lib/resources/file_resource.ts +++ b/front/lib/resources/file_resource.ts @@ -9,7 +9,7 @@ import type { Result, UserType, } from "@dust-tt/types"; -import { Err, Ok } from "@dust-tt/types"; +import { Err, Ok, removeNulls } from "@dust-tt/types"; import type { Attributes, CreationAttributes, @@ -59,26 +59,31 @@ export class FileResource extends BaseResource { // TODO(2024-07-01 flav) Remove once we introduce AuthenticatorWithWorkspace. const owner = auth.workspace(); if (!owner) { - throw new Error("Unexpected unauthenticated call to `getUploadUrl`"); + throw new Error("Unexpected unauthenticated call to `fetchById`"); } + const res = await FileResource.fetchByIds(auth, [id]); + return res.length > 0 ? res[0] : null; + } - const fileModelId = getResourceIdFromSId(id); - if (!fileModelId) { - return null; + static async fetchByIds( + auth: Authenticator, + ids: string[] + ): Promise { + const owner = auth.workspace(); + if (!owner) { + throw new Error("Unexpected unauthenticated call to `fetchByIds`"); } - const blob = await this.model.findOne({ + const fileModelIds = removeNulls(ids.map((id) => getResourceIdFromSId(id))); + + const blobs = await this.model.findAll({ where: { workspaceId: owner.id, - id: fileModelId, + id: fileModelIds, }, }); - if (!blob) { - return null; - } - // Use `.get` to extract model attributes, omitting Sequelize instance metadata. - return new this(this.model, blob.get()); + return blobs.map((blob) => new this(this.model, blob.get())); } static async deleteAllForWorkspace(