Skip to content

Commit

Permalink
Send an image to the LLM for description.
Browse files Browse the repository at this point in the history
  • Loading branch information
lublagg committed Dec 17, 2024
1 parent 8aa942c commit eb31a99
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 52 deletions.
162 changes: 110 additions & 52 deletions src/models/assistant-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { DAVAI_SPEAKER, DEBUG_SPEAKER } from "../constants";
import { formatJsonMessage } from "../utils/utils";
import { getTools, initLlmConnection } from "../utils/llm-utils";
import { ChatTranscriptModel } from "./chat-transcript-model";
import { requestThreadDeletion } from "../utils/openai-utils";
import { convertBase64ToImage, requestThreadDeletion } from "../utils/openai-utils";

/**
* AssistantModel encapsulates the AI assistant and its interactions with the user.
Expand Down Expand Up @@ -34,6 +34,8 @@ export const AssistantModel = types
})
.volatile(() => ({
isLoadingResponse: false,
uploadFileAfterRun: false,
dataUri: "",
}))
.actions((self) => ({
handleMessageSubmitMockAssistant() {
Expand Down Expand Up @@ -95,6 +97,7 @@ export const AssistantModel = types
} catch (err) {
console.error("Failed to handle message submit:", err);
self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Failed to handle message submit", content: formatJsonMessage(err)});
self.isLoadingResponse = false;
}
});

Expand All @@ -120,60 +123,107 @@ export const AssistantModel = types
content: formatJsonMessage(runState.status),
});

const errorStates = ["failed", "cancelled", "incomplete"];
const errorStates = ["failed", "cancelled", "incomplete"];

while (runState.status !== "completed" && runState.status !== "requires_action" && !errorStates.includes(runState.status)) {
yield new Promise((resolve) => setTimeout(resolve, 2000));
runState = yield self.apiConnection.beta.threads.runs.retrieve(self.thread.id, currentRunId);
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run state status",
content: formatJsonMessage(runState.status),
});
}
while (runState.status !== "completed" && runState.status !== "requires_action" && !errorStates.includes(runState.status)) {
yield new Promise((resolve) => setTimeout(resolve, 2000));
runState = yield self.apiConnection.beta.threads.runs.retrieve(self.thread.id, currentRunId);
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run state status",
content: formatJsonMessage(runState.status),
});
}

if (runState.status === "requires_action") {
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run requires action",
content: formatJsonMessage(runState),
});
yield handleRequiredAction(runState, currentRunId);
yield pollRunState(currentRunId);
}

if (runState.status === "completed") {
const messages = yield self.apiConnection.beta.threads.messages.list(self.thread.id);

const lastMessageForRun = messages.data
.filter((msg: Message) => msg.run_id === currentRunId && msg.role === "assistant")
.pop();

self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run completed, assistant response",
content: formatJsonMessage(lastMessageForRun),
});

const lastMessageContent = lastMessageForRun?.content[0]?.text?.value;
if (lastMessageContent) {
self.transcriptStore.addMessage(DAVAI_SPEAKER, { content: lastMessageContent });
} else {
if (runState.status === "requires_action") {
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run requires action",
content: formatJsonMessage(runState),
});
yield handleRequiredAction(runState, currentRunId);
yield pollRunState(currentRunId);
}

if (runState.status === "completed") {
if (self.uploadFileAfterRun && self.dataUri) {
const fileId = yield uploadFile();
yield sendFileMessage(fileId);
self.uploadFileAfterRun = false;
self.dataUri = "";
startRun();
} else {
const messages = yield self.apiConnection.beta.threads.messages.list(self.thread.id);

const lastMessageForRun = messages.data
.filter((msg: Message) => msg.run_id === currentRunId && msg.role === "assistant")
.pop();

self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run completed, assistant response",
content: formatJsonMessage(lastMessageForRun),
});

const lastMessageContent = lastMessageForRun?.content[0]?.text?.value;
if (lastMessageContent) {
self.transcriptStore.addMessage(DAVAI_SPEAKER, { content: lastMessageContent });
} else {
self.transcriptStore.addMessage(DAVAI_SPEAKER, {
content: "I'm sorry, I don't have a response for that.",
});
}
self.isLoadingResponse = false;
}
}

if (errorStates.includes(runState.status)) {
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run failed",
content: formatJsonMessage(runState),
});
self.transcriptStore.addMessage(DAVAI_SPEAKER, {
content: "I'm sorry, I don't have a response for that.",
});
}
self.isLoadingResponse = false;
}
content: "I'm sorry, I encountered an error. Please try again.",
});
self.isLoadingResponse = false;
}
});

if (errorStates.includes(runState.status)) {
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run failed",
content: formatJsonMessage(runState),
});
self.transcriptStore.addMessage(DAVAI_SPEAKER, {
content: "I'm sorry, I encountered an error. Please try again.",
});
self.isLoadingResponse = false;
}
});
const uploadFile = flow(function* () {
try {
const fileFromDataUri = yield convertBase64ToImage(self.dataUri);
const uploadedFile = yield self.apiConnection?.files.create({
file: fileFromDataUri,
purpose: "vision"
});
return uploadedFile.id;
}
catch (err) {
console.error("Failed to upload image:", err);
self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Failed to upload image", content: formatJsonMessage(err)});
}
});

const sendFileMessage = flow(function* (fileId) {
try {
const res = yield self.apiConnection.beta.threads.messages.create(self.thread.id, {
role: "user",
content: [
{
type: "text",
text: "This is an image of a graph. Describe it for the user."
},
{
type: "image_file",
image_file: {
file_id: fileId
}
}
]
});
self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Image uploaded", content: formatJsonMessage(res)});
} catch (err) {
console.error("Failed to send file message:", err);
self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Failed to send file message", content: formatJsonMessage(err)});
}
});

const handleRequiredAction = flow(function* (runState, runId) {
try {
Expand All @@ -184,7 +234,14 @@ export const AssistantModel = types
const { action, resource, values } = JSON.parse(toolCall.function.arguments);
const request = { action, resource, values };
self.transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Request sent to CODAP", content: formatJsonMessage(request) });
const res = yield codapInterface.sendRequest(request);
let res = yield codapInterface.sendRequest(request);
// note: we will implement a new endpoint in the CODAP api to get the exportDataUri value,
// so that it can be retrieved separately from a general get component request
if (res.values.exportDataUri) {
self.uploadFileAfterRun = true;
self.dataUri = res.values.exportDataUri;
res = { ...res, values: { ...res.values, exportDataUri: undefined } };
}
self.transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Response from CODAP", content: formatJsonMessage(res) });
return { tool_call_id: toolCall.id, output: JSON.stringify(res) };
} else {
Expand All @@ -203,6 +260,7 @@ export const AssistantModel = types
} catch (err) {
console.error(err);
self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Error taking required action", content: formatJsonMessage(err)});
self.isLoadingResponse = false;
}
});

Expand Down
40 changes: 40 additions & 0 deletions src/utils/openai-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ export const newOpenAI = () => {
});
};

export async function convertBase64ToImage(base64Data: string, filename = "image.png") {
const base64 = base64Data.split(",")[1];

const binary = atob(base64);
const binaryLength = binary.length;
const arrayBuffer = new Uint8Array(binaryLength);
for (let i = 0; i < binaryLength; i++) {
arrayBuffer[i] = binary.charCodeAt(i);
}

const blob = new Blob([arrayBuffer], { type: "image/png" });
const file = new File([blob], filename, { type: "image/png" });
return file;
}

export const openAiTools: AssistantTool[] = [
{
type: "function",
Expand Down Expand Up @@ -42,6 +57,31 @@ export const openAiTools: AssistantTool[] = [
}
}
},
{
type: "function",
function: {
name: "convert_base64_to_image",
description: "Convert a base64 image to a file object",
strict: false,
parameters: {
type: "object",
properties: {
base64Data: {
type: "string",
description: "The base64 image data"
},
filename: {
type: "string",
description: "The filename to use for the image"
}
},
additionalProperties: false,
required: [
"base64Data"
]
}
}
}
];

export const requestThreadDeletion = async (threadId: string): Promise<Response> => {
Expand Down

0 comments on commit eb31a99

Please sign in to comment.