Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Prompt Node Options Structure #967

Merged
merged 4 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1087,20 +1087,17 @@ export function tracePromptResult(
export function appendUserMessage(
messages: ChatCompletionMessageParam[],
content: string,
options?: { ephemeral?: boolean }
options?: ContextExpansionOptions
) {
if (!content) return
const { ephemeral } = options || {}
const { cacheControl } = options || {}
let last = messages.at(-1) as ChatCompletionUserMessageParam
if (
last?.role !== "user" ||
!!ephemeral !== (last?.cacheControl === "ephemeral")
) {
if (last?.role !== "user" || options?.cacheControl !== last?.cacheControl) {
last = {
role: "user",
content: "",
} satisfies ChatCompletionUserMessageParam
if (ephemeral) last.cacheControl = "ephemeral"
if (cacheControl) last.cacheControl = cacheControl
messages.push(last)
}
if (last.content) {
Expand All @@ -1112,20 +1109,20 @@ export function appendUserMessage(
export function appendAssistantMessage(
messages: ChatCompletionMessageParam[],
content: string,
options?: { ephemeral?: boolean }
options?: ContextExpansionOptions
) {
if (!content) return
const { ephemeral } = options || {}
const { cacheControl } = options || {}
let last = messages.at(-1) as ChatCompletionAssistantMessageParam
if (
last?.role !== "assistant" ||
!!ephemeral !== (last?.cacheControl === "ephemeral")
options?.cacheControl !== last?.cacheControl
) {
last = {
role: "assistant",
content: "",
} satisfies ChatCompletionAssistantMessageParam
if (ephemeral) last.cacheControl = "ephemeral"
if (cacheControl) last.cacheControl = cacheControl
messages.push(last)
}
if (last.content) {
Expand All @@ -1137,21 +1134,21 @@ export function appendAssistantMessage(
export function appendSystemMessage(
messages: ChatCompletionMessageParam[],
content: string,
options?: { ephemeral?: boolean }
options?: ContextExpansionOptions
) {
if (!content) return
const { ephemeral } = options || {}
const { cacheControl } = options || {}

let last = messages[0] as ChatCompletionSystemMessageParam
if (
last?.role !== "system" ||
!!ephemeral !== (last?.cacheControl === "ephemeral")
options?.cacheControl !== last?.cacheControl
) {
last = {
role: "system",
content: "",
} as ChatCompletionSystemMessageParam
if (ephemeral) last.cacheControl = "ephemeral"
if (cacheControl) last.cacheControl = cacheControl
messages.unshift(last)
}
if (last.content) {
Expand Down
40 changes: 22 additions & 18 deletions packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ export interface PromptNode extends ContextExpansionOptions {
children?: PromptNode[] // Child nodes for hierarchical structure
error?: unknown // Error information if present
tokens?: number // Token count for the node
/**
* Definte a prompt caching breakpoint.
* This prompt prefix (including this text) is cacheable for a short amount of time.
*/
ephemeral?: boolean

/**
* Rendered markdown preview of the node
Expand Down Expand Up @@ -237,6 +232,15 @@ export function createDef(
return { type: "def", name, value, ...(options || {}) }
}

function cloneContextFields(n: PromptNode): Partial<PromptNode> {
const r = {} as Partial<PromptNode>
r.maxTokens = n.maxTokens
r.priority = n.priority
r.flex = n.flex
r.cacheControl = n.cacheControl
return r
}

export function createDefDiff(
name: string,
left: string | WorkspaceFile,
Expand Down Expand Up @@ -327,7 +331,7 @@ function renderDefNode(def: PromptDefNode): string {
}

async function renderDefDataNode(n: PromptDefDataNode): Promise<string> {
const { name, headers, priority, ephemeral, query } = n
const { name, headers, priority, cacheControl, query } = n
let data = n.resolved
let format = n.format
if (
Expand Down Expand Up @@ -680,7 +684,7 @@ async function resolvePromptNode(
const rendered = renderDefNode(n)
n.preview = rendered
n.tokens = estimateTokens(rendered, encoder)
n.children = [createTextNode(rendered)]
n.children = [createTextNode(rendered, cloneContextFields(n))]
} catch (e) {
n.error = e
}
Expand All @@ -693,7 +697,7 @@ async function resolvePromptNode(
const rendered = await renderDefDataNode(n)
n.preview = rendered
n.tokens = estimateTokens(rendered, encoder)
n.children = [createTextNode(rendered)]
n.children = [createTextNode(rendered, cloneContextFields(n))]
} catch (e) {
n.error = e
}
Expand Down Expand Up @@ -929,7 +933,7 @@ async function truncatePromptNode(
n.tokens = estimateTokens(n.resolved.content, encoder)
const rendered = renderDefNode(n)
n.preview = rendered
n.children = [createTextNode(rendered)]
n.children = [createTextNode(rendered, cloneContextFields(n))]
truncated = true
trace.log(
`truncated def ${n.name} to ${n.tokens} tokens (max ${n.maxTokens})`
Expand Down Expand Up @@ -1062,14 +1066,14 @@ async function validateSafetyPromptNode(
def: async (n) => {
if (!n.detectPromptInjection || !n.resolved?.content) return

const detectPromptInjection = await resolveContentSafety()
const detectPromptInjectionFn = await resolveContentSafety()
if (
(!detectPromptInjection && n.detectPromptInjection === true) ||
(!detectPromptInjectionFn && n.detectPromptInjection === true) ||
n.detectPromptInjection === "always"
)
throw new Error("content safety service not available")
const { attackDetected } =
(await detectPromptInjection?.(n.resolved)) || {}
(await detectPromptInjectionFn?.(n.resolved)) || {}
if (attackDetected) {
mod = true
n.resolved = {
Expand All @@ -1087,14 +1091,14 @@ async function validateSafetyPromptNode(
defData: async (n) => {
if (!n.detectPromptInjection || !n.preview) return

const detectPromptInjection = await resolveContentSafety()
const detectPromptInjectionFn = await resolveContentSafety()
if (
(!detectPromptInjection && n.detectPromptInjection === true) ||
(!detectPromptInjectionFn && n.detectPromptInjection === true) ||
n.detectPromptInjection === "always"
)
throw new Error("content safety service not available")
const { attackDetected } =
(await detectPromptInjection?.(n.preview)) || {}
(await detectPromptInjectionFn?.(n.preview)) || {}
if (attackDetected) {
mod = true
n.children = []
Expand Down Expand Up @@ -1164,13 +1168,13 @@ export async function renderPromptNode(
if (safety) await tracePromptNode(trace, node, { label: "safety" })

const messages: ChatCompletionMessageParam[] = []
const appendSystem = (content: string, options: { ephemeral?: boolean }) =>
const appendSystem = (content: string, options: ContextExpansionOptions) =>
appendSystemMessage(messages, content, options)
const appendUser = (content: string, options: { ephemeral?: boolean }) =>
const appendUser = (content: string, options: ContextExpansionOptions) =>
appendUserMessage(messages, content, options)
const appendAssistant = (
content: string,
options: { ephemeral?: boolean }
options: ContextExpansionOptions
) => appendAssistantMessage(messages, content, options)

const images: PromptImage[] = []
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ export function createChatTurnGenerationContext(
return res
},
cacheControl: (cc) => {
current.ephemeral = cc === "ephemeral"
current.cacheControl = cc
return res
},
} satisfies PromptTemplateString)
Expand Down
5 changes: 0 additions & 5 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -951,11 +951,6 @@ interface ContextExpansionOptions {
*/
flex?: number

/**
* @deprecated use cacheControl instead
*/
ephemeral?: boolean

/**
* Caching policy for this text. `ephemeral` means the prefix can be cached for a short amount of time.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
script({
title: "summarize all files with caching",
files: "src/rag/markdown.md",
model: "small",
tests: [
{
files: "src/rag/markdown.md",
Expand Down
Loading