Skip to content

Commit

Permalink
init internvl_hf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesZhutheThird committed Feb 10, 2024
1 parent 1422fad commit 3cb4d97
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

🌐 [Website](https://opendfm.github.io/MULTI-Benchmark/)

📃 [Paper](https://arxiv.org/abs/2402.03173/)
📃 [Paper](https://arxiv.org/abs/2402.03173/)

🤗 [Dataset](https://opendfm.github.io/MULTI-Benchmark/) (Coming Soon)

Expand Down
7 changes: 7 additions & 0 deletions eval/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@
"evaluator": "DFMEvaluator",
"split_sys": True,
},
"intern-vl": {
"model_type": "local",
"support_input": [0, 1, 2, 3],
"executor": "internvl",
"evaluator": "InternVLEvaluator",
"split_sys": False,
},
}


Expand Down
72 changes: 72 additions & 0 deletions eval/models/internvl_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""InternVL Evaluator with HuggingFace Transformers"""

from transformers import AutoTokenizer, AutoModel, CLIPImageProcessor
import torch
from PIL import Image


class InternVLEvaluator():
def __init__(self, model_dir="OpenGVLab/InternVL-Chat-Chinese-V1-1", device_map="auto"):
self.model_dir = model_dir
self.sample_params = {
"max_new_tokens": 512,
"do_sample": False,
"num_beams": 1,
}

self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
self.image_processor = CLIPImageProcessor.from_pretrained(self.model_dir)
self.model = AutoModel.from_pretrained(self.model_dir, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device_map).eval()

# self.model.generation_config.__dict__.update(self.sample_params)

def generate_response(self, input):
if isinstance(input, dict):
question = input
image_path = question.get("image_list", [""])[0]
image = Image.open(image_path).convert('RGB')
image = image.resize((448, 448))
pixel_values = self.image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()
message = question["prompted_content"]
response, _ = self.model.chat(self.tokenizer, pixel_values, message, None, **self.sample_params)
return response, message

elif isinstance(input, tuple):
# question with multiple images
assert len(input) == 3, "Input tuple must have 3 elements. (prompt, image_path, history)"
message, image_path, history = input
image = Image.open(image_path).convert('RGB')
image = image.resize((448, 448))
pixel_values = self.image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()
response, history = self.model.chat(self.tokenizer, pixel_values, message, history, **self.sample_params)
return response, history, message
else:
raise ValueError(f"input type not supported: {type(input)}")

def generate_answer(self, question):
if question.get("prompted_content"):
assert len(question.get("image_list", [""])) <= 1, "VisualGLM model only supports one image at one time."
response, message = self.generate_response(question)
question["input_message"] = message
question.pop("prompted_content")
elif question.get("prompted_content_list"):
# Processing questions with multiple images in a model of seemingly 1-image support is essential.
# We consider multiple-rounds chat to send images separately,
prompted_content_list = question.get("prompted_content_list")
image_list = question.get("image_list").copy()
# image_list.append("")
history = None
assert len(prompted_content_list) == len(image_list), f"Length of prompted_content_list and image_list must be the same. \n{question}"
question["answer_history"] = []
question["input_message_list"] = []
for multi_rounds_prompt, image_path in zip(prompted_content_list, image_list):
response, history, message = self.generate_response((multi_rounds_prompt, image_path, history))
question["answer_history"].append(response)
question["input_message_list"].append(message)
question.pop("prompted_content_list")
else:
raise ValueError(f"Question not supported: {question}")
question["prediction"] = response
return question

0 comments on commit 3cb4d97

Please sign in to comment.