Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] support for more vlms #390

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions auto_round/mllm/autoround_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def _only_text_test(model, tokenizer, device):
text = ["only text", "test"]
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if device.split(':')[0] != model.device.type:
tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.unk_token
if device.split(':')[0] != model.device.type: # TODO: OOM
model = model.to(device)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
model(**inputs)
Expand Down Expand Up @@ -160,6 +160,9 @@ def __init__(
self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)
dataset = self.template.default_dataset if dataset is None else dataset

if model.config.model_type == "deepseek_vl_v2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the setting here is a little tricky. Could the quantizing-non-text-module still be supported?

model.forward = model.language.forward

from ..calib_dataset import CALIB_DATASETS
from .mllm_dataset import MLLM_DATASET
if isinstance(dataset, str):
Expand Down Expand Up @@ -258,6 +261,7 @@ def calib(self, nsamples, bs):
template=self.template,
model=self.model,
tokenizer=self.tokenizer,
processor=self.processor,
image_processor=self.image_processor,
dataset=dataset,
extra_data_dir=self.extra_data_dir,
Expand Down Expand Up @@ -326,7 +330,7 @@ def calib(self, nsamples, bs):
data_new = {}
for key in data.keys():
data_new[key] = to_device(data[key], self.model.device)
if key == 'images':
if key in ['images', 'pixel_values']:
data_new[key] = to_dtype(data_new[key], self.model.dtype)
input_ids = data_new["input_ids"]

Expand Down
3 changes: 2 additions & 1 deletion auto_round/mllm/mllm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def get_mllm_dataloader(
template,
model,
tokenizer,
processor,
image_processor=None,
dataset="liuhaotian/llava_conv_58k",
extra_data_dir=None,
Expand Down Expand Up @@ -222,7 +223,7 @@ def get_mllm_dataloader(
"""
if isinstance(template, str):
from .template import get_template
template = get_template(template, model=model, tokenizer=tokenizer, image_processor=image_processor)
template = get_template(template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)

if os.path.isfile(dataset) or dataset in MLLM_DATASET.keys():
dataset = MLLM_DATASET['liuhaotian/llava'](
Expand Down
45 changes: 45 additions & 0 deletions auto_round/mllm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,51 @@ def squeeze_result(ret):
return ret


@regist_processor("hf")
class HFProcessor(BasicProcessor):
IMAGE_TOKEN = '<image>'
def __init__(self):
pass

def post_init(self, model, tokenizer, processor=None, image_processor=None, **kwargs):
self.model = model
self.tokenizer = tokenizer
self.processor = processor
if image_processor is not None:
self.image_processor = image_processor
else:
self.image_processor = self.default_image_processor

def get_input(
self,
text,
images,
return_tensors="pt",
squeeze=True,
max_length=None,
truncation=False,
truncation_strategy="text",
**kwargs):

messages = []
for content in text:
messages.append({
"role": content['role'],
"content": [
{"text": content["content"].replace(self.IMAGE_TOKEN, ""), "type": "text"}
]
})
if self.IMAGE_TOKEN in content['content']:
messages[-1]["content"].append({"text": None, "type": "image"})
text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
if images is not None:
images = self.image_processor(images)
ret = self.processor(text=text, images=images, return_tensors="pt")
if squeeze:
ret = self.squeeze_result(ret)
return ret


@regist_processor("qwen2_vl")
class Qwen2VLProcessor(BasicProcessor):
@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion auto_round/mllm/templates/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
"replace_tokens": null,
"extra_encode" : false,
"default_dataset": "NeelNanda/pile-10k",
"processor": "basic"
"processor": "hf"
}
13 changes: 12 additions & 1 deletion auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ def tune(args):
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
args.device = ",".join(map(str, range(len(devices))))
devices = args.device.replace(" ", "").split(',')
use_auto_mapping = True
if len(devices) > 1: ##for 70B model on single card, use auto will cause some layer offload to cpu
use_auto_mapping = True
elif args.device == "auto":
use_auto_mapping == True

Expand All @@ -289,6 +290,13 @@ def tune(args):
model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
model_type = "llava"
elif "deepseek" in model_name.lower():
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
processor = DeepseekVLV2Processor.from_pretrained(model_name)
tokenizer = processor.tokenizer
model: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype,
device_map="auto" if use_auto_mapping else None)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
Expand All @@ -302,6 +310,9 @@ def tune(args):
elif "mllama" in model_type:
from transformers import MllamaForConditionalGeneration
cls = MllamaForConditionalGeneration
elif "idefics3" in model_type:
from transformers import AutoModelForVision2Seq
cls = AutoModelForVision2Seq
else:
cls = AutoModelForCausalLM

Expand Down
13 changes: 9 additions & 4 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,16 @@ def get_multimodal_block_names(model, quant_vision=False):
"""
block_names = []
target_modules = []
vison_blocks_tuple = ("vision", "visual",)
vison_blocks_tuple = ("vision", "visual", "projector")
module_list_type = ("ModuleList", "Sequential")
last_module_list = None
for n, m in model.named_modules():
if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__:
if quant_vision or all(key not in n.lower() for key in (vison_blocks_tuple)):
target_modules.append((n, m))
# if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__:
if hasattr(type(m), "__name__") and any(key in type(m).__name__ for key in module_list_type):
if quant_vision or all(key not in n.lower() for key in vison_blocks_tuple):
if last_module_list is None or last_module_list not in n:
last_module_list = n
target_modules.append((n, m))
validate_modules(target_modules, quant_vision, vison_blocks_tuple)
for i, target_m in enumerate(target_modules):
block_names.append([])
Expand Down
Loading