Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo committed Nov 1, 2024
1 parent 9370532 commit 364e3d7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 41 deletions.
1 change: 1 addition & 0 deletions auto_round/mllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
from .mllm_dataset import get_mllm_dataloader
from .template import Template, get_template, TEMPLATES
from .autoround_mllm import AutoRoundMLLM
from ..utils import LazyImport
from .eval import mllm_eval
78 changes: 37 additions & 41 deletions auto_round/mllm/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,41 +35,31 @@
from ..utils import logger, LazyImport
vlmeval = LazyImport("vlmeval")

from vlmeval.vlm import *
from vlmeval.api import *
from vlmeval.config import supported_VLM
from vlmeval.dataset import build_dataset
from vlmeval.inference import infer_data_job
from vlmeval.inference_video import infer_data_job_video
from vlmeval.inference_mt import infer_data_job_mt
from vlmeval.smp import listinstr, MMBenchOfficialServer
from vlmeval.utils.result_transfer import MMMU_result_transfer, MMTBench_result_transfer


MODEL_TYPE_TO_VLMEVAL_MODEL = {
#model_name
"Qwen-VL": dict(cls=QwenVL),
"Qwen-VL-Chat": dict(cls=QwenVLChat),
"Qwen2-VL": dict(cls=Qwen2VLChat, min_pixels=1280*28*28, max_pixels=16384*28*28),
"Llama-3.2": dict(cls=llama_vision),
"Phi-3-vision": dict(cls=Phi3Vision),
"Phi-3.5-vision": dict(cls=Phi3_5Vision),
"llava_v1.5": dict(cls=LLaVA),
"llava_v1.6": dict(cls=LLaVA_Next),
"llava-onevision-qwen2": dict(cls=LLaVA_OneVision),
"cogvlm2": dict(cls=CogVlm),
"SliME": dict(cls=SliME),
"Eagle": dict(cls=Eagle),
"Molmo": dict(cls=molmo),
"Qwen-VL": dict(cls="QwenVL"),
"Qwen-VL-Chat": dict(cls="QwenVLChat"),
"Qwen2-VL": dict(cls="Qwen2VLChat", min_pixels=1280*28*28, max_pixels=16384*28*28),
"Llama-3.2": dict(cls="llama_vision"),
"Phi-3-vision": dict(cls="Phi3Vision"),
"Phi-3.5-vision": dict(cls="Phi3_5Vision"),
"llava_v1.5": dict(cls="LLaVA"),
"llava_v1.6": dict(cls="LLaVA_Next"),
"llava-onevision-qwen2": dict(cls="LLaVA_OneVision"),
"cogvlm2": dict(cls="CogVlm"),
"SliME": dict(cls="SliME"),
"Eagle": dict(cls="Eagle"),
"Molmo": dict(cls="molmo"),

# config.model_type
"qwen2_vl": dict(cls=Qwen2VLChat, min_pixels=1280*28*28, max_pixels=16384*28*28),
"qwen": dict(cls=QwenVL),
"qwen_chat": dict(cls=QwenVLChat),
"llava": dict(cls=LLaVA),
"llava_next": dict(cls=LLaVA_Next),
"phi3_v": dict(cls=Phi3Vision),
"mllama": dict(cls=llama_vision),
"qwen2_vl": dict(cls="Qwen2VLChat", min_pixels=1280*28*28, max_pixels=16384*28*28),
"qwen": dict(cls="QwenVL"),
"qwen_chat": dict(cls="QwenVLChat"),
"llava": dict(cls="LLaVA"),
"llava_next": dict(cls="LLaVA_Next"),
"phi3_v": dict(cls="Phi3Vision"),
"mllama": dict(cls="llama_vision"),
}

