Skip to content

Commit

Permalink
refactor: use tool() and organize by file (#721)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcus Schiesser <[email protected]>
  • Loading branch information
jeremyphilemon and marcusschiesser authored Jan 22, 2025
1 parent 371242d commit 3f9d379
Show file tree
Hide file tree
Showing 5 changed files with 412 additions and 345 deletions.
359 changes: 14 additions & 345 deletions app/(chat)/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,30 @@ import {
type Message,
convertToCoreMessages,
createDataStreamResponse,
experimental_generateImage,
streamObject,
streamText,
} from 'ai';
import { z } from 'zod';

import { auth } from '@/app/(auth)/auth';
import { customModel, imageGenerationModel } from '@/lib/ai';
import { customModel } from '@/lib/ai';
import { models } from '@/lib/ai/models';
import {
codePrompt,
systemPrompt,
updateDocumentPrompt,
} from '@/lib/ai/prompts';
import { systemPrompt } from '@/lib/ai/prompts';
import {
deleteChatById,
getChatById,
getDocumentById,
saveChat,
saveDocument,
saveMessages,
saveSuggestions,
} from '@/lib/db/queries';
import type { Suggestion } from '@/lib/db/schema';
import {
generateUUID,
getMostRecentUserMessage,
sanitizeResponseMessages,
} from '@/lib/utils';

import { generateTitleFromUserMessage } from '../../actions';
import { createDocument } from '@/lib/ai/tools/create-document';
import { updateDocument } from '@/lib/ai/tools/update-document';
import { requestSuggestions } from '@/lib/ai/tools/request-suggestions';
import { getWeather } from '@/lib/ai/tools/get-weather';

export const maxDuration = 60;

Expand All @@ -49,7 +42,6 @@ const blocksTools: AllowedTools[] = [
];

const weatherTools: AllowedTools[] = ['getWeather'];

const allTools: AllowedTools[] = [...blocksTools, ...weatherTools];

export async function POST(request: Request) {
Expand Down Expand Up @@ -108,337 +100,14 @@ export async function POST(request: Request) {
maxSteps: 5,
experimental_activeTools: allTools,
tools: {
getWeather: {
description: 'Get the current weather at a location',
parameters: z.object({
latitude: z.number(),
longitude: z.number(),
}),
execute: async ({ latitude, longitude }) => {
const response = await fetch(
`https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}&current=temperature_2m&hourly=temperature_2m&daily=sunrise,sunset&timezone=auto`,
);

const weatherData = await response.json();
return weatherData;
},
},
createDocument: {
description:
'Create a document for a writing or content creation activities like image generation. This tool will call other functions that will generate the contents of the document based on the title and kind.',
parameters: z.object({
title: z.string(),
kind: z.enum(['text', 'code', 'image']),
}),
execute: async ({ title, kind }) => {
const id = generateUUID();
let draftText = '';

dataStream.writeData({
type: 'id',
content: id,
});

dataStream.writeData({
type: 'title',
content: title,
});

dataStream.writeData({
type: 'kind',
content: kind,
});

dataStream.writeData({
type: 'clear',
content: '',
});

if (kind === 'text') {
const { fullStream } = streamText({
model: customModel(model.apiIdentifier),
system:
'Write about the given topic. Markdown is supported. Use headings wherever appropriate.',
prompt: title,
});

for await (const delta of fullStream) {
const { type } = delta;

if (type === 'text-delta') {
const { textDelta } = delta;

draftText += textDelta;
dataStream.writeData({
type: 'text-delta',
content: textDelta,
});
}
}

dataStream.writeData({ type: 'finish', content: '' });
} else if (kind === 'code') {
const { fullStream } = streamObject({
model: customModel(model.apiIdentifier),
system: codePrompt,
prompt: title,
schema: z.object({
code: z.string(),
}),
});

for await (const delta of fullStream) {
const { type } = delta;

if (type === 'object') {
const { object } = delta;
const { code } = object;

if (code) {
dataStream.writeData({
type: 'code-delta',
content: code ?? '',
});

draftText = code;
}
}
}

dataStream.writeData({ type: 'finish', content: '' });
} else if (kind === 'image') {
const { image } = await experimental_generateImage({
model: imageGenerationModel,
prompt: title,
n: 1,
});

draftText = image.base64;

dataStream.writeData({
type: 'image-delta',
content: image.base64,
});

dataStream.writeData({ type: 'finish', content: '' });
}

if (session.user?.id) {
await saveDocument({
id,
title,
kind,
content: draftText,
userId: session.user.id,
});
}

return {
id,
title,
kind,
content:
'A document was created and is now visible to the user.',
};
},
},
updateDocument: {
description: 'Update a document with the given description.',
parameters: z.object({
id: z.string().describe('The ID of the document to update'),
description: z
.string()
.describe('The description of changes that need to be made'),
}),
execute: async ({ id, description }) => {
const document = await getDocumentById({ id });

if (!document) {
return {
error: 'Document not found',
};
}

const { content: currentContent } = document;
let draftText = '';

dataStream.writeData({
type: 'clear',
content: document.title,
});

if (document.kind === 'text') {
const { fullStream } = streamText({
model: customModel(model.apiIdentifier),
system: updateDocumentPrompt(currentContent, 'text'),
prompt: description,
experimental_providerMetadata: {
openai: {
prediction: {
type: 'content',
content: currentContent,
},
},
},
});

for await (const delta of fullStream) {
const { type } = delta;

if (type === 'text-delta') {
const { textDelta } = delta;

draftText += textDelta;
dataStream.writeData({
type: 'text-delta',
content: textDelta,
});
}
}

dataStream.writeData({ type: 'finish', content: '' });
} else if (document.kind === 'code') {
const { fullStream } = streamObject({
model: customModel(model.apiIdentifier),
system: updateDocumentPrompt(currentContent, 'code'),
prompt: description,
schema: z.object({
code: z.string(),
}),
});

for await (const delta of fullStream) {
const { type } = delta;

if (type === 'object') {
const { object } = delta;
const { code } = object;

if (code) {
dataStream.writeData({
type: 'code-delta',
content: code ?? '',
});

draftText = code;
}
}
}

dataStream.writeData({ type: 'finish', content: '' });
} else if (document.kind === 'image') {
const { image } = await experimental_generateImage({
model: imageGenerationModel,
prompt: description,
n: 1,
});

draftText = image.base64;

dataStream.writeData({
type: 'image-delta',
content: image.base64,
});

dataStream.writeData({ type: 'finish', content: '' });
}

if (session.user?.id) {
await saveDocument({
id,
title: document.title,
content: draftText,
kind: document.kind,
userId: session.user.id,
});
}

return {
id,
title: document.title,
kind: document.kind,
content: 'The document has been updated successfully.',
};
},
},
requestSuggestions: {
description: 'Request suggestions for a document',
parameters: z.object({
documentId: z
.string()
.describe('The ID of the document to request edits'),
}),
execute: async ({ documentId }) => {
const document = await getDocumentById({ id: documentId });

if (!document || !document.content) {
return {
error: 'Document not found',
};
}

const suggestions: Array<
Omit<Suggestion, 'userId' | 'createdAt' | 'documentCreatedAt'>
> = [];

const { elementStream } = streamObject({
model: customModel(model.apiIdentifier),
system:
'You are a help writing assistant. Given a piece of writing, please offer suggestions to improve the piece of writing and describe the change. It is very important for the edits to contain full sentences instead of just words. Max 5 suggestions.',
prompt: document.content,
output: 'array',
schema: z.object({
originalSentence: z
.string()
.describe('The original sentence'),
suggestedSentence: z
.string()
.describe('The suggested sentence'),
description: z
.string()
.describe('The description of the suggestion'),
}),
});

for await (const element of elementStream) {
const suggestion = {
originalText: element.originalSentence,
suggestedText: element.suggestedSentence,
description: element.description,
id: generateUUID(),
documentId: documentId,
isResolved: false,
};

dataStream.writeData({
type: 'suggestion',
content: suggestion,
});

suggestions.push(suggestion);
}

if (session.user?.id) {
const userId = session.user.id;

await saveSuggestions({
suggestions: suggestions.map((suggestion) => ({
...suggestion,
userId,
createdAt: new Date(),
documentCreatedAt: document.createdAt,
})),
});
}

return {
id: documentId,
title: document.title,
kind: document.kind,
message: 'Suggestions have been added to the document',
};
},
},
getWeather,
createDocument: createDocument({ session, dataStream, model }),
updateDocument: updateDocument({ session, dataStream, model }),
requestSuggestions: requestSuggestions({
session,
dataStream,
model,
}),
},
onFinish: async ({ response }) => {
if (session.user?.id) {
Expand Down
Loading

0 comments on commit 3f9d379

Please sign in to comment.