Skip to content

Commit

Permalink
inject files in prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Fontanier committed Jul 31, 2024
1 parent 7e505d3 commit 91a6c6b
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 13 deletions.
1 change: 1 addition & 0 deletions front/lib/api/assistant/actions/process.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ export class ProcessConfigurationServerRunner extends BaseActionConfigurationSer
};

const prompt = await constructPromptMultiActions(auth, {
conversation,
userMessage,
agentConfiguration,
fallbackPrompt:
Expand Down
1 change: 1 addition & 0 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ export async function* runMultiActionsAgent(

const prompt = await constructPromptMultiActions(auth, {
userMessage,
conversation,
agentConfiguration,
fallbackPrompt,
model,
Expand Down
89 changes: 88 additions & 1 deletion front/lib/api/assistant/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type {
AgentConfigurationType,
AssistantContentMessageTypeModel,
AssistantFunctionCallMessageTypeModel,
ContentFragmentType,
ConversationType,
FunctionCallType,
FunctionMessageTypeModel,
Expand All @@ -24,13 +25,17 @@ import {
Ok,
removeNulls,
} from "@dust-tt/types";
import _ from "lodash";
import moment from "moment-timezone";
import * as readline from "readline";
import type { Readable } from "stream";

import { citationMetaPrompt } from "@app/lib/api/assistant/citations";
import { getAgentConfigurations } from "@app/lib/api/assistant/configuration";
import { visualizationSystemPrompt } from "@app/lib/api/assistant/visualization";
import type { Authenticator } from "@app/lib/auth";
import { renderContentFragmentForModel } from "@app/lib/resources/content_fragment_resource";
import { FileResource } from "@app/lib/resources/file_resource";
import { tokenCountForTexts, tokenSplit } from "@app/lib/tokenization";
import logger from "@app/logger/logger";

Expand Down Expand Up @@ -369,12 +374,14 @@ export async function renderConversationForModelMultiActions({
export async function constructPromptMultiActions(
auth: Authenticator,
{
conversation,
userMessage,
agentConfiguration,
fallbackPrompt,
model,
hasAvailableActions,
}: {
conversation: ConversationType;
userMessage: UserMessageType;
agentConfiguration: AgentConfigurationType;
fallbackPrompt?: string;
Expand Down Expand Up @@ -449,7 +456,10 @@ export async function constructPromptMultiActions(
}

if (agentConfiguration.visualizationEnabled) {
additionalInstructions += visualizationSystemPrompt.trim();
additionalInstructions += await getVisualizationPrompt({
auth,
conversation,
});
}

const providerMetaPrompt = model.metaPrompt;
Expand Down Expand Up @@ -494,3 +504,80 @@ function getTextContentFromMessage(
})
.join("\n");
}

async function getVisualizationPrompt({
auth,
conversation,
}: {
auth: Authenticator;
conversation: ConversationType;
}) {
const readFirstFiveLines = (inputStream: Readable): Promise<string[]> => {
return new Promise((resolve, reject) => {
const rl: readline.Interface = readline.createInterface({
input: inputStream,
crlfDelay: Infinity,
});

let lineCount: number = 0;
const lines: string[] = [];

rl.on("line", (line: string) => {
lines.push(line);
lineCount++;
if (lineCount === 5) {
rl.close();
}
});

rl.on("close", () => {
resolve(lines);
});

rl.on("error", (err: Error) => {
reject(err);
});
});
};

const contentFragmentMessages: Array<ContentFragmentType> = [];
for (const m of conversation.content.flat(1)) {
if (isContentFragmentType(m)) {
contentFragmentMessages.push(m);
}
}
const contentFragmentFileBySid = _.keyBy(
await FileResource.fetchByIds(
auth,
removeNulls(contentFragmentMessages.map((m) => m.fileId))
),
"sId"
);

const contentFragmentTextByMessageId: Record<string, string[]> = {};
for (const m of contentFragmentMessages) {
if (!m.fileId || !m.contentType.startsWith("text/")) {
continue;
}

const file = contentFragmentFileBySid[m.fileId];
if (!file) {
continue;
}
const readStream = file.getReadStream({
auth,
version: "original",
});
contentFragmentTextByMessageId[m.sId] =
await readFirstFiveLines(readStream);
}

return (
`${visualizationSystemPrompt.trim()}\n\nYou have access to the following files:\n` +
contentFragmentMessages
.map((m) => {
return `<file id="${m.fileId}" name="${m.title}" type="${m.contentType}">\n${contentFragmentTextByMessageId[m.sId]?.join("\n")}(truncated...)</file>`;
})
.join("\n")
);
}
29 changes: 17 additions & 12 deletions front/lib/resources/file_resource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import type {
Result,
UserType,
} from "@dust-tt/types";
import { Err, Ok } from "@dust-tt/types";
import { Err, Ok, removeNulls } from "@dust-tt/types";
import type {
Attributes,
CreationAttributes,
Expand Down Expand Up @@ -59,26 +59,31 @@ export class FileResource extends BaseResource<FileModel> {
// TODO(2024-07-01 flav) Remove once we introduce AuthenticatorWithWorkspace.
const owner = auth.workspace();
if (!owner) {
throw new Error("Unexpected unauthenticated call to `getUploadUrl`");
throw new Error("Unexpected unauthenticated call to `fetchById`");
}
const res = await FileResource.fetchByIds(auth, [id]);
return res.length > 0 ? res[0] : null;
}

const fileModelId = getResourceIdFromSId(id);
if (!fileModelId) {
return null;
static async fetchByIds(
auth: Authenticator,
ids: string[]
): Promise<FileResource[]> {
const owner = auth.workspace();
if (!owner) {
throw new Error("Unexpected unauthenticated call to `fetchByIds`");
}

const blob = await this.model.findOne({
const fileModelIds = removeNulls(ids.map((id) => getResourceIdFromSId(id)));

const blobs = await this.model.findAll({
where: {
workspaceId: owner.id,
id: fileModelId,
id: fileModelIds,
},
});
if (!blob) {
return null;
}

// Use `.get` to extract model attributes, omitting Sequelize instance metadata.
return new this(this.model, blob.get());
return blobs.map((blob) => new this(this.model, blob.get()));
}

static async deleteAllForWorkspace(
Expand Down

0 comments on commit 91a6c6b

Please sign in to comment.