Skip to content

Commit

Permalink
add support for Anthropic Claude, fix several bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesZhutheThird committed Mar 12, 2024
1 parent aa29f8e commit db67800
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 26 deletions.
8 changes: 8 additions & 0 deletions eval/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
"evaluator": "GPTEvaluator",
"split_sys": True,
},
"claude": {
"avail_model": ["claude-3-opus-20240229", "claude-3-sonnet-20240229"],
"model_type": "api",
"support_input": [0, 1,2,3],
"executor": "claude",
"evaluator": "ClaudeEvaluator",
"split_sys": True,
},
"geminivision": {
"avail_model": ["gemini-pro-vision", ],
"model_type": "api",
Expand Down
2 changes: 1 addition & 1 deletion eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,4 +440,4 @@ def main(args):

if __name__ == "__main__":
args = parse_args_for_score()
main(args)
main(args)
135 changes: 135 additions & 0 deletions eval/models/claude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Anthropic Claude Evaluator"""

import httpx
from anthropic import Anthropic
import requests
import json
from tqdm import tqdm
import random
import time
import pdb
from utils import encode_image_base64
import re


class ClaudeEvaluator:
def __init__(self, api_key, model='claude-3-opus-20240229', api_url=None, max_tokens=200, temperature=0.1, top_p=1, presence_penalty=0.0, frequency_penalty=0.0,use_client=False):
self.use_client =use_client
self.api_key = api_key
self.api_url = api_url
if self.use_client:
self.client = Anthropic(api_key=self.api_key ,base_url=self.api_url) # http_client=httpx.Client(proxies=api_url, transport=httpx.HTTPTransport(local_address="0.0.0.0"))
else:
self.header = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
self.post_dict = {
"model": model,
"system": None,
"messages": None,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
}
self.model = model

def prepare_inputs(self, question):
image_list = question.get("image_list")
prompted_content = question["prompted_content"]
if image_list:
match = re.findall("\[IMAGE_[0-9]+]", prompted_content)
assert len(match) == len(image_list)
content = []
for i, img_sub in enumerate(match):
img_token_start = prompted_content.index(img_sub)
prompted_content_split = prompted_content[:img_token_start].strip() + f" Image {i + 1}:"
content.append({
"type": "text",
"text": prompted_content_split
})
prompted_content = prompted_content[img_token_start + len(img_sub):]

base64_image = encode_image_base64(image_list[i]) # max_size = 512
content.append({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": base64_image,
},
})
content.append({
"type": "text",
"text": prompted_content
})

else:
content = [{
"type": "text",
"text": prompted_content
}]

return content

def generate_response(self, question):
content = self.prepare_inputs(question)
messages = [{
"role": "user",
"content": content
}]
system_message = question["prompted_system_content"]
if not self.use_client:
self.post_dict["system"] = system_message
self.post_dict["messages"] = messages

response = ""
i = 0
MAX_RETRY = 100

while i < MAX_RETRY:
try:
if self.use_client:
response_ = self.client.messages.create(model=self.model, system=system_message, messages=messages)
response = response_ # THIS HAS NOT BEEN VERIFIED
else:
response_ = requests.post(self.api_url, json=self.post_dict, headers=self.header)
response_ = response_.json()
response = response_["choices"][0]["message"]["content"]
except KeyboardInterrupt:
raise Exception("Terminated by user.")
except Exception as e:
print(e)
i += 1
time.sleep(1 + i / 10)
if i == 1 or i % 10 == 0:
error_type = response_.get("error", {}).get("type", "")
if error_type == 'upstream_error':
response = ""
feedback = error_type
return response, [system_message, messages], feedback
print(f"Retry {i} times...")
else:
break
if i >= MAX_RETRY:
raise Exception("Failed to generate response.")
return response, [system_message, messages], None

def generate_answer(self, question):
response, message_, feedback = self.generate_response(question)
message = {
"system": message_[0],
"messages": message_[1]
}
for i in range(len(message["messages"][0]["content"])):
if message["messages"][0]["content"][i]["type"] == "image":
message["messages"][0]["content"][i]["source"]["data"] = message["messages"][0]["content"][i]["source"]["data"][:32] + "..."
question["input_message"] = message
question["prediction"] = response
if feedback:
question["feedback"] = feedback
question.pop("prompted_content")
question.pop("prompted_system_content")
return question
10 changes: 5 additions & 5 deletions eval/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@ def prepare_inputs(self, question):
prompt = question["prompted_system_content"].strip() + "\n" + question["prompted_content"].strip()
content = [prompt,]

image_list = question.get("question_image_list")
image_list = question.get("image_list")
if image_list:
for image_path in image_list:
max_size = 512
image = encode_image_PIL(image_path, max_size=max_size)
image = encode_image_PIL(image_path) # max_size = 512
content.append(image)
return content

Expand All @@ -57,7 +56,8 @@ def generate_response(self, question):
if len(content) > 1:
response_ = self.model_with_vision.generate_content(content)
message = [content[0], ]
message.append(f"image no.{i+1}" for i in range(len(content) - 1))
for i in range(len(content) - 1):
message.append(str(content[i+1]))
else:
response_ = self.model_without_vision.generate_content(content)
message = content
Expand All @@ -70,7 +70,7 @@ def generate_response(self, question):
time.sleep(1 + i / 10)
if i == 1 or i % 10 == 0:
if str(e).endswith("if the prompt was blocked.") or str(e).endswith("lookup instead."):
response = "Gemini refused to answer this question."
response = ""
feedback = str(response_.prompt_feedback)
return response, message, feedback
print(f"Retry {i} times...")
Expand Down
16 changes: 9 additions & 7 deletions eval/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, api_key, model='gpt-3.5-turbo', api_url="https://api.openai.c
}

def prepare_inputs(self, question):
image_list = question.get("question_image_list")
image_list = question.get("image_list")
messages = [{
"role": "system",
"content": question["prompted_system_content"]
Expand All @@ -43,14 +43,12 @@ def prepare_inputs(self, question):
"text": question["prompted_content"]
},]}
for image_path in image_list:
max_size = 512
base64_image, origin_pixels = encode_image_base64(image_path, max_size=max_size)
detail = "high" if origin_pixels > max_size * max_size / 2 else "low"
base64_image = encode_image_base64(image_path) # max_size = 512
user_message["content"].append({
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": detail, # "auto"
"detail": "auto"
},},)
messages.append(user_message)
else:
Expand Down Expand Up @@ -87,10 +85,10 @@ def generate_response(self, question):
i += 1
time.sleep(1 + i / 10)
if i == 1 or i % 10 == 0:
if error.startswith("This model's maximum context length"):
if error.startswith("This model's maximum context length") or error.startswith("Your input image may contain"):
response = ""
feedback = error
return response, message,feedback
return response, message, feedback
print(f"Retry {i} times...")
else:
break
Expand All @@ -100,6 +98,10 @@ def generate_response(self, question):

def generate_answer(self, question):
response, message, feedback = self.generate_response(question)
if not isinstance(message[1]["content"], str):
for i in range(len(message[1]["content"])):
if message[1]["content"][i]["type"] == "image_url":
message[1]["content"][i]["image_url"]["url"] = message[1]["content"][i]["image_url"]["url"][:64]+"..."
question["input_message"] = message
question["prediction"] = response
if feedback:
Expand Down
29 changes: 16 additions & 13 deletions eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Tuple, List
from PIL import Image
import pdb
import io

import base64

Expand Down Expand Up @@ -50,24 +51,26 @@ def open_image(image_path, force_blank_return=True):
return image


def encode_image_base64(image_path,max_size=0):
def encode_image_base64(image_path,max_size=-1):
with open(image_path, "rb") as image_file:
if max_size > 0:
image = Image.open(image_file)
size = image.size
image.thumbnail((max_size, max_size))
image_file = image
return base64.b64encode(image_file.read()).decode('utf-8'),size.width*size.height


output_buffer = io.BytesIO()
image.save(output_buffer, format='png')
image_bytes = output_buffer.getvalue()
else:
image_bytes = image_file.read()
return base64.b64encode(image_bytes).decode('utf-8')

def encode_image_PIL(image_path,max_size=-1):
if max_size > 0:
image = Image.open(image_path)
image.thumbnail((max_size, max_size))
else:
image = Image.open(image_path)
return image

def encode_image_PIL(image_path,max_size=0):
with open(image_path, "rb") as image_file:
if max_size > 0:
image = Image.open(image_file)
image.thumbnail((max_size, max_size))
image_file = image
return image_file.read()

def infer_lang_from_question(question):
question_type = question["question_type"]
Expand Down

0 comments on commit db67800

Please sign in to comment.