From 419418d44a80e5946d2aacdadf2e04c0ec320216 Mon Sep 17 00:00:00 2001 From: Jacob Cable Date: Thu, 12 Dec 2024 12:36:35 +0000 Subject: [PATCH 1/5] feat(js/plugins/ollama): add support for structured responses --- docs/plugins/ollama.md | 56 ++++ js/plugins/ollama/package.json | 5 +- js/plugins/ollama/src/index.ts | 114 ++++++-- js/plugins/ollama/src/types.ts | 16 +- js/plugins/ollama/tests/embeddings_test.ts | 55 ++-- .../ollama/tests/streaming_live_test.ts | 162 +++++++++++ .../tests/structured_response_live_test.ts | 129 +++++++++ .../ollama/tests/structured_response_test.ts | 258 ++++++++++++++++++ js/pnpm-lock.yaml | 13 +- 9 files changed, 750 insertions(+), 58 deletions(-) create mode 100644 js/plugins/ollama/tests/streaming_live_test.ts create mode 100644 js/plugins/ollama/tests/structured_response_live_test.ts create mode 100644 js/plugins/ollama/tests/structured_response_test.ts diff --git a/docs/plugins/ollama.md b/docs/plugins/ollama.md index 5b6ecd385..acee53bd4 100644 --- a/docs/plugins/ollama.md +++ b/docs/plugins/ollama.md @@ -151,3 +151,59 @@ async function getEmbedding() { getEmbedding().then((e) => console.log(e)) ``` + +## Structured Output +The Ollama plugin supports structured output through JSON schema validation. This allows you to specify the exact structure of the response you expect from the model. +### Basic Usage +Here's a simple example using Zod to define a schema and get structured output: +```ts +import { genkit, z } from 'genkit'; +import { ollama } from 'genkitx-ollama'; + +// Define your schema using Zod +const CountrySchema = z.object({ + name: z.string(), + capital: z.string(), + languages: z.array(z.string()), +}); + +// Initialize Genkit with Ollama plugin +const ai = genkit({ + plugins: [ + ollama({ + models: [{ name: 'your-model', type: 'chat' }], + serverAddress: 'http://localhost:11434', + }), + ], +}); + +// ... +// ... + +// Use structured output in your request +const llmResponse = await ai.generate({ + model: 'ollama/your-model', + messages: [ + { + role: 'system', + content: [ + { + text: 'You are a helpful assistant that provides information about countries in a structured format.', + }, + ], + }, + { + role: 'user', + content: [{ text: 'Tell me about Canada.' }], + }, + ], + output: { + format: 'json', + schema: CountrySchema, + }, +}); + +// Now output should be typed correectly: +const {name, capital, languages} = llmResponse.output; + +``` \ No newline at end of file diff --git a/js/plugins/ollama/package.json b/js/plugins/ollama/package.json index 1c00e7d53..290f6f92a 100644 --- a/js/plugins/ollama/package.json +++ b/js/plugins/ollama/package.json @@ -34,7 +34,7 @@ "devDependencies": { "@types/node": "^20.11.16", "npm-run-all": "^4.1.5", - "ollama": "^0.5.9", + "ollama": "^0.5.11", "rimraf": "^6.0.1", "tsup": "^8.3.5", "tsx": "^4.19.2", @@ -48,5 +48,8 @@ "types": "./lib/index.d.ts", "default": "./lib/index.js" } + }, + "dependencies": { + "zod-to-json-schema": "^3.22.4" } } diff --git a/js/plugins/ollama/src/index.ts b/js/plugins/ollama/src/index.ts index 1a934a700..124f947d3 100644 --- a/js/plugins/ollama/src/index.ts +++ b/js/plugins/ollama/src/index.ts @@ -14,19 +14,24 @@ * limitations under the License. */ -import { Genkit } from 'genkit'; +import { Genkit, GenkitError } from 'genkit'; +import { extractJson } from 'genkit/extract'; import { logger } from 'genkit/logging'; import { GenerateRequest, GenerateResponseData, GenerationCommonConfigSchema, - getBasicUsageStats, MessageData, + getBasicUsageStats, } from 'genkit/model'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; +import { ErrorResponse } from 'ollama'; import { defineOllamaEmbedder } from './embeddings.js'; import { ApiType, + ChatResponse, + GenerateResponse, + Message, ModelDefinition, OllamaPluginParams, RequestHeaders, @@ -143,8 +148,10 @@ function ollamaModel( let textResponse = ''; for await (const chunk of readChunks(reader)) { const chunkText = textDecoder.decode(chunk); - const json = JSON.parse(chunkText); - const message = parseMessage(json, type); + const json = extractJson(chunkText) as + | ChatResponse + | GenerateResponse; + const message = parseMessage(json, type, input); streamingCallback({ index: 0, content: message.content, @@ -161,10 +168,8 @@ function ollamaModel( }; } else { const txtBody = await res.text(); - const json = JSON.parse(txtBody); - logger.debug(txtBody, 'ollama raw response'); - - message = parseMessage(json, type); + const json = extractJson(txtBody) as ChatResponse | GenerateResponse; + message = parseMessage(json, type, input); } return { @@ -176,29 +181,82 @@ function ollamaModel( ); } -function parseMessage(response: any, type: ApiType): MessageData { - if (response.error) { +function parseMessage( + response: ChatResponse | GenerateResponse, + type: ApiType, + input: GenerateRequest +): MessageData { + function isErrorResponse(resp: any): resp is ErrorResponse { + return 'error' in resp && typeof resp.error === 'string'; + } + + if (isErrorResponse(response)) { throw new Error(response.error); } - if (type === 'chat') { + + function isChatResponse(resp: any): resp is ChatResponse { + return 'message' in resp; + } + + function isGenerateResponse(resp: any): resp is GenerateResponse { + return 'response' in resp; + } + + // Handle JSON format if requested + if (input.output?.format === 'json' && input.output.schema) { + let rawContent; + if (isChatResponse(response)) { + try { + // Parse the content string into an object + const parsedContent = extractJson(response.message.content); + // Validate against the schema + rawContent = parsedContent; + } catch (e) { + throw new GenkitError({ + message: 'Failed to parse structured response from Ollama model', + status: 'FAILED_PRECONDITION', + }); + } + } else if (isGenerateResponse(response)) { + try { + const parsedContent = extractJson(response.response); + rawContent = parsedContent; + } catch (e) { + throw new GenkitError({ + message: 'Failed to parse structured response from Ollama model', + status: 'FAILED_PRECONDITION', + }); + } + } else { + throw new Error('Invalid response format'); + } + + return { + role: + type === 'chat' && isChatResponse(response) + ? toGenkitRole(response.message.role) + : 'model', + content: [{ text: JSON.stringify(rawContent) }], + }; + } + + // Handle regular output + if (isChatResponse(response)) { return { role: toGenkitRole(response.message.role), - content: [ - { - text: response.message.content, - }, - ], + content: [{ text: response.message.content }], }; - } else { + } else if (isGenerateResponse(response)) { return { role: 'model', - content: [ - { - text: response.response, - }, - ], + content: [{ text: response.response }], }; } + + throw new GenkitError({ + message: 'Invalid response format from Ollama model', + status: 'FAILED_PRECONDITION', + }); } function toOllamaRequest( @@ -213,6 +271,12 @@ function toOllamaRequest( options, stream, }; + + // Add format and schema if specified in output + if (input.output?.format === 'json' && input.output.schema) { + request.format = input.output.schema; + } + if (type === 'chat') { const messages: Message[] = []; input.messages.forEach((m) => { @@ -279,9 +343,3 @@ function getSystemMessage(input: GenerateRequest): string { .map((m) => m.content.map((c) => c.text).join()) .join(); } - -interface Message { - role: string; - content: string; - images?: string[]; -} diff --git a/js/plugins/ollama/src/types.ts b/js/plugins/ollama/src/types.ts index 0a3cd002f..dd074a786 100644 --- a/js/plugins/ollama/src/types.ts +++ b/js/plugins/ollama/src/types.ts @@ -15,11 +15,17 @@ */ import { GenerateRequest, z } from 'genkit'; -import { EmbedRequest } from 'ollama'; +import { + ChatResponse, + EmbedRequest, + GenerateResponse, + Message as OllamaMessage, +} from 'ollama'; + // Define possible API types export type ApiType = 'chat' | 'generate'; -// Standard model definition +// Standard model definition - removed format export interface ModelDefinition { name: string; type?: ApiType; @@ -86,4 +92,8 @@ export interface RequestHeaderFunction { // Union type for request headers, supporting both static and dynamic options export type RequestHeaders = Record | RequestHeaderFunction; -export type OllamaRole = 'assistant' | 'tool' | 'system' | 'user'; +// Use Ollama's Message type +export type { OllamaMessage as Message }; + +// Export response types from Ollama +export type { ChatResponse, GenerateResponse }; diff --git a/js/plugins/ollama/tests/embeddings_test.ts b/js/plugins/ollama/tests/embeddings_test.ts index 7a9b98d6d..861d1ec00 100644 --- a/js/plugins/ollama/tests/embeddings_test.ts +++ b/js/plugins/ollama/tests/embeddings_test.ts @@ -15,31 +15,12 @@ */ import { Genkit, genkit } from 'genkit'; import assert from 'node:assert'; -import { beforeEach, describe, it } from 'node:test'; +import { after, before, beforeEach, describe, it } from 'node:test'; import { defineOllamaEmbedder } from '../src/embeddings.js'; import { ollama } from '../src/index.js'; import { OllamaPluginParams } from '../src/types.js'; -// Mock fetch to simulate API responses -global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { - const url = typeof input === 'string' ? input : input.toString(); - if (url.includes('/api/embed')) { - if (options?.body && JSON.stringify(options.body).includes('fail')) { - return { - ok: false, - statusText: 'Internal Server Error', - json: async () => ({}), - } as Response; - } - return { - ok: true, - json: async () => ({ - embeddings: [[0.1, 0.2, 0.3]], // Example embedding values - }), - } as Response; - } - throw new Error('Unknown API endpoint'); -}; +let originalFetch: typeof fetch; describe('defineOllamaEmbedder', () => { const options: OllamaPluginParams = { @@ -48,6 +29,38 @@ describe('defineOllamaEmbedder', () => { }; let ai: Genkit; + + before(() => { + // Store original fetch + originalFetch = global.fetch; + + // Mock fetch to simulate API responses + global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { + const url = typeof input === 'string' ? input : input.toString(); + if (url.includes('/api/embed')) { + if (options?.body && JSON.stringify(options.body).includes('fail')) { + return { + ok: false, + statusText: 'Internal Server Error', + json: async () => ({}), + } as Response; + } + return { + ok: true, + json: async () => ({ + embeddings: [[0.1, 0.2, 0.3]], // Example embedding values + }), + } as Response; + } + throw new Error('Unknown API endpoint'); + }; + }); + + after(() => { + // Restore original fetch + global.fetch = originalFetch; + }); + beforeEach(() => { ai = genkit({ plugins: [ diff --git a/js/plugins/ollama/tests/streaming_live_test.ts b/js/plugins/ollama/tests/streaming_live_test.ts new file mode 100644 index 000000000..142ea0105 --- /dev/null +++ b/js/plugins/ollama/tests/streaming_live_test.ts @@ -0,0 +1,162 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { genkit } from 'genkit'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { ollama } from '../src/index.js'; +import { OllamaPluginParams } from '../src/types.js'; + +// Utility function to parse command-line arguments +function parseArgs() { + const args = process.argv.slice(2); + const serverAddress = + args.find((arg) => arg.startsWith('--server-address='))?.split('=')[1] || + 'http://localhost:11434'; + const modelName = + args.find((arg) => arg.startsWith('--model-name='))?.split('=')[1] || + 'phi3.5:latest'; + return { serverAddress, modelName }; +} + +const { serverAddress, modelName } = parseArgs(); + +describe('Ollama Streaming - Live Tests', () => { + const options: OllamaPluginParams = { + serverAddress, + models: [{ name: modelName, type: 'chat' }], + }; + + it('should stream responses in chat mode', async () => { + const ai = genkit({ + plugins: [ollama(options)], + }); + + const streamedChunks: string[] = []; + let streamingCallCount = 0; + + const response = await ai.generate({ + model: `ollama/${modelName}`, + messages: [ + { + role: 'user', + content: [{ text: 'Count from 1 to 5 slowly.' }], + }, + ], + streamingCallback: (chunk) => { + streamingCallCount++; + if (chunk.content[0].text) { + streamedChunks.push(chunk.content[0].text); + } + }, + }); + + // Verify that streaming occurred + assert.ok( + streamingCallCount > 1, + 'Streaming callback should be called multiple times' + ); + assert.ok(streamedChunks.length > 1, 'Should receive multiple chunks'); + + // Verify final response matches the accumulated streamed content + const finalText = response.message?.content[0].text; + const streamedText = streamedChunks.join(''); + assert.strictEqual( + finalText, + streamedText, + 'Final text should match accumulated streamed content' + ); + }); + + it('should stream responses with generate mode', async () => { + const ai = genkit({ + plugins: [ollama(options)], + }); + + const streamedChunks: string[] = []; + let streamingCallCount = 0; + + const response = await ai.generate({ + model: `ollama/${modelName}`, + messages: [ + { + role: 'user', + content: [{ text: 'Write a short story about a cat.' }], + }, + ], + streamingCallback: (chunk) => { + streamingCallCount++; + if (chunk.content[0].text) { + streamedChunks.push(chunk.content[0].text); + } + }, + }); + + // Verify that streaming occurred + assert.ok( + streamingCallCount > 1, + 'Streaming callback should be called multiple times' + ); + assert.ok(streamedChunks.length > 1, 'Should receive multiple chunks'); + + // Verify final response matches the accumulated streamed content + const finalText = response.message?.content[0].text; + const streamedText = streamedChunks.join(''); + assert.strictEqual( + finalText, + streamedText, + 'Final text should match accumulated streamed content' + ); + }); + + it('should handle errors during streaming', async () => { + const ai = genkit({ + plugins: [ollama(options)], + }); + + const streamedChunks: string[] = []; + + await assert.rejects( + async () => { + await ai.generate({ + model: `ollama/nonexistent-model`, + messages: [ + { + role: 'user', + content: [{ text: 'This should fail.' }], + }, + ], + streamingCallback: (chunk) => { + if (chunk.content[0].text) { + streamedChunks.push(chunk.content[0].text); + } + }, + }); + }, + (error) => { + assert(error instanceof Error); + // Check if error message indicates model not found or similar + return true; + } + ); + + // Verify no content was streamed before error + assert.strictEqual( + streamedChunks.length, + 0, + 'No content should be streamed on error' + ); + }); +}); diff --git a/js/plugins/ollama/tests/structured_response_live_test.ts b/js/plugins/ollama/tests/structured_response_live_test.ts new file mode 100644 index 000000000..2c4ddab60 --- /dev/null +++ b/js/plugins/ollama/tests/structured_response_live_test.ts @@ -0,0 +1,129 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { genkit, z } from 'genkit'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { ollama } from '../src/index.js'; +import { OllamaPluginParams } from '../src/types.js'; + +// Utility function to parse command-line arguments +function parseArgs() { + const args = process.argv.slice(2); + const serverAddress = + args.find((arg) => arg.startsWith('--server-address='))?.split('=')[1] || + 'http://localhost:11434'; + const modelName = + args.find((arg) => arg.startsWith('--model-name='))?.split('=')[1] || + 'phi3.5:latest'; + return { serverAddress, modelName }; +} + +const { serverAddress, modelName } = parseArgs(); + +// Define a schema for testing +const CountrySchema = z.object({ + name: z.string(), + capital: z.string(), + languages: z.array(z.string()), +}); + +describe('Ollama Structured Output - Live Tests', () => { + const options: OllamaPluginParams = { + serverAddress, + models: [{ name: modelName, type: 'chat' }], + }; + + it('should handle structured output in chat mode', async () => { + const ai = genkit({ + plugins: [ollama(options)], + }); + + const response = await ai.generate({ + model: `ollama/${modelName}`, + messages: [ + { + role: 'system', + content: [ + { + text: 'You are a helpful assistant that provides information about countries in a structured format.', + }, + ], + }, + { + role: 'user', + content: [{ text: 'Tell me about Canada.' }], + }, + ], + output: { + format: 'json', + schema: CountrySchema, + }, + }); + + assert.notEqual(response.message, undefined); + assert.notEqual(response.message!.content, undefined); + assert.notEqual(response.message!.content[0].text, undefined); + + const content = response.output!; + + // Validate the structure + assert(typeof content.name === 'string'); + assert(typeof content.capital === 'string'); + assert(Array.isArray(content.languages)); + content.languages.forEach((lang) => { + assert(typeof lang === 'string'); + }); + + // Log the actual response for inspection + console.log('Structured Response:', content); + }); + + it('should handle multiple requests with different schemas', async () => { + const ai = genkit({ + plugins: [ollama(options)], + }); + + const PersonSchema = z.object({ + name: z.string(), + age: z.number(), + occupation: z.string(), + }); + + const response = await ai.generate({ + model: `ollama/${modelName}`, + messages: [ + { + role: 'user', + content: [{ text: 'Create a profile for a fictional person.' }], + }, + ], + output: { + format: 'json', + schema: PersonSchema, + }, + }); + + assert.notEqual(response.message, undefined); + const content = response.output!; + + // Validate the structure + assert(typeof content.name === 'string'); + assert(typeof content.age === 'number'); + assert(typeof content.occupation === 'string'); + + console.log('Person Response:', content); + }); +}); diff --git a/js/plugins/ollama/tests/structured_response_test.ts b/js/plugins/ollama/tests/structured_response_test.ts new file mode 100644 index 000000000..828e14581 --- /dev/null +++ b/js/plugins/ollama/tests/structured_response_test.ts @@ -0,0 +1,258 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { Genkit, genkit, z } from 'genkit'; +import assert from 'node:assert'; +import { after, before, beforeEach, describe, it } from 'node:test'; + +import { ollama } from '../src/index.js'; + +// Define a sample schema for testing +const CountrySchema = z.object({ + name: z.string(), + capital: z.string(), + languages: z.array(z.string()), +}); + +// Mock response data +const sampleCountryData = { + name: 'Canada', + capital: 'Ottawa', + languages: ['English', 'French'], +}; + +describe('Ollama Structured Output', () => { + let ai: Genkit; + let originalFetch: typeof fetch; + + before(() => { + // Store the original fetch + originalFetch = global.fetch; + + // Set up mock fetch + global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { + const url = typeof input === 'string' ? input : input.toString(); + const requestBody = options?.body + ? JSON.parse(options.body as string) + : {}; + + const mockResponse = (data: any) => { + return { + ok: true, + body: { + getReader: () => ({ + read: async () => ({ done: true, value: undefined }), + }), + }, + text: async () => JSON.stringify(data), + } as Response; + }; + + if (url.includes('/api/chat')) { + // Check if format is specified in the request + if (requestBody.format) { + return mockResponse({ + model: 'llama2', + created_at: '2024-03-14T12:00:00Z', + message: { + role: 'assistant', + content: JSON.stringify(sampleCountryData), + }, + done: true, + }); + } + // Regular chat response + return mockResponse({ + model: 'llama2', + created_at: '2024-03-14T12:00:00Z', + message: { + role: 'assistant', + content: 'Some regular text response', + }, + done: true, + }); + } + + if (url.includes('/api/generate')) { + if (requestBody.format) { + return mockResponse({ + model: 'llama2', + created_at: '2024-03-14T12:00:00Z', + response: JSON.stringify(sampleCountryData), + done: true, + }); + } + return mockResponse({ + model: 'llama2', + created_at: '2024-03-14T12:00:00Z', + response: 'Some regular text response', + done: true, + }); + } + + throw new Error('Unknown API endpoint'); + }; + }); + + after(() => { + // Restore the original fetch + global.fetch = originalFetch; + }); + + beforeEach(() => { + ai = genkit({ + plugins: [ + ollama({ + serverAddress: 'http://localhost:11434', + models: [ + { + name: 'chat-model', + type: 'chat', + }, + { + name: 'generate-model', + type: 'generate', + }, + ], + }), + ], + }); + }); + + it('should handle structured output in chat mode', async () => { + const response = await ai.generate({ + model: 'ollama/chat-model', + messages: [ + { + role: 'user', + content: [{ text: 'Tell me about Canada' }], + }, + ], + output: { + format: 'json', + schema: CountrySchema, + }, + }); + + assert.notEqual(response.message, undefined); + assert.notEqual(response.message!.content, undefined); + assert.notEqual(response.message!.content[0].text, undefined); + + const content = JSON.parse(response.message!.content[0].text!); + assert.deepStrictEqual(content, sampleCountryData); + }); + + it('should handle structured output in generate mode', async () => { + const response = await ai.generate({ + model: 'ollama/generate-model', + messages: [ + { + role: 'user', + content: [{ text: 'Tell me about Canada' }], + }, + ], + output: { + format: 'json', + schema: CountrySchema, + }, + }); + + assert.notEqual(response.message, undefined); + assert.notEqual(response.message!.content, undefined); + assert.notEqual(response.message!.content[0].text, undefined); + + const content = JSON.parse(response.message!.content[0].text!); + assert.deepStrictEqual(content, sampleCountryData); + }); + + it('should handle schema validation errors', async () => { + // Override fetch for this specific test + global.fetch = async () => + ({ + ok: true, + body: { + getReader: () => ({ + read: async () => ({ done: true, value: undefined }), + }), + }, + text: async () => + JSON.stringify({ + model: 'llama2', + created_at: '2024-03-14T12:00:00Z', + message: { + role: 'assistant', + content: JSON.stringify({ invalid: 'data' }), + }, + done: true, + }), + }) as Response; + + await assert.rejects( + async () => { + await ai.generate({ + model: 'ollama/chat-model', + messages: [ + { + role: 'user', + content: [{ text: 'Tell me about Canada' }], + }, + ], + output: { + format: 'json', + schema: CountrySchema, + }, + }); + }, + (error) => { + assert(error instanceof Error); + assert(error.message.includes('Required')); + return true; + } + ); + }); + + it('should handle API errors gracefully', async () => { + // Override fetch for this specific test + global.fetch = async () => + ({ + ok: false, + statusText: 'Internal Server Error', + text: async () => 'Internal Server Error', + }) as Response; + + await assert.rejects( + async () => { + await ai.generate({ + model: 'ollama/chat-model', + messages: [ + { + role: 'user', + content: [{ text: 'Tell me about Canada' }], + }, + ], + output: { + format: 'json', + schema: CountrySchema, + }, + }); + }, + (error) => { + assert(error instanceof Error); + assert.strictEqual(error.message, 'Response has no body'); + return true; + } + ); + }); +}); diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index c1711a784..ccf3d7285 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -683,6 +683,9 @@ importers: genkit: specifier: workspace:* version: link:../../genkit + zod-to-json-schema: + specifier: ^3.22.4 + version: 3.22.5(zod@3.23.8) devDependencies: '@types/node': specifier: ^20.11.16 @@ -691,8 +694,8 @@ importers: specifier: ^4.1.5 version: 4.1.5 ollama: - specifier: ^0.5.9 - version: 0.5.9 + specifier: ^0.5.11 + version: 0.5.11 rimraf: specifier: ^6.0.1 version: 6.0.1 @@ -5413,8 +5416,8 @@ packages: resolution: {integrity: sha512-byy+U7gp+FVwmyzKPYhW2h5l3crpmGsxl7X2s8y43IgxvG4g3QZ6CffDtsNQy1WsmZpQbO+ybo0AlW7TY6DcBQ==} engines: {node: '>= 0.4'} - ollama@0.5.9: - resolution: {integrity: sha512-F/KZuDRC+ZsVCuMvcOYuQ6zj42/idzCkkuknGyyGVmNStMZ/sU3jQpvhnl4SyC0+zBzLiKNZJnJeuPFuieWZvQ==} + ollama@0.5.11: + resolution: {integrity: sha512-lDAKcpmBU3VAOGF05NcQipHNKTdpKfAHpZ7bjCsElkUkmX7SNZImi6lwIxz/l1zQtLq0S3wuLneRuiXxX2KIew==} on-finished@2.4.1: resolution: {integrity: sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==} @@ -10887,7 +10890,7 @@ snapshots: has-symbols: 1.0.3 object-keys: 1.1.1 - ollama@0.5.9: + ollama@0.5.11: dependencies: whatwg-fetch: 3.6.20 From 67edacfb5ea7cd7922e1cdbad55dd2fe031e71dd Mon Sep 17 00:00:00 2001 From: Jacob Cable Date: Thu, 12 Dec 2024 12:45:28 +0000 Subject: [PATCH 2/5] chore(js/plugins/ollama): add JSDoc comments --- js/plugins/ollama/src/embeddings.ts | 16 ++++++++++ js/plugins/ollama/src/index.ts | 36 +++++++++++++++++++--- js/plugins/ollama/src/types.ts | 48 ++++++++++++++++++++++++++--- 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/js/plugins/ollama/src/embeddings.ts b/js/plugins/ollama/src/embeddings.ts index a559c62f9..44e3ff489 100644 --- a/js/plugins/ollama/src/embeddings.ts +++ b/js/plugins/ollama/src/embeddings.ts @@ -17,6 +17,16 @@ import { Document, Genkit } from 'genkit'; import { EmbedRequest, EmbedResponse } from 'ollama'; import { DefineOllamaEmbeddingParams, RequestHeaders } from './types.js'; +/** + * Constructs an Ollama embedding request from the provided parameters. + * @param {string} modelName - The name of the Ollama model to use + * @param {number} dimensions - The number of dimensions for the embeddings + * @param {Document[]} documents - The documents to embed + * @param {string} serverAddress - The Ollama server address + * @param {RequestHeaders} [requestHeaders] - Optional headers to include with the request + * @returns {Promise<{url: string, requestPayload: EmbedRequest, headers: Record}>} The prepared request + * @private + */ async function toOllamaEmbedRequest( modelName: string, dimensions: number, @@ -59,6 +69,12 @@ async function toOllamaEmbedRequest( }; } +/** + * Defines and registers an Ollama embedder in the Genkit environment. + * @param {Genkit} ai - The Genkit instance + * @param {DefineOllamaEmbeddingParams} params - Configuration for the embedder + * @returns {Embedder} The defined Genkit embedder + */ export function defineOllamaEmbedder( ai: Genkit, { name, modelName, dimensions, options }: DefineOllamaEmbeddingParams diff --git a/js/plugins/ollama/src/index.ts b/js/plugins/ollama/src/index.ts index 124f947d3..d979bf69b 100644 --- a/js/plugins/ollama/src/index.ts +++ b/js/plugins/ollama/src/index.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Genkit, GenkitError } from 'genkit'; +import { Genkit, GenkitError, Role } from 'genkit'; import { extractJson } from 'genkit/extract'; import { logger } from 'genkit/logging'; import { @@ -39,6 +39,11 @@ import { export { defineOllamaEmbedder }; +/** + * Creates and registers a Genkit plugin for Ollama integration. + * @param {OllamaPluginParams} params - Configuration options for the Ollama plugin + * @returns {GenkitPlugin} A configured Genkit plugin for Ollama + */ export function ollama(params: OllamaPluginParams): GenkitPlugin { return genkitPlugin('ollama', async (ai: Genkit) => { const serverAddress = params.serverAddress; @@ -56,6 +61,15 @@ export function ollama(params: OllamaPluginParams): GenkitPlugin { }); } +/** + * Defines a new Ollama model in the Genkit registry. + * @param {Genkit} ai - The Genkit instance + * @param {ModelDefinition} model - The model configuration + * @param {string} serverAddress - The Ollama server address + * @param {RequestHeaders} [requestHeaders] - Optional headers to include with requests + * @returns {Model} The defined Genkit model + * @private + */ function ollamaModel( ai: Genkit, model: ModelDefinition, @@ -304,18 +318,30 @@ function toOllamaRequest( return request; } -function toOllamaRole(role) { +/** + * Converts a Genkit role to the corresponding Ollama role. + * @param {string} role - The Genkit role to convert + * @returns {string} The corresponding Ollama role + * @private + */ +function toOllamaRole(role: string) { if (role === 'model') { return 'assistant'; } return role; // everything else seems to match } -function toGenkitRole(role) { +/** + * Converts an Ollama role to the corresponding Genkit role. + * @param {string} role - The Ollama role to convert + * @returns {Role} The corresponding Genkit role + * @private + */ +function toGenkitRole(role: string): Role { if (role === 'assistant') { - return 'model'; + return 'model' as Role; } - return role; // everything else seems to match + return role as Role; // everything else seems to match } function readChunks(reader) { diff --git a/js/plugins/ollama/src/types.ts b/js/plugins/ollama/src/types.ts index dd074a786..f7e527649 100644 --- a/js/plugins/ollama/src/types.ts +++ b/js/plugins/ollama/src/types.ts @@ -22,16 +22,29 @@ import { Message as OllamaMessage, } from 'ollama'; -// Define possible API types +/** + * Represents the type of API endpoint to use when communicating with Ollama. + * Can be either 'chat' for conversational models or 'generate' for completion models. + */ export type ApiType = 'chat' | 'generate'; -// Standard model definition - removed format +/** + * Configuration for defining an Ollama model. + * @interface ModelDefinition + * @property {string} name - The name of the Ollama model to use + * @property {ApiType} [type] - Optional API type to use. Defaults to 'chat' if not specified + */ export interface ModelDefinition { name: string; type?: ApiType; } -// Definition for embedding models +/** + * Configuration for defining an Ollama embedding model. + * @interface EmbeddingModelDefinition + * @property {string} name - The name of the Ollama embedding model + * @property {number} dimensions - The number of dimensions in the embedding output + */ export interface EmbeddingModelDefinition { name: string; dimensions: number; @@ -45,6 +58,14 @@ export type OllamaEmbeddingPrediction = z.infer< typeof OllamaEmbeddingPredictionSchema >; +/** + * Parameters for defining an Ollama embedder. + * @interface DefineOllamaEmbeddingParams + * @property {string} name - The name to use for the embedder + * @property {string} modelName - The name of the Ollama model to use + * @property {number} dimensions - The number of dimensions in the embedding output + * @property {OllamaPluginParams} options - Configuration options for the embedder + */ export interface DefineOllamaEmbeddingParams { name: string; modelName: string; @@ -52,7 +73,14 @@ export interface DefineOllamaEmbeddingParams { options: OllamaPluginParams; } -// Parameters for the Ollama plugin configuration +/** + * Configuration options for the Ollama plugin. + * @interface OllamaPluginParams + * @property {ModelDefinition[]} [models] - Array of model definitions to register + * @property {EmbeddingModelDefinition[]} [embedders] - Array of embedding model definitions to register + * @property {string} serverAddress - The base URL of the Ollama server + * @property {RequestHeaders} [requestHeaders] - Optional headers to include with requests + */ export interface OllamaPluginParams { /** * Array of models to be defined. @@ -75,7 +103,17 @@ export interface OllamaPluginParams { requestHeaders?: RequestHeaders; } -// Function type for generating request headers dynamically +/** + * Function for dynamically generating request headers. + * @callback RequestHeaderFunction + * @param {Object} params - Parameters for generating headers + * @param {string} params.serverAddress - The Ollama server address + * @param {ModelDefinition | EmbeddingModelDefinition} params.model - The model being used + * @param {GenerateRequest} [params.modelRequest] - The generation request (if applicable) + * @param {EmbedRequest} [params.embedRequest] - The embedding request (if applicable) + * @param {GenerateRequest} [modelRequest] - @deprecated Legacy parameter for backwards compatibility + * @returns {Promise | void>} The headers to include with the request + */ export interface RequestHeaderFunction { ( params: { From 225aa6eac667094da91473f4889b93f4013a67ac Mon Sep 17 00:00:00 2001 From: Jacob Cable Date: Thu, 12 Dec 2024 12:49:35 +0000 Subject: [PATCH 3/5] chore(js/plugins/ollama): remove unused dependency --- js/plugins/ollama/package.json | 3 --- js/pnpm-lock.yaml | 3 --- 2 files changed, 6 deletions(-) diff --git a/js/plugins/ollama/package.json b/js/plugins/ollama/package.json index 290f6f92a..485680bf0 100644 --- a/js/plugins/ollama/package.json +++ b/js/plugins/ollama/package.json @@ -48,8 +48,5 @@ "types": "./lib/index.d.ts", "default": "./lib/index.js" } - }, - "dependencies": { - "zod-to-json-schema": "^3.22.4" } } diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index ccf3d7285..ac2b223b5 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -683,9 +683,6 @@ importers: genkit: specifier: workspace:* version: link:../../genkit - zod-to-json-schema: - specifier: ^3.22.4 - version: 3.22.5(zod@3.23.8) devDependencies: '@types/node': specifier: ^20.11.16 From 86449b4cca1a87e64fa9e5b6e9a7cfd7f75f2aac Mon Sep 17 00:00:00 2001 From: Jacob Cable Date: Thu, 12 Dec 2024 13:02:54 +0000 Subject: [PATCH 4/5] refactor(js/plugins/ollama): clean up parseMessage --- js/plugins/ollama/src/index.ts | 102 ++++++++++++--------------------- 1 file changed, 36 insertions(+), 66 deletions(-) diff --git a/js/plugins/ollama/src/index.ts b/js/plugins/ollama/src/index.ts index d979bf69b..0825cbfce 100644 --- a/js/plugins/ollama/src/index.ts +++ b/js/plugins/ollama/src/index.ts @@ -165,7 +165,7 @@ function ollamaModel( const json = extractJson(chunkText) as | ChatResponse | GenerateResponse; - const message = parseMessage(json, type, input); + const message = parseMessage(json, type); streamingCallback({ index: 0, content: message.content, @@ -183,7 +183,7 @@ function ollamaModel( } else { const txtBody = await res.text(); const json = extractJson(txtBody) as ChatResponse | GenerateResponse; - message = parseMessage(json, type, input); + message = parseMessage(json, type); } return { @@ -195,82 +195,52 @@ function ollamaModel( ); } +/** + * Parses the Ollama response into a standardized MessageData format. + * @param {ChatResponse | GenerateResponse} response - The raw response from Ollama + * @param {ApiType} type - The type of API used (chat or generate) + * @returns {MessageData} The parsed message data + * @throws {GenkitError} If the response format is invalid or parsing fails + */ function parseMessage( response: ChatResponse | GenerateResponse, - type: ApiType, - input: GenerateRequest + _type: ApiType ): MessageData { - function isErrorResponse(resp: any): resp is ErrorResponse { - return 'error' in resp && typeof resp.error === 'string'; - } + // Type guards + const isErrorResponse = (resp: any): resp is ErrorResponse => + 'error' in resp && typeof resp.error === 'string'; + const isChatResponse = (resp: any): resp is ChatResponse => 'message' in resp; + const isGenerateResponse = (resp: any): resp is GenerateResponse => + 'response' in resp; + // Handle error responses first if (isErrorResponse(response)) { throw new Error(response.error); } - function isChatResponse(resp: any): resp is ChatResponse { - return 'message' in resp; - } + // Get the text content based on response type + const content = isChatResponse(response) + ? response.message.content + : isGenerateResponse(response) + ? response.response + : null; - function isGenerateResponse(resp: any): resp is GenerateResponse { - return 'response' in resp; - } - - // Handle JSON format if requested - if (input.output?.format === 'json' && input.output.schema) { - let rawContent; - if (isChatResponse(response)) { - try { - // Parse the content string into an object - const parsedContent = extractJson(response.message.content); - // Validate against the schema - rawContent = parsedContent; - } catch (e) { - throw new GenkitError({ - message: 'Failed to parse structured response from Ollama model', - status: 'FAILED_PRECONDITION', - }); - } - } else if (isGenerateResponse(response)) { - try { - const parsedContent = extractJson(response.response); - rawContent = parsedContent; - } catch (e) { - throw new GenkitError({ - message: 'Failed to parse structured response from Ollama model', - status: 'FAILED_PRECONDITION', - }); - } - } else { - throw new Error('Invalid response format'); - } - - return { - role: - type === 'chat' && isChatResponse(response) - ? toGenkitRole(response.message.role) - : 'model', - content: [{ text: JSON.stringify(rawContent) }], - }; + if (content === null) { + throw new GenkitError({ + message: 'Invalid response format from Ollama model', + status: 'FAILED_PRECONDITION', + }); } - // Handle regular output - if (isChatResponse(response)) { - return { - role: toGenkitRole(response.message.role), - content: [{ text: response.message.content }], - }; - } else if (isGenerateResponse(response)) { - return { - role: 'model', - content: [{ text: response.response }], - }; - } + // Determine role for chat responses, default to 'model' + const role = isChatResponse(response) + ? toGenkitRole(response.message.role) + : 'model'; - throw new GenkitError({ - message: 'Invalid response format from Ollama model', - status: 'FAILED_PRECONDITION', - }); + return { + role, + content: [{ text: content }], + }; } function toOllamaRequest( From 0050ae6fffa399c6ffe83bd8d7bf0c55b430135c Mon Sep 17 00:00:00 2001 From: Jacob Cable Date: Fri, 13 Dec 2024 12:52:10 +0000 Subject: [PATCH 5/5] test(js/plugins/ollama): add structured response to testapp --- js/testapps/ollama/src/index.ts | 182 ++++++++++++++++++++------------ 1 file changed, 112 insertions(+), 70 deletions(-) diff --git a/js/testapps/ollama/src/index.ts b/js/testapps/ollama/src/index.ts index 3e625f370..2ed939a1a 100644 --- a/js/testapps/ollama/src/index.ts +++ b/js/testapps/ollama/src/index.ts @@ -15,7 +15,32 @@ */ import { genkit, z } from 'genkit'; +import { logger } from 'genkit/logging'; import { ollama } from 'genkitx-ollama'; +// Define our schemas upfront for better type safety and documentation +const PokemonSchema = z.object({ + name: z.string(), + description: z.string(), + type: z.array( + z.enum(['Electric', 'Fire', 'Grass', 'Poison', 'Water', 'Normal', 'Fairy']) + ), + stats: z.object({ + attack: z.number(), + defense: z.number(), + speed: z.number(), + }), +}); + +const QueryResponseSchema = z.object({ + matchedPokemon: z.array(z.string()), + analysis: z.object({ + answer: z.string(), + confidence: z.number().min(0).max(1), + relevantTypes: z.array(z.string()), + }), +}); + +type Pokemon = z.infer; const ai = genkit({ plugins: [ @@ -25,7 +50,6 @@ const ai = genkit({ models: [{ name: 'phi3.5:latest' }], requestHeaders: async (params) => { console.log('Using server address', params.serverAddress); - // Simulate a token-based authentication await new Promise((resolve) => setTimeout(resolve, 200)); return { Authorization: 'Bearer my-token' }; }, @@ -33,108 +57,126 @@ const ai = genkit({ ], }); -interface PokemonInfo { - name: string; - description: string; - embedding: number[] | null; -} - -const pokemonList: PokemonInfo[] = [ +// Enhanced Pokemon database with more structured data +const pokemonDatabase: Pokemon[] = [ { name: 'Pikachu', description: 'An Electric-type Pokemon known for its strong electric attacks.', - embedding: null, + type: ['Electric'], + stats: { attack: 55, defense: 40, speed: 90 }, }, { name: 'Charmander', description: 'A Fire-type Pokemon that evolves into the powerful Charizard.', - embedding: null, + type: ['Fire'], + stats: { attack: 52, defense: 43, speed: 65 }, }, { name: 'Bulbasaur', description: 'A Grass/Poison-type Pokemon that grows into a powerful Venusaur.', - embedding: null, + type: ['Grass', 'Poison'], + stats: { attack: 49, defense: 49, speed: 45 }, }, { name: 'Squirtle', description: 'A Water-type Pokemon known for its water-based attacks and high defense.', - embedding: null, + type: ['Water'], + stats: { attack: 48, defense: 65, speed: 43 }, }, { name: 'Jigglypuff', description: 'A Normal/Fairy-type Pokemon with a hypnotic singing ability.', - embedding: null, + type: ['Normal', 'Fairy'], + stats: { attack: 45, defense: 20, speed: 20 }, }, ]; -// Step 1: Embed each Pokemon's description -async function embedPokemon() { - for (const pokemon of pokemonList) { - pokemon.embedding = await ai.embed({ - embedder: 'ollama/nomic-embed-text', - content: pokemon.description, - }); - } -} +export const pokemonFlow = ai.defineFlow( + { + name: 'Pokedex', + inputSchema: z.object({ + question: z.string(), + maxResults: z.number().default(3).optional(), + }), + outputSchema: QueryResponseSchema, + }, + async ({ question, maxResults = 3 }) => { + // Embed the question and all Pokemon descriptions in parallel + const [questionEmbedding, pokemonEmbeddings] = await Promise.all([ + ai.embed({ + embedder: 'ollama/nomic-embed-text', + content: question, + }), + Promise.all( + pokemonDatabase.map((pokemon) => + ai.embed({ + embedder: 'ollama/nomic-embed-text', + content: pokemon.description, + }) + ) + ), + ]); -// Step 2: Find top 3 Pokemon closest to the input -function findNearestPokemon(inputEmbedding: number[], topN = 3) { - if (pokemonList.some((pokemon) => pokemon.embedding === null)) - throw new Error('Some Pokemon are not yet embedded'); - const distances = pokemonList.map((pokemon) => ({ - pokemon, - distance: cosineDistance(inputEmbedding, pokemon.embedding!), - })); - return distances - .sort((a, b) => a.distance - b.distance) - .slice(0, topN) - .map((entry) => entry.pokemon); -} + // Calculate similarity and sort Pokemon by relevance + const similarityScores = pokemonEmbeddings.map((embedding, index) => ({ + pokemon: pokemonDatabase[index], + similarity: 1 - cosineSimilarity(questionEmbedding, embedding), + })); -// Helper function: cosine distance calculation -function cosineDistance(a: number[], b: number[]) { - const dotProduct = a.reduce((sum, ai, i) => sum + ai * b[i], 0); - const magnitudeA = Math.sqrt(a.reduce((sum, ai) => sum + ai * ai, 0)); - const magnitudeB = Math.sqrt(b.reduce((sum, bi) => sum + bi * bi, 0)); - if (magnitudeA === 0 || magnitudeB === 0) - throw new Error('Invalid input: zero vector'); - return 1 - dotProduct / (magnitudeA * magnitudeB); -} + const topPokemon = similarityScores + .sort((a, b) => b.similarity - a.similarity) + .slice(0, maxResults); -// Step 3: Generate response with RAG results in context -async function generateResponse(question: string) { - const inputEmbedding = await ai.embed({ - embedder: 'ollama/nomic-embed-text', - content: question, - }); + logger.info('Top Pokemon:', topPokemon); - const nearestPokemon = findNearestPokemon(inputEmbedding); - const pokemonContext = nearestPokemon - .map((pokemon) => `${pokemon.name}: ${pokemon.description}`) - .join('\n'); + // Build context for the LLM + const context = topPokemon.map(({ pokemon }) => ({ + name: pokemon.name, + description: pokemon.description, + type: pokemon.type, + stats: pokemon.stats, + })); - return await ai.generate({ - model: 'ollama/phi3.5:latest', - prompt: `Given the following context on Pokemon:\n${pokemonContext}\n\nQuestion: ${question}\n\nAnswer:`, - }); -} + // Generate structured response using the LLM + const response = await ai.generate({ + model: 'ollama/phi3.5:latest', + prompt: `Given these Pokemon details: +${JSON.stringify(context, null, 2)} -export const pokemonFlow = ai.defineFlow( - { - name: 'Pokedex', - inputSchema: z.string(), - outputSchema: z.string(), - }, - async (input) => { - await embedPokemon(); - const response = await generateResponse(input); +Answer this question: ${question} - const answer = response.text; +Format your response as JSON with: +- matchedPokemon: names of relevant Pokemon +- analysis.answer: detailed explanation +- analysis.confidence: score from 0-1 indicating certainty +- analysis.relevantTypes: Pokemon types mentioned in answer`, + output: { + format: 'json', + schema: QueryResponseSchema, + }, + }); - return answer; + return ( + response.output || { + matchedPokemon: [], + analysis: { + answer: '', + confidence: 0, + relevantTypes: [], + }, + } + ); } ); + +// Helper function for vector similarity +function cosineSimilarity(a: number[], b: number[]): number { + const dotProduct = a.reduce((sum, ai, i) => sum + ai * b[i], 0); + const magnitudeA = Math.sqrt(a.reduce((sum, ai) => sum + ai * ai, 0)); + const magnitudeB = Math.sqrt(b.reduce((sum, bi) => sum + bi * bi, 0)); + return dotProduct / (magnitudeA * magnitudeB); +}