Skip to content

Commit

Permalink
add support for InternVL
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesZhutheThird committed Feb 11, 2024
1 parent 3cb4d97 commit 3b310ac
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
20 changes: 14 additions & 6 deletions eval/models/internvl_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import AutoTokenizer, AutoModel, CLIPImageProcessor
import torch
from PIL import Image

import pdb

class InternVLEvaluator():
def __init__(self, model_dir="OpenGVLab/InternVL-Chat-Chinese-V1-1", device_map="auto"):
Expand All @@ -24,30 +24,38 @@ def generate_response(self, input):
if isinstance(input, dict):
question = input
image_path = question.get("image_list", [""])[0]
image = Image.open(image_path).convert('RGB')
image = image.resize((448, 448))
if len(image_path) > 0:
image = Image.open(image_path).convert('RGB')
image = image.resize((448, 448))
else:
image = Image.new("RGB", (448,448), (0, 0, 0))
pixel_values = self.image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()
message = question["prompted_content"]
response, _ = self.model.chat(self.tokenizer, pixel_values, message, None, **self.sample_params)
response = self.model.chat(self.tokenizer, pixel_values, message, self.sample_params)
return response, message

elif isinstance(input, tuple):
# question with multiple images

assert len(input) == 3, "Input tuple must have 3 elements. (prompt, image_path, history)"
message, image_path, history = input
'''
image = Image.open(image_path).convert('RGB')
image = image.resize((448, 448))
pixel_values = self.image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()
response, history = self.model.chat(self.tokenizer, pixel_values, message, history, **self.sample_params)
response, history = self.model.chat(self.tokenizer, pixel_values, message, self.sample_params)
return response, history, message
'''
print(f"multiple image input is not supported")
return "", "", message
else:
raise ValueError(f"input type not supported: {type(input)}")

def generate_answer(self, question):
if question.get("prompted_content"):
assert len(question.get("image_list", [""])) <= 1, "VisualGLM model only supports one image at one time."
# assert len(question.get("image_list", [""])) <= 1, "InternVL model only supports one image at one time."
response, message = self.generate_response(question)
question["input_message"] = message
question.pop("prompted_content")
Expand Down
2 changes: 2 additions & 0 deletions eval/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def postprocess_prompt(content, in_turn=True, remove_image_token=False):
img_token_start = content.index(img_sub)
if remove_image_token:
prompted_content_list.append(content[:img_token_start].strip())
else:
prompted_content_list.append(content[:img_token_start].strip() + img_sub)
content = content[img_token_start + len(img_sub):]
prompted_content_list.append(content.strip())
prompted_content_list[-2] += '\n' + prompted_content_list[-1]
Expand Down

0 comments on commit 3b310ac

Please sign in to comment.