From b26fd218241e2df5c5de4416fcb658b5de3581c5 Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Wed, 20 Dec 2023 16:16:38 +0100 Subject: [PATCH] x/evals: better prompting (#2975) --- x/spolu/research/evals/lib/algorithms/CoT.ts | 31 +++----------- x/spolu/research/evals/lib/datasets.ts | 8 ++-- x/spolu/research/evals/lib/datasets/MATH.ts | 40 ++++++++++++------- x/spolu/research/evals/lib/datasets/game24.ts | 38 ++++++++++-------- x/spolu/research/evals/lib/models/openai.ts | 2 + 5 files changed, 59 insertions(+), 60 deletions(-) diff --git a/x/spolu/research/evals/lib/algorithms/CoT.ts b/x/spolu/research/evals/lib/algorithms/CoT.ts index 2ce4de7ced69..6568f11b4078 100644 --- a/x/spolu/research/evals/lib/algorithms/CoT.ts +++ b/x/spolu/research/evals/lib/algorithms/CoT.ts @@ -39,10 +39,8 @@ export class CoT extends Algorithm { let prompt = `INSTRUCTIONS:\n`; prompt += ` ${this.dataset.instructions()}`; prompt += "\n\n"; - prompt += `Start by providing a REASONING consisting in multiple steps, using one line per step.`; + prompt += `Provide a reasoning consisting in multiple steps, using one line per step.`; prompt += ` ${this.dataset.reasoningStepInstructions()}`; - prompt += ` Finally provide a final ANSWER.`; - prompt += ` ${this.dataset.answerInstructions()}`; // prompt += // ` Do not perform multiple reasoning attempts per question,` + // ` do not backtrack in your reasoning steps.`; @@ -52,7 +50,6 @@ export class CoT extends Algorithm { for (const e of examples.slice(0, 4)) { prompt += `\nQUESTION: ${e.question}\n`; prompt += `REASONING:\n${e.reasoning.join("\n")}\n`; - prompt += `ANSWER: ${e.answer}\n`; } messages.push({ @@ -67,7 +64,7 @@ export class CoT extends Algorithm { }); messages.push({ role: "assistant", - content: `REASONING:\n${e.reasoning.join("\n")}\nANSWER: ${e.answer}`, + content: `REASONING:\n${e.reasoning.join("\n")}`, }); } @@ -79,18 +76,14 @@ export class CoT extends Algorithm { // console.log(prompt); // console.log(messages); - let maxTokens: number | undefined = undefined; - const datasetMaxTokens = this.dataset.maxTokens(); - if (datasetMaxTokens.reasoning && datasetMaxTokens.answer) { - maxTokens = datasetMaxTokens.reasoning + datasetMaxTokens.answer; - } - const query: ChatQuery = { provider: this.model.provider, model: this.model.model(), messages, temperature: this.TEMPERATURE, - maxTokens, + maxTokens: + this.dataset.maxTokens().reasoningStep * + this.dataset.maxTokens().maxStepCount, }; const c = await this.runCompletion(query); @@ -125,18 +118,7 @@ export class CoT extends Algorithm { console.log("+++++++++++++++++++++++++"); } - if (!c.content || !c.content.includes("REASONING:")) { - return await finish(test, c, query, false, ""); - } - - const content = c.content.split("REASONING:")[1].trim(); - - if (!content.includes("ANSWER:")) { - return await finish(test, c, query, false, ""); - } - - const reasoning = content.split("ANSWER:")[0].trim().split("\n"); - const answer = content.split("ANSWER:")[1].trim(); + const answer = this.dataset.parseAnswer(c.content); let check = false; try { @@ -146,7 +128,6 @@ export class CoT extends Algorithm { } if (debug) { - console.log(`REASONING: ${reasoning.join(" ")}`); console.log(`ANSWER: ${answer}`); console.log(`CHECK: ${check}`); console.log("-------------------------"); diff --git a/x/spolu/research/evals/lib/datasets.ts b/x/spolu/research/evals/lib/datasets.ts index 7f951f896542..aea27a852e43 100644 --- a/x/spolu/research/evals/lib/datasets.ts +++ b/x/spolu/research/evals/lib/datasets.ts @@ -24,14 +24,14 @@ export abstract class Dataset { abstract instructions(): string; abstract reasoningStepInstructions(): string; - abstract answerInstructions(): string; abstract maxTokens(): { - resaoningStep: number | null; - reasoning: number | null; - answer: number | null; + reasoningStep: number; + maxStepCount: number; }; + abstract parseAnswer(str: string): string; + abstract tests({ count }: { count: number }): Test[]; abstract examples({ diff --git a/x/spolu/research/evals/lib/datasets/MATH.ts b/x/spolu/research/evals/lib/datasets/MATH.ts index d9ba943bee9f..6d01c81d589f 100644 --- a/x/spolu/research/evals/lib/datasets/MATH.ts +++ b/x/spolu/research/evals/lib/datasets/MATH.ts @@ -36,6 +36,7 @@ export class MATH extends Dataset { d[e.type][e.level] = []; } d[e.type][e.level].push(e); + // console.log(e.reasoning.length); } return d; @@ -69,30 +70,39 @@ export class MATH extends Dataset { } instructions(): string { - return `Find a solution to the provided mathematical problem below.`; + return ( + "Find a solution to the provided mathematical problem." + + " The answer is a unique mathematical expression presented in LaTeX `\\boxed{}` directive. " + + " (example: `\\boxed{4}` or `\\boxed{3\\pi}`). Formatting instructions: " + + " fractions should be represented in the LaTeX form `\\frac{a}{b}` (not `\\frac12`)," + + " units should not be included," + + " square roots should be presented in the LaTeX form `\\sqrt{c}` (not `\\sqrt2`)," + + " all spaces and non critical parentheses or formatting should be stripped," + + " rational numbers should be presented with a leading `0`." + ); } reasoningStepInstructions(): string { - return `A reasoning step is one coherent step of mathematical reasoning it should held in one line.`; - } - - answerInstructions(): string { return ( - ` The answer is a unique mathematical expression presented in a LaTeX '\\boxed' directive` + - ` (eg: \\boxed{4} or \\boxed{3\\pi}). Formatting instructions:` + - ` fractions should be represented in the LaTeX form \\frac{a}{b} (not \\frac12),` + - ` units should not be included,` + - ` square roots should be presented in the LaTeX form \\sqrt{c} (not \\sqrt2),` + - ` all spaces and non critical parentheses or formatting should be stripped,` + - ` rational numbers should be presented with a leading 0.` + "A reasoning step is one coherent step of mathematical reasoning. It should hold in one line" + + " of at most 500 characters." + + " If an answer is reached as part of the reasoning, it should be included" + + " in the reasoning step using the `\\boxed{}` directive." ); } + parseAnswer(str: string): string { + const boxed = str.match(/\\boxed{([^}]*)}/g); + if (!boxed) { + return ""; + } + return boxed[boxed.length - 1]; + } + maxTokens() { return { - resaoningStep: 512, - reasoning: 3584, - answer: 64, + reasoningStep: 256, + maxStepCount: 16, }; } diff --git a/x/spolu/research/evals/lib/datasets/game24.ts b/x/spolu/research/evals/lib/datasets/game24.ts index 5371a34963b9..bc62fe07cf03 100644 --- a/x/spolu/research/evals/lib/datasets/game24.ts +++ b/x/spolu/research/evals/lib/datasets/game24.ts @@ -80,7 +80,7 @@ export class Game24 extends Dataset { if (result !== 24) { throw new Error("Unexpected non 24 result"); } - const r = `${a}${op}${b}=${result}`; + const r = `${a}${op}${b}=${result}, \\boxed{${solution}}`; reasoning.push(r); } } @@ -116,33 +116,39 @@ export class Game24 extends Dataset { instructions(): string { return ( - `Given a set of 4 input numbers, find a mathematical expression using each number` + - ` exactly once that symbolically evaluates to 24 (Game of 24).` + - ` The available operators are [+,-,*,/]` + - ` (the division operator / is the symbolic division (eg: 2/(3-5/2) = 2/(1/2) = 4)).` + "Given a set of 4 input numbers, find a mathematical expression using each number" + + " exactly once that symbolically evaluates to 24 (Game of 24)." + + " The available operators are [+,-,*,/]" + + " (the division operator / is the symbolic division (`2/(3-5/2) = 2/(1/2) = 4`))." ); } reasoningStepInstructions(): string { return ( - `A reasoning step is one operation involving 2 numbers followed by the numbers left to form` + - ` 24 after that operation (eg: '10*7=70, left: 70 2 11').` + - ` There is always exactly 3 reasoning steps per question.` + "A reasoning step is one operation involving 2 numbers followed by the numbers left to form" + + " 24 after that operation, separated by a comma (example: `10*7=70, left: 70 2 11`)." + + " There is always exactly 3 reasoning steps per question in Game of 24." + + " The last step should present the last operation and the solution expression" + + " using the `\\boxed{}` directive, sperated by a comma" + + " (example: `35-11=24, \\\boxed{(6+1)*5-11}`)." ); } - answerInstructions(): string { - return ( - `The answer should be a valid solution expression without space using each number` + - ` exactly once (eg: '(6+1)*5-11' or '(9-1)*9/3').` - ); + parseAnswer(str: string): string { + const boxed = str.match(/\\boxed{([^}]*)}/g); + if (!boxed) { + return ""; + } + // remove the \boxed{} directive + const answer = boxed.map((s) => s.slice(7, s.length - 1)); + // return the last one + return answer[answer.length - 1]; } maxTokens() { return { - resaoningStep: 32, - reasoning: 32 * 3, - answer: 16, + reasoningStep: 32, + maxStepCount: 3, }; } diff --git a/x/spolu/research/evals/lib/models/openai.ts b/x/spolu/research/evals/lib/models/openai.ts index f734bc058ef6..3bd4ef2b6fd3 100644 --- a/x/spolu/research/evals/lib/models/openai.ts +++ b/x/spolu/research/evals/lib/models/openai.ts @@ -28,8 +28,10 @@ export class OpenAIModel extends Model { messages: query.messages, max_tokens: query.maxTokens, temperature: query.temperature, + // logprobs: true, }); + // console.log(JSON.stringify(completion)); const m = completion.choices[0].message; if (m.content === null) {