def mllm_eval(
Expand All @@ -87,6 +77,11 @@ def mllm_eval(
mode: str = 'all',
ignore: bool = False
):

try:
from auto_round import AutoRoundConfig
except:
from auto_round.auto_quantizer import AutoHfQuantizer

model = None
if data_store_dir is not None:
Expand Down Expand Up @@ -117,7 +112,8 @@ def mllm_eval(
kwargs = MODEL_TYPE_TO_VLMEVAL_MODEL[model_type]
kwargs["model_path"] = pretrained_model_name_or_path
model_cls = kwargs.pop("cls")
supported_VLM[model_name] = partial(model_cls, **kwargs)
model_cls = getattr(vlmeval.vlm, model_cls)
vlmeval.config.supported_VLM[model_name] = partial(model_cls, **kwargs)

pred_root = os.path.join(work_dir, model_name)
os.makedirs(pred_root, exist_ok=True)
Expand All @@ -132,7 +128,7 @@ def mllm_eval(
if dataset_name == 'Video-MME':
dataset_kwargs['use_subtitle'] = use_subtitle

dataset = build_dataset(dataset_name, **dataset_kwargs)
dataset = vlmeval.dataset.build_dataset(dataset_name, **dataset_kwargs)
if dataset is None:
logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
continue
Expand Down Expand Up @@ -173,7 +169,7 @@ def mllm_eval(

# Perform the Inference
if dataset.MODALITY == 'VIDEO':
model = infer_data_job_video(
model = vlmeval.inference_video.infer_data_job_video(
model,
work_dir=pred_root,
model_name=model_name,
Expand All @@ -184,15 +180,15 @@ def mllm_eval(
subtitle=use_subtitle,
fps=fps)
elif dataset.TYPE == 'MT':
model = infer_data_job_mt(
model = vlmeval.inference_mt.infer_data_job_mt(
model,
work_dir=pred_root,
model_name=model_name,
dataset=dataset,
verbose=verbose,
ignore_failed=ignore)
else:
model = infer_data_job(
model = vlmeval.inference.infer_data_job(
model,
work_dir=pred_root,
model_name=model_name,
Expand All @@ -207,12 +203,12 @@ def mllm_eval(
if judge is not None:
judge_kwargs['model'] = judge
else:
if dataset.TYPE in ['MCQ', 'Y/N'] or listinstr(['MathVerse'], dataset_name):
if dataset.TYPE in ['MCQ', 'Y/N'] or vlmeval.smp.listinstr(['MathVerse'], dataset_name):
judge_kwargs['model'] = 'chatgpt-0125'
elif listinstr(['MMVet', 'MathVista', 'LLaVABench', 'MMBench-Video', 'MathVision'],
elif vlmeval.smp.listinstr(['MMVet', 'MathVista', 'LLaVABench', 'MMBench-Video', 'MathVision'],
dataset_name):
judge_kwargs['model'] = 'gpt-4-turbo'
elif listinstr(['MMLongBench', 'MMDU', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI'],
elif vlmeval.smp.listinstr(['MMLongBench', 'MMDU', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI'],
dataset_name):
judge_kwargs['model'] = 'gpt-4o'
if 'OPENAI_API_KEY_JUDGE' in os.environ and len(os.environ['OPENAI_API_KEY_JUDGE']):
Expand All @@ -221,12 +217,12 @@ def mllm_eval(
judge_kwargs['api_base'] = os.environ['OPENAI_API_BASE_JUDGE']

if dataset_name in ['MMMU_TEST']:
result_json = MMMU_result_transfer(result_file)
result_json = vlmeval.utils.result_transfer.MMMU_result_transfer(result_file)
logger.info(f'Transfer MMMU_TEST result to json for official evaluation, '
f'json file saved in {result_json}') # noqa: E501
continue
elif 'MMT-Bench_ALL' in dataset_name:
submission_file = MMTBench_result_transfer(result_file, **judge_kwargs)
submission_file = vlmeval.utils.result_transfer.MMTBench_result_transfer(result_file, **judge_kwargs)
logger.info(f'Extract options from prediction of MMT-Bench FULL split for official evaluation '
f'(https://eval.ai/web/challenges/challenge-page/2328/overview), '
f'submission file saved in {submission_file}') # noqa: E501
Expand All @@ -246,7 +242,7 @@ def mllm_eval(
'MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMBench', 'MMBench_CN',
'MMBench_TEST_CN_V11', 'MMBench_TEST_EN_V11', 'MMBench_V11', 'MMBench_CN_V11'
]:
if not MMBenchOfficialServer(dataset_name):
if not vlmeval.smp.MMBenchOfficialServer(dataset_name):
logger.error(
f'Can not evaluate {dataset_name} on non-official servers, '
'will skip the evaluation. '
Expand Down

0 comments on commit 364e3d7

Please sign in to comment.