Skip to content

Commit

Permalink
Refactor common explainPatch code into utility func
Browse files Browse the repository at this point in the history
  • Loading branch information
diracdeltas committed May 9, 2024
1 parent 26f55f9 commit c9d7e2f
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 154 deletions.
62 changes: 11 additions & 51 deletions src/anthropicExplainPatch.js
Original file line number Diff line number Diff line change
@@ -1,36 +1,18 @@
import Anthropic from '@anthropic-ai/sdk'
import { countTokens } from '@anthropic-ai/tokenizer'
import { SYSTEM_PROMPT, explainPatchHelper } from './utils.js'

/* eslint-disable camelcase */
export default async function explainPatch ({
apiKey, patchBody, owner, repo,
models = ['claude-3-opus-20240229'],
system = `
You are an expert software engineer reviewing a pull request on Github. Lines that start with "+" have been added, lines that start with "-" have been deleted. Use markdown for formatting your review.
Desired format:
### Description
<description_of_PR> // How does this PR change the codebase? What is the motivation for this change?
### Changes
<list_of_changes> // Describe the main changes in the PR, organizing them by filename
### Security Hotspots
<list_of_security_hotspots> // Describe locations for possible vulnerabilities in the change, order by risk
\n`,
system = SYSTEM_PROMPT,
max_tokens = 2048,
temperature = 1,
top_p = 1,
amplification = 2,
debug = false
}) {
const userPrompt = `Repository: https://github.com/${owner}/${repo}\n\nThis is the PR diff\n\`\`\`\n${patchBody}\n\`\`\``
const realModels = Array.isArray(models) ? models : models.split(' ')

if (debug) {
console.log(`user_prompt:\n\n${userPrompt}`)
}

const pLen = countTokens(patchBody)
if (pLen === 0) { throw new Error('The patch is empty, cannot summarize!') }
if (pLen < amplification * max_tokens) {
Expand All @@ -39,15 +21,10 @@ Desired format:

const anthropic = new Anthropic({ apiKey })

let model = null

// iterate over the models, and find one that works, or throw the last error
// catch all the original errors in between, and throw the last one
for (let i = 0; i < realModels.length; i++) {
try {
model = realModels[i]

var aiResponse = await anthropic.messages.create({
return await explainPatchHelper(
patchBody, owner, repo, models, debug,
async (userPrompt, model) => {
const aiResponse = await anthropic.messages.create({
max_tokens,
temperature,
model,
Expand All @@ -60,29 +37,12 @@ Desired format:
}
]
})

break
} catch (e) {
if (i + 1 === realModels.length) {
// last model
throw e
const response = aiResponse.content
if (debug) {
console.log(response)
}

console.log(e)
continue
return response[0].text
}
}

let response = aiResponse.content

if (debug) {
console.log(response)
}

response = response[0].text.replaceAll('### Changes', '<details>\n<summary><i>Changes</i></summary>\n\n### Changes')
response = response.replaceAll('### Security Hotspots', '</details>\n\n### Security Hotspots')
response += `\n\n<!-- Generated by ${model} -->`

return response
)
}
/* eslint-enable camelcase */
64 changes: 12 additions & 52 deletions src/bedrockExplainPatch.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
BedrockRuntimeClient,
InvokeModelCommand
} from '@aws-sdk/client-bedrock-runtime'
import { SYSTEM_PROMPT, explainPatchHelper } from './utils.js'

const COUNT_TOKENS_HASHFUN = {
'amazon.titan-text-express-v1': null,
Expand Down Expand Up @@ -45,51 +46,26 @@ const countTokens = (text, modelId) => {
export default async function explainPatch ({
patchBody, owner, repo,
models = ['anthropic.claude-3-opus-20240229-v1:0'],
system = `
You are an expert software engineer reviewing a pull request on Github. Lines that start with "+" have been added, lines that start with "-" have been deleted. Use markdown for formatting your review.
Desired format:
### Description
<description_of_PR> // How does this PR change the codebase? What is the motivation for this change?
### Changes
<list_of_changes> // Describe the main changes in the PR, organizing them by filename
### Security Hotspots
<list_of_security_hotspots> // Describe locations for possible vulnerabilities in the change, order by risk
\n`,
system = SYSTEM_PROMPT,
max_tokens = 2048,
temperature = 1,
top_p = 1,
amplification = 2,
region = 'us-east-1',
debug = false
}) {
const userPrompt = `Repository: https://github.com/${owner}/${repo}\n\nThis is the PR diff\n\`\`\`\n${patchBody}\n\`\`\``
const realModels = Array.isArray(models) ? models : models.split(' ')

if (debug) {
console.log(`user_prompt:\n\n${userPrompt}`)
}

const client = new BedrockRuntimeClient({ region })

let modelId = null
let response = null

// iterate over the models, and find one that works, or throw the last error
// catch all the original errors in between, and throw the last one
for (let i = 0; i < realModels.length; i++) {
try {
modelId = realModels[i]

const pLen = countTokens(patchBody, modelId)
return await explainPatchHelper(
patchBody, owner, repo, models, debug,
async (userPrompt, model) => {
const pLen = countTokens(patchBody, model)

if (pLen === 0) { throw new Error('The patch is empty, cannot summarize!') }
if (pLen < amplification * max_tokens) { throw new Error('The patch is trivial, no need for a summarization') }

const command = new InvokeModelCommand({
modelId,
modelId: model,
contentType: 'application/json',
accept: 'application/json',
body: JSON.stringify({
Expand All @@ -112,28 +88,12 @@ Desired format:
if (debug) {
console.log(`response:\n\n${textResponse}`)
}
response = JSON.parse(textResponse)

break
} catch (e) {
if (i + 1 === realModels.length) {
// last model
throw e
const response = JSON.parse(textResponse)
if (debug) {
console.log(response)
}

console.log(e)
continue
return response.content[0].text
}
}

if (debug) {
console.log(response)
}

response = response.content[0].text.replaceAll('### Changes', '<details>\n<summary><i>Changes</i></summary>\n\n### Changes')
response = response.replaceAll('### Security Hotspots', '</details>\n\n### Security Hotspots')
response += `\n\n<!-- Generated by ${modelId} -->`

return response
)
}
/* eslint-enable camelcase */
66 changes: 15 additions & 51 deletions src/openaiExplainPatch.js
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
import OpenAI from 'openai'
/* eslint-disable camelcase */
import { encoding_for_model } from 'tiktoken'
import { SYSTEM_PROMPT, explainPatchHelper } from './utils.js'

export default async function explainPatch ({
apiKey, patchBody, owner, repo,
models = ['gpt-4-0125-preview', 'gpt-3.5-turbo-0125'],
system = `
You are an expert software engineer reviewing a pull request on Github. Lines that start with "+" have been added, lines that start with "-" have been deleted. Use markdown for formatting your review.
Desired format:
### Description
<description_of_PR> // How does this PR change the codebase? What is the motivation for this change?
### Changes
<list_of_changes> // Describe the main changes in the PR, organizing them by filename
### Security Hotspots
<list_of_security_hotspots> // Describe locations for possible vulnerabilities in the change, order by risk
\n`,
system = SYSTEM_PROMPT,
max_tokens = 2048,
temperature = 1,
top_p = 1,
Expand All @@ -27,26 +16,18 @@ Desired format:
debug = false
}) {
const openai = new OpenAI({ apiKey })
const realModels = Array.isArray(models) ? models : models.split(' ')
const userPrompt = `Repository: https://github.com/${owner}/${repo}\n\nThis is the PR diff\n\`\`\`\n${patchBody}\n\`\`\``

if (debug) {
console.log(`user_prompt:\n\n${userPrompt}`)
}
return await explainPatchHelper(
patchBody, owner, repo, models, debug,
async (userPrompt, model) => {
const enc = encoding_for_model(model)
const pLen = enc.encode(patchBody).length

let m = null
let enc
let pLen

for (let i = 0; i < realModels.length; i++) {
try {
m = realModels[i]
enc = encoding_for_model(m)
pLen = enc.encode(patchBody).length
if (pLen === 0) { throw new Error('The patch is empty, cannot summarize!') }
if (pLen < amplification * max_tokens) { throw new Error('The patch is trivial, no need for a summarization') }
var aiResponse = await openai.chat.completions.create({
model: m,

const aiResponse = await openai.chat.completions.create({
model,
messages: [
{
role: 'system',
Expand All @@ -63,29 +44,12 @@ Desired format:
frequency_penalty,
presence_penalty
})
break
} catch (e) {
if (i + 1 === realModels.length) {
// last model
throw e
if (debug) {
console.log(aiResponse)
console.log(aiResponse.choices[0].message)
}

console.log(e)
continue
return aiResponse.choices[0].message.content
}
}

if (debug) {
console.log(aiResponse)
console.log(aiResponse.choices[0].message)
}

let response = aiResponse.choices[0].message.content

response = response.replaceAll('### Changes', '<details>\n<summary><i>Changes</i></summary>\n\n### Changes')
response = response.replaceAll('### Security Hotspots', '</details>\n\n### Security Hotspots')
response += `\n\n<!-- Generated by ${m} -->`

return response
)
}
/* eslint-enable camelcase */
47 changes: 47 additions & 0 deletions src/utils.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
export const SYSTEM_PROMPT = `
You are an expert software engineer reviewing a pull request on Github. Lines that start with "+" have been added, lines that start with "-" have been deleted. Use markdown for formatting your review.
Desired format:
### Description
<description_of_PR> // How does this PR change the codebase? What is the motivation for this change?
### Changes
<list_of_changes> // Describe the main changes in the PR, organizing them by filename
### Security Hotspots
<list_of_security_hotspots> // Describe locations for possible vulnerabilities in the change, order by risk
\n`

export async function explainPatchHelper (patchBody, owner, repo, models, debug, getResponse) {
models = Array.isArray(models) ? models : models.split(' ')

let model = null
let response = null

const userPrompt = `Repository: https://github.com/${owner}/${repo}\n\nThis is the PR diff\n\`\`\`\n${patchBody}\n\`\`\``

if (debug) {
console.log(`user_prompt:\n\n${userPrompt}`)
}

for (let i = 0; i < this.models.length; i++) {
try {
model = this.models[i]
response = await this.getResponse(userPrompt, model)
break
} catch (e) {
if (i + 1 === this.models.length) {
// last model
throw e
}

console.log(e)
continue
}
}

response = response.replaceAll('### Changes', '<details>\n<summary><i>Changes</i></summary>\n\n### Changes')
response = response.replaceAll('### Security Hotspots', '</details>\n\n### Security Hotspots')
response += `\n\n<!-- Generated by ${model} -->`
return response
}

0 comments on commit c9d7e2f

Please sign in to comment.