From a9e63ed0a43524e1db4ad4f2069582d29553ff1e Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 19 Dec 2024 13:57:32 -0500 Subject: [PATCH] fix(js): make sure middleware is applied by prompts (#1534) --- js/ai/src/prompt.ts | 9 +- js/genkit/src/genkit.ts | 18 +- js/genkit/tests/prompts_test.ts | 244 +++++++++++++++++++++- js/plugins/dotprompt/src/index.ts | 6 +- js/plugins/dotprompt/src/metadata.ts | 4 + js/plugins/dotprompt/src/prompt.ts | 15 +- js/plugins/dotprompt/tests/prompt_test.ts | 8 +- 7 files changed, 292 insertions(+), 12 deletions(-) diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index 4570fa57c..2e47d534d 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -27,6 +27,7 @@ import { GenerateRequestSchema, GenerateResponseChunkSchema, ModelArgument, + ModelMiddleware, } from './model.js'; import { ToolAction } from './tool.js'; @@ -51,6 +52,7 @@ export type PromptAction = Action< type: 'prompt'; }; }; + __config: PromptConfig; }; /** @@ -62,6 +64,7 @@ export interface PromptConfig { inputSchema?: I; inputJsonSchema?: JSONSchema7; metadata?: Record; + use?: ModelMiddleware[]; } /** @@ -147,12 +150,16 @@ export function definePrompt( const a = defineAction( registry, { - ...config, + name: config.name, + inputJsonSchema: config.inputJsonSchema, + inputSchema: config.inputSchema, + description: config.description, actionType: 'prompt', metadata: { ...(config.metadata || { prompt: {} }), type: 'prompt' }, }, fn ); + (a as PromptAction).__config = config; return a as PromptAction; } diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 1601e7539..d212afdc9 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -131,6 +131,7 @@ import { loadPromptFolder, prompt, toFrontmatter, + type DotpromptAction, } from '@genkit-ai/dotprompt'; import { v4 as uuidv4 } from 'uuid'; import { BaseEvalDataPointSchema } from './evaluator.js'; @@ -463,13 +464,13 @@ export class Genkit implements HasRegistry { ); return this.wrapPromptActionInExecutablePrompt( dotprompt.promptAction! as PromptAction, - options, - dotprompt + options ); } else { const p = definePrompt( this.registry, { + ...options, name: options.name!, description: options.description, inputJsonSchema: options.input?.jsonSchema, @@ -513,8 +514,7 @@ export class Genkit implements HasRegistry { promptAction: PromptAction | Promise>, options: | Partial> - | Promise>>, - dotprompt?: Dotprompt> + | Promise>> ): ExecutablePrompt { const executablePrompt = async ( input?: z.infer, @@ -558,6 +558,9 @@ export class Genkit implements HasRegistry { const p = await promptAction; // If it's a dotprompt template, we invoke dotprompt template directly // because it can take in more PromptGenerateOptions (not just inputs). + const dotprompt: Dotprompt> | undefined = ( + p as DotpromptAction> + ).__dotprompt; const promptResult = await (dotprompt ? dotprompt.render(opt) : p(opt.input)); @@ -584,6 +587,13 @@ export class Genkit implements HasRegistry { }, model, } as GenerateOptions; + if ((promptResult as GenerateOptions).use) { + resultOptions.use = (promptResult as GenerateOptions).use; + } else if (p.__config?.use) { + resultOptions.use = p.__config?.use; + } else if (opt.use) { + resultOptions.use = opt.use; + } delete (resultOptions as any).input; if ((promptResult as GenerateOptions).prompt) { resultOptions.prompt = (promptResult as GenerateOptions).prompt; diff --git a/js/genkit/tests/prompts_test.ts b/js/genkit/tests/prompts_test.ts index 8a2b1e2c5..bc6a2a240 100644 --- a/js/genkit/tests/prompts_test.ts +++ b/js/genkit/tests/prompts_test.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { modelRef } from '@genkit-ai/ai/model'; +import { ModelMiddleware, modelRef } from '@genkit-ai/ai/model'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { Genkit, genkit } from '../src/genkit'; @@ -26,6 +26,168 @@ import { defineStaticResponseModel, } from './helpers'; +const wrapRequest: ModelMiddleware = async (req, next) => { + return next({ + ...req, + messages: [ + { + role: 'user', + content: [ + { + text: + '(' + + req.messages + .map((m) => m.content.map((c) => c.text).join()) + .join() + + ')', + }, + ], + }, + ], + }); +}; +const wrapResponse: ModelMiddleware = async (req, next) => { + const res = await next(req); + return { + message: { + role: 'model', + content: [ + { + text: '[' + res.message!.content.map((c) => c.text).join() + ']', + }, + ], + }, + finishReason: res.finishReason, + }; +}; + +describe('definePrompt - functional', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({ + model: 'echoModel', + }); + defineEchoModel(ai); + }); + + it('should apply middleware to a prompt call', async () => { + const prompt = ai.definePrompt( + { + name: 'hi', + input: { + schema: z.object({ + name: z.string(), + }), + }, + }, + async (input) => { + return { + messages: [ + { + role: 'user', + content: [{ text: `hi ${input.name}` }], + }, + ], + }; + } + ); + + const response = await prompt( + { name: 'Genkit' }, + { use: [wrapRequest, wrapResponse] } + ); + assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]'); + }); + + it.only('should apply middleware configured on a prompt', async () => { + const prompt = ai.definePrompt( + { + name: 'hi', + input: { + schema: z.object({ + name: z.string(), + }), + }, + use: [wrapRequest, wrapResponse], + }, + async (input) => { + return { + messages: [ + { + role: 'user', + content: [{ text: `hi ${input.name}` }], + }, + ], + }; + } + ); + + const response = await prompt({ name: 'Genkit' }); + assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]'); + }); + + it.only('should apply middleware to a prompt call on a looked up prompt', async () => { + ai.definePrompt( + { + name: 'hi', + input: { + schema: z.object({ + name: z.string(), + }), + }, + use: [wrapRequest, wrapResponse], + }, + async (input) => { + return { + messages: [ + { + role: 'user', + content: [{ text: `hi ${input.name}` }], + }, + ], + }; + } + ); + + const prompt = ai.prompt('hi'); + + const response = await prompt({ name: 'Genkit' }); + assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]'); + }); + + it('should apply middleware configured on a prompt on a looked up prompt', async () => { + ai.definePrompt( + { + name: 'hi', + input: { + schema: z.object({ + name: z.string(), + }), + }, + }, + async (input) => { + return { + messages: [ + { + role: 'user', + content: [{ text: `hi ${input.name}` }], + }, + ], + }; + } + ); + + const prompt = ai.prompt('hi'); + + const response = await prompt( + { name: 'Genkit' }, + { use: [wrapRequest, wrapResponse] } + ); + assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]'); + }); +}); + describe('definePrompt - dotprompt', () => { describe('default model', () => { let ai: Genkit; @@ -95,6 +257,86 @@ describe('definePrompt - dotprompt', () => { const response = await hi({ name: 'Genkit' }); assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); + + it('should apply middleware to a prompt call', async () => { + const prompt = ai.definePrompt( + { + name: 'hi', + input: { + schema: z.object({ + name: z.string(), + }), + }, + }, + 'hi {{ name }}' + ); + + const response = await prompt( + { name: 'Genkit' }, + { use: [wrapRequest, wrapResponse] } + ); + assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]'); + }); + + it('should apply middleware configured on a prompt', async () => { + const prompt = ai.definePrompt( + { + name: 'hi', + input: { + schema: z.object({ + name: z.string(), + }), + }, + use: [wrapRequest, wrapResponse], + }, + 'hi {{ name }}' + ); + + const response = await prompt({ name: 'Genkit' }); + assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]'); + }); + + it.only('should apply middleware to a prompt call on a looked up prompt', async () => { + ai.definePrompt( + { + name: 'hi', + input: { + schema: z.object({ + name: z.string(), + }), + }, + use: [wrapRequest, wrapResponse], + }, + 'hi {{ name }}' + ); + + const prompt = ai.prompt('hi'); + + const response = await prompt({ name: 'Genkit' }); + assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]'); + }); + + it.only('should apply middleware configured on a prompt on a looked up prompt', async () => { + ai.definePrompt( + { + name: 'hi', + input: { + schema: z.object({ + name: z.string(), + }), + }, + }, + 'hi {{ name }}' + ); + + const prompt = ai.prompt('hi'); + + const response = await prompt( + { name: 'Genkit' }, + { use: [wrapRequest, wrapResponse] } + ); + assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]'); + }); }); describe('default model ref', () => { diff --git a/js/plugins/dotprompt/src/index.ts b/js/plugins/dotprompt/src/index.ts index d68d20c64..ee07ec62a 100644 --- a/js/plugins/dotprompt/src/index.ts +++ b/js/plugins/dotprompt/src/index.ts @@ -19,9 +19,10 @@ import { readFileSync } from 'fs'; import { basename } from 'path'; import { toFrontmatter } from './metadata.js'; import { - defineDotprompt, Dotprompt, DotpromptRef, + defineDotprompt, + type DotpromptAction, type PromptGenerateOptions, } from './prompt.js'; import { loadPromptFolder, lookupPrompt } from './registry.js'; @@ -29,10 +30,11 @@ import { loadPromptFolder, lookupPrompt } from './registry.js'; export { type PromptMetadata } from './metadata.js'; export { defineHelper, definePartial } from './template.js'; export { - defineDotprompt, Dotprompt, + defineDotprompt, loadPromptFolder, toFrontmatter, + type DotpromptAction, type PromptGenerateOptions, }; diff --git a/js/plugins/dotprompt/src/metadata.ts b/js/plugins/dotprompt/src/metadata.ts index 3917fda9b..686d0b211 100644 --- a/js/plugins/dotprompt/src/metadata.ts +++ b/js/plugins/dotprompt/src/metadata.ts @@ -22,6 +22,7 @@ import { GenerationCommonConfigSchema, ModelArgument, + ModelMiddleware, } from '@genkit-ai/ai/model'; import { ToolArgument } from '@genkit-ai/ai/tool'; import { z } from '@genkit-ai/core'; @@ -79,6 +80,9 @@ export interface PromptMetadata< /** Arbitrary metadata to be used by code, tools, and libraries. */ metadata?: Record; + + /** Middleware to be used with this model call. */ + use?: ModelMiddleware[]; } /** diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index 52dbdfcff..83bbeec5c 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -48,6 +48,10 @@ import { compile } from './template.js'; export type PromptData = PromptFrontmatter & { template: string }; +export type DotpromptAction = PromptAction & { + __dotprompt: Dotprompt; +}; + export type PromptGenerateOptions< V = unknown, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, @@ -78,6 +82,7 @@ export class Dotprompt implements PromptMetadata { output?: PromptMetadata['output']; tools?: PromptMetadata['tools']; config?: PromptMetadata['config']; + use?: PromptMetadata['use']; private _promptAction?: PromptAction; @@ -143,6 +148,7 @@ export class Dotprompt implements PromptMetadata { this.output = options.output; this.tools = options.tools; this.config = options.config; + this.use = options.use; this.template = template; this.hash = createHash('sha256').update(JSON.stringify(this)).digest('hex'); @@ -216,6 +222,7 @@ export class Dotprompt implements PromptMetadata { async (input?: I) => toGenerateRequest(this.registry, this.render({ input })) ); + (this._promptAction as DotpromptAction).__dotprompt = this; } get promptAction(): PromptAction | undefined { @@ -239,7 +246,7 @@ export class Dotprompt implements PromptMetadata { renderedPrompt = undefined; renderedMessages = messages; } - return { + const res = { model: options.model || this.model!, config: { ...this.config, ...options.config }, messages: renderedMessages, @@ -254,8 +261,12 @@ export class Dotprompt implements PromptMetadata { onChunk: options.onChunk ?? options.streamingCallback, returnToolRequests: options.returnToolRequests, maxTurns: options.maxTurns, - use: options.use, } as GenerateOptions; + const middleware = (options.use || []).concat(this.use || []); + if (middleware.length > 0) { + res.use = middleware; + } + return res; } /** diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index 19577dc2c..52da58f69 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -128,7 +128,11 @@ describe('Prompt', () => { ); const streamingCallback = (c) => console.log(c); - const middleware = []; + const middleware = [ + async (req, next) => { + return next(); + }, + ]; const rendered = prompt.render({ input: { name: 'Michael' }, @@ -140,7 +144,7 @@ describe('Prompt', () => { assert.strictEqual(rendered.onChunk, streamingCallback); assert.strictEqual(rendered.returnToolRequests, true); assert.strictEqual(rendered.maxTurns, 17); - assert.strictEqual(rendered.use, middleware); + assert.deepStrictEqual(rendered.use, middleware); }); it('should support system prompt with history', () => {