Skip to content

Commit

Permalink
fix(js): make sure middleware is applied by prompts (#1534)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Dec 19, 2024
1 parent 4c89f5f commit a9e63ed
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 12 deletions.
9 changes: 8 additions & 1 deletion js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
GenerateRequestSchema,
GenerateResponseChunkSchema,
ModelArgument,
ModelMiddleware,
} from './model.js';
import { ToolAction } from './tool.js';

Expand All @@ -51,6 +52,7 @@ export type PromptAction<I extends z.ZodTypeAny = z.ZodTypeAny> = Action<
type: 'prompt';
};
};
__config: PromptConfig;
};

/**
Expand All @@ -62,6 +64,7 @@ export interface PromptConfig<I extends z.ZodTypeAny = z.ZodTypeAny> {
inputSchema?: I;
inputJsonSchema?: JSONSchema7;
metadata?: Record<string, any>;
use?: ModelMiddleware[];
}

/**
Expand Down Expand Up @@ -147,12 +150,16 @@ export function definePrompt<I extends z.ZodTypeAny>(
const a = defineAction(
registry,
{
...config,
name: config.name,
inputJsonSchema: config.inputJsonSchema,
inputSchema: config.inputSchema,
description: config.description,
actionType: 'prompt',
metadata: { ...(config.metadata || { prompt: {} }), type: 'prompt' },
},
fn
);
(a as PromptAction<I>).__config = config;
return a as PromptAction<I>;
}

Expand Down
18 changes: 14 additions & 4 deletions js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ import {
loadPromptFolder,
prompt,
toFrontmatter,
type DotpromptAction,
} from '@genkit-ai/dotprompt';
import { v4 as uuidv4 } from 'uuid';
import { BaseEvalDataPointSchema } from './evaluator.js';
Expand Down Expand Up @@ -463,13 +464,13 @@ export class Genkit implements HasRegistry {
);
return this.wrapPromptActionInExecutablePrompt(
dotprompt.promptAction! as PromptAction<I>,
options,
dotprompt
options
);
} else {
const p = definePrompt(
this.registry,
{
...options,
name: options.name!,
description: options.description,
inputJsonSchema: options.input?.jsonSchema,
Expand Down Expand Up @@ -513,8 +514,7 @@ export class Genkit implements HasRegistry {
promptAction: PromptAction<I> | Promise<PromptAction<I>>,
options:
| Partial<PromptMetadata<I, CustomOptions>>
| Promise<Partial<PromptMetadata<I, CustomOptions>>>,
dotprompt?: Dotprompt<z.infer<I>>
| Promise<Partial<PromptMetadata<I, CustomOptions>>>
): ExecutablePrompt<I, O, CustomOptions> {
const executablePrompt = async (
input?: z.infer<I>,
Expand Down Expand Up @@ -558,6 +558,9 @@ export class Genkit implements HasRegistry {
const p = await promptAction;
// If it's a dotprompt template, we invoke dotprompt template directly
// because it can take in more PromptGenerateOptions (not just inputs).
const dotprompt: Dotprompt<z.infer<I>> | undefined = (
p as DotpromptAction<z.infer<I>>
).__dotprompt;
const promptResult = await (dotprompt
? dotprompt.render(opt)
: p(opt.input));
Expand All @@ -584,6 +587,13 @@ export class Genkit implements HasRegistry {
},
model,
} as GenerateOptions<O, CustomOptions>;
if ((promptResult as GenerateOptions).use) {
resultOptions.use = (promptResult as GenerateOptions).use;
} else if (p.__config?.use) {
resultOptions.use = p.__config?.use;
} else if (opt.use) {
resultOptions.use = opt.use;
}
delete (resultOptions as any).input;
if ((promptResult as GenerateOptions).prompt) {
resultOptions.prompt = (promptResult as GenerateOptions).prompt;
Expand Down
244 changes: 243 additions & 1 deletion js/genkit/tests/prompts_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { modelRef } from '@genkit-ai/ai/model';
import { ModelMiddleware, modelRef } from '@genkit-ai/ai/model';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { Genkit, genkit } from '../src/genkit';
Expand All @@ -26,6 +26,168 @@ import {
defineStaticResponseModel,
} from './helpers';

const wrapRequest: ModelMiddleware = async (req, next) => {
return next({
...req,
messages: [
{
role: 'user',
content: [
{
text:
'(' +
req.messages
.map((m) => m.content.map((c) => c.text).join())
.join() +
')',
},
],
},
],
});
};
const wrapResponse: ModelMiddleware = async (req, next) => {
const res = await next(req);
return {
message: {
role: 'model',
content: [
{
text: '[' + res.message!.content.map((c) => c.text).join() + ']',
},
],
},
finishReason: res.finishReason,
};
};

describe('definePrompt - functional', () => {
let ai: Genkit;

beforeEach(() => {
ai = genkit({
model: 'echoModel',
});
defineEchoModel(ai);
});

it('should apply middleware to a prompt call', async () => {
const prompt = ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
},
async (input) => {
return {
messages: [
{
role: 'user',
content: [{ text: `hi ${input.name}` }],
},
],
};
}
);

const response = await prompt(
{ name: 'Genkit' },
{ use: [wrapRequest, wrapResponse] }
);
assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]');
});

it.only('should apply middleware configured on a prompt', async () => {
const prompt = ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
use: [wrapRequest, wrapResponse],
},
async (input) => {
return {
messages: [
{
role: 'user',
content: [{ text: `hi ${input.name}` }],
},
],
};
}
);

const response = await prompt({ name: 'Genkit' });
assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]');
});

it.only('should apply middleware to a prompt call on a looked up prompt', async () => {
ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
use: [wrapRequest, wrapResponse],
},
async (input) => {
return {
messages: [
{
role: 'user',
content: [{ text: `hi ${input.name}` }],
},
],
};
}
);

const prompt = ai.prompt('hi');

const response = await prompt({ name: 'Genkit' });
assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]');
});

it('should apply middleware configured on a prompt on a looked up prompt', async () => {
ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
},
async (input) => {
return {
messages: [
{
role: 'user',
content: [{ text: `hi ${input.name}` }],
},
],
};
}
);

const prompt = ai.prompt('hi');

const response = await prompt(
{ name: 'Genkit' },
{ use: [wrapRequest, wrapResponse] }
);
assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]');
});
});

describe('definePrompt - dotprompt', () => {
describe('default model', () => {
let ai: Genkit;
Expand Down Expand Up @@ -95,6 +257,86 @@ describe('definePrompt - dotprompt', () => {
const response = await hi({ name: 'Genkit' });
assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}');
});

it('should apply middleware to a prompt call', async () => {
const prompt = ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
},
'hi {{ name }}'
);

const response = await prompt(
{ name: 'Genkit' },
{ use: [wrapRequest, wrapResponse] }
);
assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]');
});

it('should apply middleware configured on a prompt', async () => {
const prompt = ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
use: [wrapRequest, wrapResponse],
},
'hi {{ name }}'
);

const response = await prompt({ name: 'Genkit' });
assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]');
});

it.only('should apply middleware to a prompt call on a looked up prompt', async () => {
ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
use: [wrapRequest, wrapResponse],
},
'hi {{ name }}'
);

const prompt = ai.prompt('hi');

const response = await prompt({ name: 'Genkit' });
assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]');
});

it.only('should apply middleware configured on a prompt on a looked up prompt', async () => {
ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
},
'hi {{ name }}'
);

const prompt = ai.prompt('hi');

const response = await prompt(
{ name: 'Genkit' },
{ use: [wrapRequest, wrapResponse] }
);
assert.strictEqual(response.text, '[Echo: (hi Genkit),; config: {}]');
});
});

describe('default model ref', () => {
Expand Down
6 changes: 4 additions & 2 deletions js/plugins/dotprompt/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,22 @@ import { readFileSync } from 'fs';
import { basename } from 'path';
import { toFrontmatter } from './metadata.js';
import {
defineDotprompt,
Dotprompt,
DotpromptRef,
defineDotprompt,
type DotpromptAction,
type PromptGenerateOptions,
} from './prompt.js';
import { loadPromptFolder, lookupPrompt } from './registry.js';

export { type PromptMetadata } from './metadata.js';
export { defineHelper, definePartial } from './template.js';
export {
defineDotprompt,
Dotprompt,
defineDotprompt,
loadPromptFolder,
toFrontmatter,
type DotpromptAction,
type PromptGenerateOptions,
};

Expand Down
Loading

0 comments on commit a9e63ed

Please sign in to comment.