From eb31a99ca6cbac03a4655e8ff37406f3f65ec560 Mon Sep 17 00:00:00 2001 From: lublagg Date: Tue, 17 Dec 2024 17:00:12 -0500 Subject: [PATCH] Send an image to the LLM for description. --- src/models/assistant-model.ts | 162 +++++++++++++++++++++++----------- src/utils/openai-utils.ts | 40 +++++++++ 2 files changed, 150 insertions(+), 52 deletions(-) diff --git a/src/models/assistant-model.ts b/src/models/assistant-model.ts index ea7cfd7..d2196c9 100644 --- a/src/models/assistant-model.ts +++ b/src/models/assistant-model.ts @@ -5,7 +5,7 @@ import { DAVAI_SPEAKER, DEBUG_SPEAKER } from "../constants"; import { formatJsonMessage } from "../utils/utils"; import { getTools, initLlmConnection } from "../utils/llm-utils"; import { ChatTranscriptModel } from "./chat-transcript-model"; -import { requestThreadDeletion } from "../utils/openai-utils"; +import { convertBase64ToImage, requestThreadDeletion } from "../utils/openai-utils"; /** * AssistantModel encapsulates the AI assistant and its interactions with the user. @@ -34,6 +34,8 @@ export const AssistantModel = types }) .volatile(() => ({ isLoadingResponse: false, + uploadFileAfterRun: false, + dataUri: "", })) .actions((self) => ({ handleMessageSubmitMockAssistant() { @@ -95,6 +97,7 @@ export const AssistantModel = types } catch (err) { console.error("Failed to handle message submit:", err); self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Failed to handle message submit", content: formatJsonMessage(err)}); + self.isLoadingResponse = false; } }); @@ -120,60 +123,107 @@ export const AssistantModel = types content: formatJsonMessage(runState.status), }); - const errorStates = ["failed", "cancelled", "incomplete"]; + const errorStates = ["failed", "cancelled", "incomplete"]; - while (runState.status !== "completed" && runState.status !== "requires_action" && !errorStates.includes(runState.status)) { - yield new Promise((resolve) => setTimeout(resolve, 2000)); - runState = yield self.apiConnection.beta.threads.runs.retrieve(self.thread.id, currentRunId); - self.transcriptStore.addMessage(DEBUG_SPEAKER, { - description: "Run state status", - content: formatJsonMessage(runState.status), - }); - } + while (runState.status !== "completed" && runState.status !== "requires_action" && !errorStates.includes(runState.status)) { + yield new Promise((resolve) => setTimeout(resolve, 2000)); + runState = yield self.apiConnection.beta.threads.runs.retrieve(self.thread.id, currentRunId); + self.transcriptStore.addMessage(DEBUG_SPEAKER, { + description: "Run state status", + content: formatJsonMessage(runState.status), + }); + } - if (runState.status === "requires_action") { - self.transcriptStore.addMessage(DEBUG_SPEAKER, { - description: "Run requires action", - content: formatJsonMessage(runState), - }); - yield handleRequiredAction(runState, currentRunId); - yield pollRunState(currentRunId); - } - - if (runState.status === "completed") { - const messages = yield self.apiConnection.beta.threads.messages.list(self.thread.id); - - const lastMessageForRun = messages.data - .filter((msg: Message) => msg.run_id === currentRunId && msg.role === "assistant") - .pop(); - - self.transcriptStore.addMessage(DEBUG_SPEAKER, { - description: "Run completed, assistant response", - content: formatJsonMessage(lastMessageForRun), - }); - - const lastMessageContent = lastMessageForRun?.content[0]?.text?.value; - if (lastMessageContent) { - self.transcriptStore.addMessage(DAVAI_SPEAKER, { content: lastMessageContent }); - } else { + if (runState.status === "requires_action") { + self.transcriptStore.addMessage(DEBUG_SPEAKER, { + description: "Run requires action", + content: formatJsonMessage(runState), + }); + yield handleRequiredAction(runState, currentRunId); + yield pollRunState(currentRunId); + } + + if (runState.status === "completed") { + if (self.uploadFileAfterRun && self.dataUri) { + const fileId = yield uploadFile(); + yield sendFileMessage(fileId); + self.uploadFileAfterRun = false; + self.dataUri = ""; + startRun(); + } else { + const messages = yield self.apiConnection.beta.threads.messages.list(self.thread.id); + + const lastMessageForRun = messages.data + .filter((msg: Message) => msg.run_id === currentRunId && msg.role === "assistant") + .pop(); + + self.transcriptStore.addMessage(DEBUG_SPEAKER, { + description: "Run completed, assistant response", + content: formatJsonMessage(lastMessageForRun), + }); + + const lastMessageContent = lastMessageForRun?.content[0]?.text?.value; + if (lastMessageContent) { + self.transcriptStore.addMessage(DAVAI_SPEAKER, { content: lastMessageContent }); + } else { + self.transcriptStore.addMessage(DAVAI_SPEAKER, { + content: "I'm sorry, I don't have a response for that.", + }); + } + self.isLoadingResponse = false; + } + } + + if (errorStates.includes(runState.status)) { + self.transcriptStore.addMessage(DEBUG_SPEAKER, { + description: "Run failed", + content: formatJsonMessage(runState), + }); self.transcriptStore.addMessage(DAVAI_SPEAKER, { - content: "I'm sorry, I don't have a response for that.", - }); - } - self.isLoadingResponse = false; - } + content: "I'm sorry, I encountered an error. Please try again.", + }); + self.isLoadingResponse = false; + } + }); - if (errorStates.includes(runState.status)) { - self.transcriptStore.addMessage(DEBUG_SPEAKER, { - description: "Run failed", - content: formatJsonMessage(runState), - }); - self.transcriptStore.addMessage(DAVAI_SPEAKER, { - content: "I'm sorry, I encountered an error. Please try again.", - }); - self.isLoadingResponse = false; - } - }); + const uploadFile = flow(function* () { + try { + const fileFromDataUri = yield convertBase64ToImage(self.dataUri); + const uploadedFile = yield self.apiConnection?.files.create({ + file: fileFromDataUri, + purpose: "vision" + }); + return uploadedFile.id; + } + catch (err) { + console.error("Failed to upload image:", err); + self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Failed to upload image", content: formatJsonMessage(err)}); + } + }); + + const sendFileMessage = flow(function* (fileId) { + try { + const res = yield self.apiConnection.beta.threads.messages.create(self.thread.id, { + role: "user", + content: [ + { + type: "text", + text: "This is an image of a graph. Describe it for the user." + }, + { + type: "image_file", + image_file: { + file_id: fileId + } + } + ] + }); + self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Image uploaded", content: formatJsonMessage(res)}); + } catch (err) { + console.error("Failed to send file message:", err); + self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Failed to send file message", content: formatJsonMessage(err)}); + } + }); const handleRequiredAction = flow(function* (runState, runId) { try { @@ -184,7 +234,14 @@ export const AssistantModel = types const { action, resource, values } = JSON.parse(toolCall.function.arguments); const request = { action, resource, values }; self.transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Request sent to CODAP", content: formatJsonMessage(request) }); - const res = yield codapInterface.sendRequest(request); + let res = yield codapInterface.sendRequest(request); + // note: we will implement a new endpoint in the CODAP api to get the exportDataUri value, + // so that it can be retrieved separately from a general get component request + if (res.values.exportDataUri) { + self.uploadFileAfterRun = true; + self.dataUri = res.values.exportDataUri; + res = { ...res, values: { ...res.values, exportDataUri: undefined } }; + } self.transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Response from CODAP", content: formatJsonMessage(res) }); return { tool_call_id: toolCall.id, output: JSON.stringify(res) }; } else { @@ -203,6 +260,7 @@ export const AssistantModel = types } catch (err) { console.error(err); self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Error taking required action", content: formatJsonMessage(err)}); + self.isLoadingResponse = false; } }); diff --git a/src/utils/openai-utils.ts b/src/utils/openai-utils.ts index f2714fb..b26916d 100644 --- a/src/utils/openai-utils.ts +++ b/src/utils/openai-utils.ts @@ -11,6 +11,21 @@ export const newOpenAI = () => { }); }; +export async function convertBase64ToImage(base64Data: string, filename = "image.png") { + const base64 = base64Data.split(",")[1]; + + const binary = atob(base64); + const binaryLength = binary.length; + const arrayBuffer = new Uint8Array(binaryLength); + for (let i = 0; i < binaryLength; i++) { + arrayBuffer[i] = binary.charCodeAt(i); + } + + const blob = new Blob([arrayBuffer], { type: "image/png" }); + const file = new File([blob], filename, { type: "image/png" }); + return file; +} + export const openAiTools: AssistantTool[] = [ { type: "function", @@ -42,6 +57,31 @@ export const openAiTools: AssistantTool[] = [ } } }, + { + type: "function", + function: { + name: "convert_base64_to_image", + description: "Convert a base64 image to a file object", + strict: false, + parameters: { + type: "object", + properties: { + base64Data: { + type: "string", + description: "The base64 image data" + }, + filename: { + type: "string", + description: "The filename to use for the image" + } + }, + additionalProperties: false, + required: [ + "base64Data" + ] + } + } + } ]; export const requestThreadDeletion = async (threadId: string): Promise => {