Skip to content

Commit

Permalink
Flav/render image for model (#5930)
Browse files Browse the repository at this point in the history
* Support images in useFileUploaderService

* Display images content fragment in messages

* Add proxy for GCS images

* Support uploading images as content fragment

* ✨

* Add signedUrl method on FileStorage

* Render images for model

* 👕

* 🙈

* Fix image placeholder content

* Address comments from review

* 👕
  • Loading branch information
flvndvd authored Jun 28, 2024
1 parent c37cd39 commit f33f80d
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 66 deletions.
3 changes: 2 additions & 1 deletion front/lib/api/assistant/actions/tables_query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,10 @@ export class TablesQueryConfigurationServerRunner extends BaseActionConfiguratio
const renderedConversationRes =
await renderConversationForModelMultiActions({
conversation,
model: agentConfiguration.model,
model: supportedModel,
prompt: agentConfiguration.instructions ?? "",
allowedTokenCount,
excludeImages: true,
});
if (renderedConversationRes.isErr()) {
yield {
Expand Down
1 change: 1 addition & 0 deletions front/lib/api/assistant/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ export async function generateConversationTitle(
prompt: "", // There is no prompt for title generation.
allowedTokenCount: model.contextSize - MIN_GENERATION_TOKENS,
excludeActions: true,
excludeImages: true,
});

if (modelConversationRes.isErr()) {
Expand Down
81 changes: 26 additions & 55 deletions front/lib/api/assistant/generation.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import type {
AgentConfigurationType,
ContentFragmentMessageTypeModel,
ConversationType,
FunctionCallType,
FunctionMessageTypeModel,
Expand Down Expand Up @@ -28,7 +27,7 @@ import moment from "moment-timezone";
import { citationMetaPrompt } from "@app/lib/api/assistant/citations";
import { getAgentConfigurations } from "@app/lib/api/assistant/configuration";
import type { Authenticator } from "@app/lib/auth";
import { getContentFragmentText } from "@app/lib/resources/content_fragment_resource";
import { renderContentFragmentForModel } from "@app/lib/resources/content_fragment_resource";
import { tokenCountForText, tokenSplit } from "@app/lib/tokenization";
import logger from "@app/logger/logger";

Expand All @@ -42,12 +41,14 @@ export async function renderConversationForModelMultiActions({
prompt,
allowedTokenCount,
excludeActions,
excludeImages,
}: {
conversation: ConversationType;
model: { providerId: string; modelId: string };
model: ModelConfigurationType;
prompt: string;
allowedTokenCount: number;
excludeActions?: boolean;
excludeImages?: boolean;
}): Promise<
Result<
{
Expand All @@ -59,7 +60,6 @@ export async function renderConversationForModelMultiActions({
> {
const now = Date.now();
const messages: ModelMessageTypeMultiActions[] = [];
const closingAttachmentTag = "</attachment>\n";

// Render loop.
// Render all messages and all actions.
Expand Down Expand Up @@ -133,36 +133,15 @@ export async function renderConversationForModelMultiActions({
],
});
} else if (isContentFragmentType(m)) {
try {
const content = await getContentFragmentText({
workspaceId: conversation.owner.sId,
conversationId: conversation.sId,
messageId: m.sId,
});
messages.push({
role: "content_fragment",
name: `inject_${m.contentType}`,
// The closing </attachment> tag will be added in the merging loop because we might
// need to add a "truncated..." mention in the selection loop.
content: [
{
type: "text",
text: `<attachment type="${m.contentType}" title="${m.title}">\n${content}\n`,
},
],
});
} catch (error) {
logger.error(
{
error,
workspaceId: conversation.owner.sId,
conversationId: conversation.sId,
messageId: m.sId,
},
"Failed to retrieve content fragment text"
);
return new Err(new Error("Failed to retrieve content fragment text"));
const res = await renderContentFragmentForModel(m, conversation, model, {
excludeImages: Boolean(excludeImages),
});

if (res.isErr()) {
return new Err(res.error);
}

messages.push(res.value);
} else {
assertNever(m);
}
Expand All @@ -173,15 +152,6 @@ export async function renderConversationForModelMultiActions({
Promise.all(
messages.map((m) => {
let text = `${m.role} ${"name" in m ? m.name : ""} ${getTextContentFromMessage(m)}`;
if (
isContentFragmentMessageTypeModel(m) &&
m.content.every((c) => isTextContent(c))
) {
// Account for the upcoming </attachment> tag for textual attachments,
// as it will be appended during the merging process.
text += closingAttachmentTag;
}

if ("function_calls" in m) {
text += m.function_calls
.map((f) => `${f.name} ${f.arguments}`)
Expand Down Expand Up @@ -226,30 +196,39 @@ export async function renderConversationForModelMultiActions({
// Allow at least tokensMargin tokens in addition to the truncation message.
tokensUsed + approxTruncMsgTokenCount + tokensMargin < allowedTokenCount
) {
const msg = messages[i] as ContentFragmentMessageTypeModel;
const remainingTokens =
allowedTokenCount - tokensUsed - approxTruncMsgTokenCount;

const updatedContent = [];
for (const c of msg.content) {
for (const c of currentMessage.content) {
if (!isTextContent(c)) {
// If there is not enough room and it's an image, we simply ignore it.
continue;
}

const contentRes = await tokenSplit(c.text, model, remainingTokens);
// Remove only if it ends with "</attachment>".
const textWithoutClosingAttachmentTag = c.text.replace(
/<\/attachment>$/,
""
);

const contentRes = await tokenSplit(
textWithoutClosingAttachmentTag,
model,
remainingTokens
);
if (contentRes.isErr()) {
return new Err(contentRes.error);
}

updatedContent.push({
...c,
text: contentRes.value + truncationMessage,
text: `${contentRes.value}${truncationMessage}</attachment>`,
});
}

selected.unshift({
...msg,
...currentMessage,
content: updatedContent,
});

Expand Down Expand Up @@ -285,14 +264,6 @@ export async function renderConversationForModelMultiActions({
);
}

for (const c of cfMessage.content) {
if (isTextContent(c)) {
// We can now close the </attachment> tag, because the message was already properly
// truncated. We also accounted for the closing that above when computing the tokens count.
c.text += closingAttachmentTag;
}
}

userMessage.content = [...cfMessage.content, ...userMessage.content];
// Now we remove the content fragment from the array since it was merged into the upcoming
// user message.
Expand Down
27 changes: 26 additions & 1 deletion front/lib/file_storage/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import { pipeline } from "stream/promises";
import config from "@app/lib/file_storage/config";
import { isGCSNotFoundError } from "@app/lib/file_storage/types";

const DEFAULT_SIGNED_URL_EXPIRATION_DELAY_MS = 5 * 60 * 1000; // 5 minutes.

class FileStorage {
private readonly bucket: Bucket;
private readonly storage: Storage;
Expand Down Expand Up @@ -73,6 +75,27 @@ class FileStorage {
return metadata.contentType;
}

async getSignedUrl(
filename: string,
{
expirationDelay,
promptSaveAs,
}: { expirationDelay: number; promptSaveAs?: string } = {
expirationDelay: DEFAULT_SIGNED_URL_EXPIRATION_DELAY_MS,
}
): Promise<string> {
const gcsFile = this.file(filename);

const signedUrl = await gcsFile.getSignedUrl({
version: "v4",
action: "read",
expires: new Date().getTime() + expirationDelay,
promptSaveAs,
});

return signedUrl.toString();
}

file(filename: string) {
return this.bucket.file(filename);
}
Expand Down Expand Up @@ -103,7 +126,9 @@ class FileStorage {

const bucketInstances = new Map();

const getBucketInstance = (bucketConfig: string) => {
const getBucketInstance: (bucketConfig: string) => FileStorage = (
bucketConfig: string
) => {
if (!bucketInstances.has(bucketConfig)) {
bucketInstances.set(bucketConfig, new FileStorage(bucketConfig));
}
Expand Down
109 changes: 107 additions & 2 deletions front/lib/resources/content_fragment_resource.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import type { ContentFragmentType, ModelId, Result } from "@dust-tt/types";
import { Err, Ok } from "@dust-tt/types";
import type {
ContentFragmentMessageTypeModel,
ContentFragmentType,
ConversationType,
ModelConfigurationType,
ModelId,
Result,
} from "@dust-tt/types";
import { Err, isSupportedImageContentFragmentType, Ok } from "@dust-tt/types";
import type {
Attributes,
CreationAttributes,
Expand All @@ -14,6 +21,7 @@ import { Message } from "@app/lib/models/assistant/conversation";
import { BaseResource } from "@app/lib/resources/base_resource";
import { ContentFragmentModel } from "@app/lib/resources/storage/models/content_fragment";
import type { ReadonlyAttributesType } from "@app/lib/resources/storage/types";
import logger from "@app/logger/logger";

// Attributes are marked as read-only to reflect the stateless nature of our Resource.
// This design will be moved up to BaseResource once we transition away from Sequelize.
Expand Down Expand Up @@ -276,3 +284,100 @@ export async function getContentFragmentText({

return getPrivateUploadBucket().fetchFileContent(filePath);
}

async function getSignedUrlForRawContentFragment({
workspaceId,
conversationId,
messageId,
}: {
workspaceId: string;
conversationId: string;
messageId: string;
}): Promise<string> {
const fileLocation = fileAttachmentLocation({
workspaceId,
conversationId,
messageId,
contentFormat: "raw",
});

return getPrivateUploadBucket().getSignedUrl(fileLocation.filePath);
}

export async function renderContentFragmentForModel(
message: ContentFragmentType,
conversation: ConversationType,
model: ModelConfigurationType,
{
excludeImages,
}: {
excludeImages: boolean;
}
): Promise<Result<ContentFragmentMessageTypeModel, Error>> {
const { contentType, sId, title } = message;

try {
if (isSupportedImageContentFragmentType(contentType)) {
if (excludeImages || !model.supportsVision) {
return new Ok({
role: "content_fragment",
name: `inject_${contentType}`,
content: [
{
type: "text",
text: `<attachment type="${contentType} title="${title}">[Image content interpreted by a vision-enabled model. Description not available in this context.]</attachment>`,
},
],
});
}

const signedUrl = await getSignedUrlForRawContentFragment({
workspaceId: conversation.owner.sId,
conversationId: conversation.sId,
messageId: sId,
});

return new Ok({
role: "content_fragment",
name: `inject_${contentType}`,
content: [
{
type: "image_url",
image_url: {
url: signedUrl,
},
},
],
});
} else {
const content = await getContentFragmentText({
workspaceId: conversation.owner.sId,
conversationId: conversation.sId,
messageId: sId,
});

return new Ok({
role: "content_fragment",
name: `inject_${contentType}`,
content: [
{
type: "text",
text: `<attachment type="${contentType}" title="${title}">\n${content}\n</attachment>`,
},
],
});
}
} catch (error) {
logger.error(
{
error,
workspaceId: conversation.owner.sId,
conversationId: conversation.sId,
messageId: sId,
},
"Failed to retrieve content fragment text"
);

return new Err(new Error("Failed to retrieve content fragment text"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,11 @@ async function handler(
return;
}

// redirect to a signed URL
const [url] = await privateUploadGcs.file(filePath).getSignedUrl({
version: "v4",
action: "read",
// since we redirect, the use is immediate so expiry can be short
expires: Date.now() + 10 * 1000,
// remove special chars
// Redirect to a signed URL.
const [url] = await privateUploadGcs.getSignedUrl(filePath, {
// Since we redirect, the use is immediate so expiry can be short.
expirationDelay: Date.now() + 10 * 1000,
// Remove special chars.
promptSaveAs:
message.title.replace(/[^\w\s.-]/gi, "") +
(contentFormat === "text" ? ".txt" : ""),
Expand Down
Loading

0 comments on commit f33f80d

Please sign in to comment.