diff --git a/src/anthropicExplainPatch.js b/src/anthropicExplainPatch.js index 2c11a64..0a582f7 100644 --- a/src/anthropicExplainPatch.js +++ b/src/anthropicExplainPatch.js @@ -1,36 +1,18 @@ import Anthropic from '@anthropic-ai/sdk' import { countTokens } from '@anthropic-ai/tokenizer' +import { SYSTEM_PROMPT, explainPatchHelper } from './utils.js' /* eslint-disable camelcase */ export default async function explainPatch ({ apiKey, patchBody, owner, repo, models = ['claude-3-opus-20240229'], - system = ` -You are an expert software engineer reviewing a pull request on Github. Lines that start with "+" have been added, lines that start with "-" have been deleted. Use markdown for formatting your review. - -Desired format: -### Description - // How does this PR change the codebase? What is the motivation for this change? - -### Changes - // Describe the main changes in the PR, organizing them by filename - -### Security Hotspots - // Describe locations for possible vulnerabilities in the change, order by risk -\n`, + system = SYSTEM_PROMPT, max_tokens = 2048, temperature = 1, top_p = 1, amplification = 2, debug = false }) { - const userPrompt = `Repository: https://github.com/${owner}/${repo}\n\nThis is the PR diff\n\`\`\`\n${patchBody}\n\`\`\`` - const realModels = Array.isArray(models) ? models : models.split(' ') - - if (debug) { - console.log(`user_prompt:\n\n${userPrompt}`) - } - const pLen = countTokens(patchBody) if (pLen === 0) { throw new Error('The patch is empty, cannot summarize!') } if (pLen < amplification * max_tokens) { @@ -39,15 +21,10 @@ Desired format: const anthropic = new Anthropic({ apiKey }) - let model = null - - // iterate over the models, and find one that works, or throw the last error - // catch all the original errors in between, and throw the last one - for (let i = 0; i < realModels.length; i++) { - try { - model = realModels[i] - - var aiResponse = await anthropic.messages.create({ + return await explainPatchHelper( + patchBody, owner, repo, models, debug, + async (userPrompt, model) => { + const aiResponse = await anthropic.messages.create({ max_tokens, temperature, model, @@ -60,29 +37,12 @@ Desired format: } ] }) - - break - } catch (e) { - if (i + 1 === realModels.length) { - // last model - throw e + const response = aiResponse.content + if (debug) { + console.log(response) } - - console.log(e) - continue + return response[0].text } - } - - let response = aiResponse.content - - if (debug) { - console.log(response) - } - - response = response[0].text.replaceAll('### Changes', '
\nChanges\n\n### Changes') - response = response.replaceAll('### Security Hotspots', '
\n\n### Security Hotspots') - response += `\n\n` - - return response + ) } /* eslint-enable camelcase */ diff --git a/src/bedrockExplainPatch.js b/src/bedrockExplainPatch.js index 2647467..7389231 100644 --- a/src/bedrockExplainPatch.js +++ b/src/bedrockExplainPatch.js @@ -3,6 +3,7 @@ import { BedrockRuntimeClient, InvokeModelCommand } from '@aws-sdk/client-bedrock-runtime' +import { SYSTEM_PROMPT, explainPatchHelper } from './utils.js' const COUNT_TOKENS_HASHFUN = { 'amazon.titan-text-express-v1': null, @@ -45,19 +46,7 @@ const countTokens = (text, modelId) => { export default async function explainPatch ({ patchBody, owner, repo, models = ['anthropic.claude-3-opus-20240229-v1:0'], - system = ` -You are an expert software engineer reviewing a pull request on Github. Lines that start with "+" have been added, lines that start with "-" have been deleted. Use markdown for formatting your review. - -Desired format: -### Description - // How does this PR change the codebase? What is the motivation for this change? - -### Changes - // Describe the main changes in the PR, organizing them by filename - -### Security Hotspots - // Describe locations for possible vulnerabilities in the change, order by risk -\n`, + system = SYSTEM_PROMPT, max_tokens = 2048, temperature = 1, top_p = 1, @@ -65,31 +54,18 @@ Desired format: region = 'us-east-1', debug = false }) { - const userPrompt = `Repository: https://github.com/${owner}/${repo}\n\nThis is the PR diff\n\`\`\`\n${patchBody}\n\`\`\`` - const realModels = Array.isArray(models) ? models : models.split(' ') - - if (debug) { - console.log(`user_prompt:\n\n${userPrompt}`) - } - const client = new BedrockRuntimeClient({ region }) - let modelId = null - let response = null - - // iterate over the models, and find one that works, or throw the last error - // catch all the original errors in between, and throw the last one - for (let i = 0; i < realModels.length; i++) { - try { - modelId = realModels[i] - - const pLen = countTokens(patchBody, modelId) + return await explainPatchHelper( + patchBody, owner, repo, models, debug, + async (userPrompt, model) => { + const pLen = countTokens(patchBody, model) if (pLen === 0) { throw new Error('The patch is empty, cannot summarize!') } if (pLen < amplification * max_tokens) { throw new Error('The patch is trivial, no need for a summarization') } const command = new InvokeModelCommand({ - modelId, + modelId: model, contentType: 'application/json', accept: 'application/json', body: JSON.stringify({ @@ -112,28 +88,12 @@ Desired format: if (debug) { console.log(`response:\n\n${textResponse}`) } - response = JSON.parse(textResponse) - - break - } catch (e) { - if (i + 1 === realModels.length) { - // last model - throw e + const response = JSON.parse(textResponse) + if (debug) { + console.log(response) } - - console.log(e) - continue + return response.content[0].text } - } - - if (debug) { - console.log(response) - } - - response = response.content[0].text.replaceAll('### Changes', '
\nChanges\n\n### Changes') - response = response.replaceAll('### Security Hotspots', '
\n\n### Security Hotspots') - response += `\n\n` - - return response + ) } /* eslint-enable camelcase */ diff --git a/src/openaiExplainPatch.js b/src/openaiExplainPatch.js index 408796f..6d0d853 100644 --- a/src/openaiExplainPatch.js +++ b/src/openaiExplainPatch.js @@ -1,23 +1,12 @@ import OpenAI from 'openai' /* eslint-disable camelcase */ import { encoding_for_model } from 'tiktoken' +import { SYSTEM_PROMPT, explainPatchHelper } from './utils.js' export default async function explainPatch ({ apiKey, patchBody, owner, repo, models = ['gpt-4-0125-preview', 'gpt-3.5-turbo-0125'], - system = ` -You are an expert software engineer reviewing a pull request on Github. Lines that start with "+" have been added, lines that start with "-" have been deleted. Use markdown for formatting your review. - -Desired format: -### Description - // How does this PR change the codebase? What is the motivation for this change? - -### Changes - // Describe the main changes in the PR, organizing them by filename - -### Security Hotspots - // Describe locations for possible vulnerabilities in the change, order by risk -\n`, + system = SYSTEM_PROMPT, max_tokens = 2048, temperature = 1, top_p = 1, @@ -27,26 +16,18 @@ Desired format: debug = false }) { const openai = new OpenAI({ apiKey }) - const realModels = Array.isArray(models) ? models : models.split(' ') - const userPrompt = `Repository: https://github.com/${owner}/${repo}\n\nThis is the PR diff\n\`\`\`\n${patchBody}\n\`\`\`` - if (debug) { - console.log(`user_prompt:\n\n${userPrompt}`) - } + return await explainPatchHelper( + patchBody, owner, repo, models, debug, + async (userPrompt, model) => { + const enc = encoding_for_model(model) + const pLen = enc.encode(patchBody).length - let m = null - let enc - let pLen - - for (let i = 0; i < realModels.length; i++) { - try { - m = realModels[i] - enc = encoding_for_model(m) - pLen = enc.encode(patchBody).length if (pLen === 0) { throw new Error('The patch is empty, cannot summarize!') } if (pLen < amplification * max_tokens) { throw new Error('The patch is trivial, no need for a summarization') } - var aiResponse = await openai.chat.completions.create({ - model: m, + + const aiResponse = await openai.chat.completions.create({ + model, messages: [ { role: 'system', @@ -63,29 +44,12 @@ Desired format: frequency_penalty, presence_penalty }) - break - } catch (e) { - if (i + 1 === realModels.length) { - // last model - throw e + if (debug) { + console.log(aiResponse) + console.log(aiResponse.choices[0].message) } - - console.log(e) - continue + return aiResponse.choices[0].message.content } - } - - if (debug) { - console.log(aiResponse) - console.log(aiResponse.choices[0].message) - } - - let response = aiResponse.choices[0].message.content - - response = response.replaceAll('### Changes', '
\nChanges\n\n### Changes') - response = response.replaceAll('### Security Hotspots', '
\n\n### Security Hotspots') - response += `\n\n` - - return response + ) } /* eslint-enable camelcase */ diff --git a/src/utils.js b/src/utils.js new file mode 100644 index 0000000..3d9a8f8 --- /dev/null +++ b/src/utils.js @@ -0,0 +1,47 @@ +export const SYSTEM_PROMPT = ` +You are an expert software engineer reviewing a pull request on Github. Lines that start with "+" have been added, lines that start with "-" have been deleted. Use markdown for formatting your review. + +Desired format: +### Description + // How does this PR change the codebase? What is the motivation for this change? + +### Changes + // Describe the main changes in the PR, organizing them by filename + +### Security Hotspots + // Describe locations for possible vulnerabilities in the change, order by risk +\n` + +export async function explainPatchHelper (patchBody, owner, repo, models, debug, getResponse) { + models = Array.isArray(models) ? models : models.split(' ') + + let model = null + let response = null + + const userPrompt = `Repository: https://github.com/${owner}/${repo}\n\nThis is the PR diff\n\`\`\`\n${patchBody}\n\`\`\`` + + if (debug) { + console.log(`user_prompt:\n\n${userPrompt}`) + } + + for (let i = 0; i < this.models.length; i++) { + try { + model = this.models[i] + response = await this.getResponse(userPrompt, model) + break + } catch (e) { + if (i + 1 === this.models.length) { + // last model + throw e + } + + console.log(e) + continue + } + } + + response = response.replaceAll('### Changes', '
\nChanges\n\n### Changes') + response = response.replaceAll('### Security Hotspots', '
\n\n### Security Hotspots') + response += `\n\n` + return response +}