diff --git a/eval/args.py b/eval/args.py index 6f1c1cb..29803c1 100644 --- a/eval/args.py +++ b/eval/args.py @@ -21,6 +21,14 @@ "evaluator": "GPTEvaluator", "split_sys": True, }, + "claude": { + "avail_model": ["claude-3-opus-20240229", "claude-3-sonnet-20240229"], + "model_type": "api", + "support_input": [0, 1,2,3], + "executor": "claude", + "evaluator": "ClaudeEvaluator", + "split_sys": True, + }, "geminivision": { "avail_model": ["gemini-pro-vision", ], "model_type": "api", diff --git a/eval/metrics.py b/eval/metrics.py index 44e3961..8787be1 100644 --- a/eval/metrics.py +++ b/eval/metrics.py @@ -440,4 +440,4 @@ def main(args): if __name__ == "__main__": args = parse_args_for_score() - main(args) + main(args) \ No newline at end of file diff --git a/eval/models/claude.py b/eval/models/claude.py new file mode 100644 index 0000000..60212fb --- /dev/null +++ b/eval/models/claude.py @@ -0,0 +1,135 @@ +"""Anthropic Claude Evaluator""" + +import httpx +from anthropic import Anthropic +import requests +import json +from tqdm import tqdm +import random +import time +import pdb +from utils import encode_image_base64 +import re + + +class ClaudeEvaluator: + def __init__(self, api_key, model='claude-3-opus-20240229', api_url=None, max_tokens=200, temperature=0.1, top_p=1, presence_penalty=0.0, frequency_penalty=0.0,use_client=False): + self.use_client =use_client + self.api_key = api_key + self.api_url = api_url + if self.use_client: + self.client = Anthropic(api_key=self.api_key ,base_url=self.api_url) # http_client=httpx.Client(proxies=api_url, transport=httpx.HTTPTransport(local_address="0.0.0.0")) + else: + self.header = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + self.post_dict = { + "model": model, + "system": None, + "messages": None, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + } + self.model = model + + def prepare_inputs(self, question): + image_list = question.get("image_list") + prompted_content = question["prompted_content"] + if image_list: + match = re.findall("\[IMAGE_[0-9]+]", prompted_content) + assert len(match) == len(image_list) + content = [] + for i, img_sub in enumerate(match): + img_token_start = prompted_content.index(img_sub) + prompted_content_split = prompted_content[:img_token_start].strip() + f" Image {i + 1}:" + content.append({ + "type": "text", + "text": prompted_content_split + }) + prompted_content = prompted_content[img_token_start + len(img_sub):] + + base64_image = encode_image_base64(image_list[i]) # max_size = 512 + content.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": base64_image, + }, + }) + content.append({ + "type": "text", + "text": prompted_content + }) + + else: + content = [{ + "type": "text", + "text": prompted_content + }] + + return content + + def generate_response(self, question): + content = self.prepare_inputs(question) + messages = [{ + "role": "user", + "content": content + }] + system_message = question["prompted_system_content"] + if not self.use_client: + self.post_dict["system"] = system_message + self.post_dict["messages"] = messages + + response = "" + i = 0 + MAX_RETRY = 100 + + while i < MAX_RETRY: + try: + if self.use_client: + response_ = self.client.messages.create(model=self.model, system=system_message, messages=messages) + response = response_ # THIS HAS NOT BEEN VERIFIED + else: + response_ = requests.post(self.api_url, json=self.post_dict, headers=self.header) + response_ = response_.json() + response = response_["choices"][0]["message"]["content"] + except KeyboardInterrupt: + raise Exception("Terminated by user.") + except Exception as e: + print(e) + i += 1 + time.sleep(1 + i / 10) + if i == 1 or i % 10 == 0: + error_type = response_.get("error", {}).get("type", "") + if error_type == 'upstream_error': + response = "" + feedback = error_type + return response, [system_message, messages], feedback + print(f"Retry {i} times...") + else: + break + if i >= MAX_RETRY: + raise Exception("Failed to generate response.") + return response, [system_message, messages], None + + def generate_answer(self, question): + response, message_, feedback = self.generate_response(question) + message = { + "system": message_[0], + "messages": message_[1] + } + for i in range(len(message["messages"][0]["content"])): + if message["messages"][0]["content"][i]["type"] == "image": + message["messages"][0]["content"][i]["source"]["data"] = message["messages"][0]["content"][i]["source"]["data"][:32] + "..." + question["input_message"] = message + question["prediction"] = response + if feedback: + question["feedback"] = feedback + question.pop("prompted_content") + question.pop("prompted_system_content") + return question diff --git a/eval/models/gemini.py b/eval/models/gemini.py index b6cc012..cb6fcac 100644 --- a/eval/models/gemini.py +++ b/eval/models/gemini.py @@ -38,11 +38,10 @@ def prepare_inputs(self, question): prompt = question["prompted_system_content"].strip() + "\n" + question["prompted_content"].strip() content = [prompt,] - image_list = question.get("question_image_list") + image_list = question.get("image_list") if image_list: for image_path in image_list: - max_size = 512 - image = encode_image_PIL(image_path, max_size=max_size) + image = encode_image_PIL(image_path) # max_size = 512 content.append(image) return content @@ -57,7 +56,8 @@ def generate_response(self, question): if len(content) > 1: response_ = self.model_with_vision.generate_content(content) message = [content[0], ] - message.append(f"image no.{i+1}" for i in range(len(content) - 1)) + for i in range(len(content) - 1): + message.append(str(content[i+1])) else: response_ = self.model_without_vision.generate_content(content) message = content @@ -70,7 +70,7 @@ def generate_response(self, question): time.sleep(1 + i / 10) if i == 1 or i % 10 == 0: if str(e).endswith("if the prompt was blocked.") or str(e).endswith("lookup instead."): - response = "Gemini refused to answer this question." + response = "" feedback = str(response_.prompt_feedback) return response, message, feedback print(f"Retry {i} times...") diff --git a/eval/models/gpt.py b/eval/models/gpt.py index dbd5711..ee7ff64 100644 --- a/eval/models/gpt.py +++ b/eval/models/gpt.py @@ -29,7 +29,7 @@ def __init__(self, api_key, model='gpt-3.5-turbo', api_url="https://api.openai.c } def prepare_inputs(self, question): - image_list = question.get("question_image_list") + image_list = question.get("image_list") messages = [{ "role": "system", "content": question["prompted_system_content"] @@ -43,14 +43,12 @@ def prepare_inputs(self, question): "text": question["prompted_content"] },]} for image_path in image_list: - max_size = 512 - base64_image, origin_pixels = encode_image_base64(image_path, max_size=max_size) - detail = "high" if origin_pixels > max_size * max_size / 2 else "low" + base64_image = encode_image_base64(image_path) # max_size = 512 user_message["content"].append({ "type": "image_url", "image_url": { "url": f"data:image/png;base64,{base64_image}", - "detail": detail, # "auto" + "detail": "auto" },},) messages.append(user_message) else: @@ -87,10 +85,10 @@ def generate_response(self, question): i += 1 time.sleep(1 + i / 10) if i == 1 or i % 10 == 0: - if error.startswith("This model's maximum context length"): + if error.startswith("This model's maximum context length") or error.startswith("Your input image may contain"): response = "" feedback = error - return response, message,feedback + return response, message, feedback print(f"Retry {i} times...") else: break @@ -100,6 +98,10 @@ def generate_response(self, question): def generate_answer(self, question): response, message, feedback = self.generate_response(question) + if not isinstance(message[1]["content"], str): + for i in range(len(message[1]["content"])): + if message[1]["content"][i]["type"] == "image_url": + message[1]["content"][i]["image_url"]["url"] = message[1]["content"][i]["image_url"]["url"][:64]+"..." question["input_message"] = message question["prediction"] = response if feedback: diff --git a/eval/utils.py b/eval/utils.py index 3268369..1fe7e3a 100644 --- a/eval/utils.py +++ b/eval/utils.py @@ -6,6 +6,7 @@ from typing import Tuple, List from PIL import Image import pdb +import io import base64 @@ -50,24 +51,26 @@ def open_image(image_path, force_blank_return=True): return image -def encode_image_base64(image_path,max_size=0): +def encode_image_base64(image_path,max_size=-1): with open(image_path, "rb") as image_file: if max_size > 0: image = Image.open(image_file) - size = image.size image.thumbnail((max_size, max_size)) - image_file = image - return base64.b64encode(image_file.read()).decode('utf-8'),size.width*size.height - - + output_buffer = io.BytesIO() + image.save(output_buffer, format='png') + image_bytes = output_buffer.getvalue() + else: + image_bytes = image_file.read() + return base64.b64encode(image_bytes).decode('utf-8') + +def encode_image_PIL(image_path,max_size=-1): + if max_size > 0: + image = Image.open(image_path) + image.thumbnail((max_size, max_size)) + else: + image = Image.open(image_path) + return image -def encode_image_PIL(image_path,max_size=0): - with open(image_path, "rb") as image_file: - if max_size > 0: - image = Image.open(image_file) - image.thumbnail((max_size, max_size)) - image_file = image - return image_file.read() def infer_lang_from_question(question): question_type = question["question_type"]