Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(js/plugins/ollama): add support for structured responses #1501

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions docs/plugins/ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,59 @@ async function getEmbedding() {

getEmbedding().then((e) => console.log(e))
```

## Structured Output
The Ollama plugin supports structured output through JSON schema validation. This allows you to specify the exact structure of the response you expect from the model.
### Basic Usage
Here's a simple example using Zod to define a schema and get structured output:
```ts
import { genkit, z } from 'genkit';
import { ollama } from 'genkitx-ollama';

// Define your schema using Zod
const CountrySchema = z.object({
name: z.string(),
capital: z.string(),
languages: z.array(z.string()),
});

// Initialize Genkit with Ollama plugin
const ai = genkit({
plugins: [
ollama({
models: [{ name: 'your-model', type: 'chat' }],
serverAddress: 'http://localhost:11434',
}),
],
});

// ...
// ...

// Use structured output in your request
const llmResponse = await ai.generate({
model: 'ollama/your-model',
messages: [
{
role: 'system',
content: [
{
text: 'You are a helpful assistant that provides information about countries in a structured format.',
},
],
},
{
role: 'user',
content: [{ text: 'Tell me about Canada.' }],
},
],
output: {
format: 'json',
schema: CountrySchema,
},
});

// Now output should be typed correectly:
const {name, capital, languages} = llmResponse.output;

```
2 changes: 1 addition & 1 deletion js/plugins/ollama/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"devDependencies": {
"@types/node": "^20.11.16",
"npm-run-all": "^4.1.5",
"ollama": "^0.5.9",
"ollama": "^0.5.11",
"rimraf": "^6.0.1",
"tsup": "^8.3.5",
"tsx": "^4.19.2",
Expand Down
16 changes: 16 additions & 0 deletions js/plugins/ollama/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ import { Document, Genkit } from 'genkit';
import { EmbedRequest, EmbedResponse } from 'ollama';
import { DefineOllamaEmbeddingParams, RequestHeaders } from './types.js';

/**
* Constructs an Ollama embedding request from the provided parameters.
* @param {string} modelName - The name of the Ollama model to use
* @param {number} dimensions - The number of dimensions for the embeddings
* @param {Document[]} documents - The documents to embed
* @param {string} serverAddress - The Ollama server address
* @param {RequestHeaders} [requestHeaders] - Optional headers to include with the request
* @returns {Promise<{url: string, requestPayload: EmbedRequest, headers: Record<string, string>}>} The prepared request
* @private
*/
async function toOllamaEmbedRequest(
modelName: string,
dimensions: number,
Expand Down Expand Up @@ -59,6 +69,12 @@ async function toOllamaEmbedRequest(
};
}

/**
* Defines and registers an Ollama embedder in the Genkit environment.
* @param {Genkit} ai - The Genkit instance
* @param {DefineOllamaEmbeddingParams} params - Configuration for the embedder
* @returns {Embedder} The defined Genkit embedder
*/
export function defineOllamaEmbedder(
ai: Genkit,
{ name, modelName, dimensions, options }: DefineOllamaEmbeddingParams
Expand Down
126 changes: 90 additions & 36 deletions js/plugins/ollama/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,36 @@
* limitations under the License.
*/

import { Genkit } from 'genkit';
import { Genkit, GenkitError, Role } from 'genkit';
import { extractJson } from 'genkit/extract';
import { logger } from 'genkit/logging';
import {
GenerateRequest,
GenerateResponseData,
GenerationCommonConfigSchema,
getBasicUsageStats,
MessageData,
getBasicUsageStats,
} from 'genkit/model';
import { GenkitPlugin, genkitPlugin } from 'genkit/plugin';
import { ErrorResponse } from 'ollama';
import { defineOllamaEmbedder } from './embeddings.js';
import {
ApiType,
ChatResponse,
GenerateResponse,
Message,
ModelDefinition,
RequestHeaders,
type OllamaPluginParams,
} from './types.js';

export { type OllamaPluginParams };

/**
* Creates and registers a Genkit plugin for Ollama integration.
* @param {OllamaPluginParams} params - Configuration options for the Ollama plugin
* @returns {GenkitPlugin} A configured Genkit plugin for Ollama
*/
export function ollama(params: OllamaPluginParams): GenkitPlugin {
return genkitPlugin('ollama', async (ai: Genkit) => {
const serverAddress = params.serverAddress;
Expand All @@ -51,6 +61,15 @@ export function ollama(params: OllamaPluginParams): GenkitPlugin {
});
}

/**
* Defines a new Ollama model in the Genkit registry.
* @param {Genkit} ai - The Genkit instance
* @param {ModelDefinition} model - The model configuration
* @param {string} serverAddress - The Ollama server address
* @param {RequestHeaders} [requestHeaders] - Optional headers to include with requests
* @returns {Model} The defined Genkit model
* @private
*/
function ollamaModel(
ai: Genkit,
model: ModelDefinition,
Expand Down Expand Up @@ -143,7 +162,9 @@ function ollamaModel(
let textResponse = '';
for await (const chunk of readChunks(reader)) {
const chunkText = textDecoder.decode(chunk);
const json = JSON.parse(chunkText);
const json = extractJson(chunkText) as
| ChatResponse
| GenerateResponse;
const message = parseMessage(json, type);
streamingCallback({
index: 0,
Expand All @@ -161,9 +182,7 @@ function ollamaModel(
};
} else {
const txtBody = await res.text();
const json = JSON.parse(txtBody);
logger.debug(txtBody, 'ollama raw response');

const json = extractJson(txtBody) as ChatResponse | GenerateResponse;
message = parseMessage(json, type);
}

Expand All @@ -176,29 +195,52 @@ function ollamaModel(
);
}

function parseMessage(response: any, type: ApiType): MessageData {
if (response.error) {
/**
* Parses the Ollama response into a standardized MessageData format.
* @param {ChatResponse | GenerateResponse} response - The raw response from Ollama
* @param {ApiType} type - The type of API used (chat or generate)
* @returns {MessageData} The parsed message data
* @throws {GenkitError} If the response format is invalid or parsing fails
*/
function parseMessage(
response: ChatResponse | GenerateResponse,
_type: ApiType
): MessageData {
// Type guards
const isErrorResponse = (resp: any): resp is ErrorResponse =>
'error' in resp && typeof resp.error === 'string';
const isChatResponse = (resp: any): resp is ChatResponse => 'message' in resp;
const isGenerateResponse = (resp: any): resp is GenerateResponse =>
'response' in resp;

// Handle error responses first
if (isErrorResponse(response)) {
throw new Error(response.error);
}
if (type === 'chat') {
return {
role: toGenkitRole(response.message.role),
content: [
{
text: response.message.content,
},
],
};
} else {
return {
role: 'model',
content: [
{
text: response.response,
},
],
};

// Get the text content based on response type
const content = isChatResponse(response)
? response.message.content
: isGenerateResponse(response)
? response.response
: null;

if (content === null) {
throw new GenkitError({
message: 'Invalid response format from Ollama model',
status: 'FAILED_PRECONDITION',
});
}

// Determine role for chat responses, default to 'model'
const role = isChatResponse(response)
? toGenkitRole(response.message.role)
: 'model';

return {
role,
content: [{ text: content }],
};
}

function toOllamaRequest(
Expand All @@ -213,6 +255,12 @@ function toOllamaRequest(
options,
stream,
};

// Add format and schema if specified in output
if (input.output?.format === 'json' && input.output.schema) {
request.format = input.output.schema;
}

if (type === 'chat') {
const messages: Message[] = [];
input.messages.forEach((m) => {
Expand Down Expand Up @@ -240,18 +288,30 @@ function toOllamaRequest(
return request;
}

function toOllamaRole(role) {
/**
* Converts a Genkit role to the corresponding Ollama role.
* @param {string} role - The Genkit role to convert
* @returns {string} The corresponding Ollama role
* @private
*/
function toOllamaRole(role: string) {
if (role === 'model') {
return 'assistant';
}
return role; // everything else seems to match
}

function toGenkitRole(role) {
/**
* Converts an Ollama role to the corresponding Genkit role.
* @param {string} role - The Ollama role to convert
* @returns {Role} The corresponding Genkit role
* @private
*/
function toGenkitRole(role: string): Role {
if (role === 'assistant') {
return 'model';
return 'model' as Role;
}
return role; // everything else seems to match
return role as Role; // everything else seems to match
}

function readChunks(reader) {
Expand Down Expand Up @@ -279,9 +339,3 @@ function getSystemMessage(input: GenerateRequest): string {
.map((m) => m.content.map((c) => c.text).join())
.join();
}

interface Message {
role: string;
content: string;
images?: string[];
}
Loading
Loading