diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml new file mode 100644 index 0000000000..9129d01221 --- /dev/null +++ b/examples/llava/lora-7b.yaml @@ -0,0 +1,63 @@ +base_model: llava-hf/llava-1.5-7b-hf +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: llava +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml new file mode 100644 index 0000000000..ab70afcda8 --- /dev/null +++ b/examples/pixtral/lora-12b.yml @@ -0,0 +1,65 @@ +base_model: mistral-community/pixtral-12b +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: pixtral +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: <|end_of_text|> diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml new file mode 100644 index 0000000000..e7ab13ddb3 --- /dev/null +++ b/examples/qwen2-vl/lora-7b.yaml @@ -0,0 +1,63 @@ +base_model: Qwen/Qwen2-VL-7B-Instruct +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: qwen2_vl +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e81740399a..a7cdd014ee 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -53,6 +53,7 @@ from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler +from axolotl.processing_strategies import get_processing_strategy from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, @@ -2015,8 +2016,9 @@ def build_collator( else: if self.cfg.processor_type and self.processor: collator = MultiModalChatDataCollator - kwargs["processor"] = self.processor - kwargs["chat_template"] = training_args.chat_template + kwargs["processing_strategy"] = get_processing_strategy( + self.processor, training_args.chat_template, self.cfg.chat_template + ) elif self.cfg.batch_flattening: collator = DataCollatorWithFlattening collator_args.pop(0) diff --git a/src/axolotl/processing_strategies/__init__.py b/src/axolotl/processing_strategies/__init__.py new file mode 100644 index 0000000000..054426d383 --- /dev/null +++ b/src/axolotl/processing_strategies/__init__.py @@ -0,0 +1,210 @@ +"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types""" +from copy import deepcopy +from typing import Optional + +from PIL import Image +from transformers import ProcessorMixin + + +class ProcessingStrategy: + """Base Processing Strategy class""" + + def __init__(self, processor: ProcessorMixin, chat_template: Optional[str] = None): + self.processor = processor + self.chat_template = chat_template + try: + self.image_token = processor.image_token + self.image_token_id = processor.tokenizer.convert_tokens_to_ids( + self.image_token + ) + except AttributeError: + pass + + @staticmethod + def preprocess(examples: list[dict]) -> list[dict]: + """ + Preprocess conversation examples to ensure consistent format. + Converts different conversation formats to OpenAI format with 'messages'. + Supports two formats: + 1. OpenAI format with 'messages' + 2. Legacy format with 'conversations' + + Args: + examples: list of conversation dictionaries + Returns: + dict in OpenAI format with 'messages' key + + Raises: + ValueError: If the conversation format is not supported + """ + role_mapping = { + "human": "user", + "gpt": "assistant", + } + + def normalize_role(role: str) -> str: + """Normalize role names to OpenAI format. Default to original role if not found.""" + return role_mapping.get(role, role) + + def convert_legacy_format(example: dict) -> dict: + """Convert legacy 'conversations' format to OpenAI 'messages' format.""" + messages = [ + { + "role": normalize_role(convo["from"]), + "content": convo["value"], + } + for convo in example["conversations"] + ] + + # Create new dict without 'conversations' key + result = deepcopy(example) + result.pop("conversations") + result["messages"] = messages + return result + + processed_examples = [] + for example in examples: + # OpenAI format + if "messages" in example: + processed_examples.append(example) + + # Legacy format + elif "conversations" in example: + processed_examples.append(convert_legacy_format(example)) + + else: + raise ValueError( + "Only `messages` and `conversations` message keys are currently supported." + ) + + return processed_examples + + @staticmethod + def process_images(examples, max_images): + """ + Process images from examples, ensuring consistency in image presence and applying max_images limit. + + Args: + examples: List of dictionaries that may contain 'images' key + max_images: Maximum number of images to keep per example (0 means no limit) + + Returns: + Either None (if no images) or List[Image objects] (if all examples have images) + + Raises: + ValueError: If there's a mix of None and non-None images + """ + + def get_image(example): + if "images" not in example: + return None + images = example["images"] + if isinstance(images, str): + return Image.open(images) + return images + + images = [get_image(example) for example in examples] + + # Count None and non-None images + none_count = sum(1 for img in images if img is None) + + # All images are None + if none_count == len(images): + return None + + # Mix of None and non-None images + if none_count > 0: + raise ValueError( + "All images should be either None or not None. " + "Please provide images for all examples or None." + ) + + # Apply max_images limit if specified + if max_images > 0: + images = [ + ( + img_batch[:max_images] + if isinstance(img_batch, (list, tuple)) + else img_batch + ) + for img_batch in images + ] + + return images + + def process_texts(self, examples): + texts = [ + self.processor.apply_chat_template( + example["messages"], chat_template=self.chat_template, tokenize=False + ) + for example in examples + ] + return texts + + +class PixtralProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Pixtral""" + + @staticmethod + def pixtral_chat_conversion(messages): + is_single_message = not isinstance(messages, list) + if is_single_message: + messages = [messages] + + for i, message in enumerate(messages): + if message["role"] == "user": + for j, content in enumerate(message["content"]): + if "type" in content and content["type"] == "text": + messages[i]["content"][j] = { + "type": "text", + "content": content["text"], + } + + if message["role"] == "assistant": + messages[i]["content"] = message["content"][0]["text"] + + if is_single_message: + return messages[0] + return messages + + def process_texts(self, examples): + texts = [ + self.processor.apply_chat_template( + __class__.pixtral_chat_conversion(example["messages"]), + chat_template=self.chat_template, + tokenize=False, + ) + for example in examples + ] + return texts + + +class Qwen2VLProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Qwen2-VL""" + + def __init__(self, processor: ProcessorMixin, chat_template: Optional[str] = None): + super().__init__(processor, chat_template) + self.image_token = "<|image_pad|>" # nosec + self.image_token_id = processor.tokenizer.convert_tokens_to_ids( + self.image_token + ) + + +class LlavaProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Llava""" + + @staticmethod + def process_images(examples, max_images): + images = ProcessingStrategy.process_images(examples, max_images) + images = [image[0] for image in images] + return images + + +def get_processing_strategy( + processor: ProcessorMixin, chat_template, chat_template_type +): + if chat_template_type == "pixtral": + return PixtralProcessingStrategy(processor, chat_template) + if chat_template_type == "llava": + return LlavaProcessingStrategy(processor, chat_template) + return ProcessingStrategy(processor, chat_template) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 682a0449e8..6a365456eb 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -24,6 +24,7 @@ "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n', + "llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", @@ -31,6 +32,8 @@ "qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", "exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}", "metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}", + "pixtral": '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["content"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {{- message["content"] + eos_token}}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}', + "qwen2_vl": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}", } diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 6f8a64ad85..dea1371272 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -2,15 +2,15 @@ Collators for multi-modal chat messages and packing """ -from copy import deepcopy from dataclasses import dataclass from typing import Any, Optional, Union -from PIL import Image -from transformers import PreTrainedTokenizerBase, ProcessorMixin +from transformers import PreTrainedTokenizerBase from transformers.data.data_collator import DataCollatorMixin from transformers.utils import PaddingStrategy +from ...processing_strategies import ProcessingStrategy + @dataclass class MultiModalChatDataCollator(DataCollatorMixin): @@ -19,9 +19,8 @@ class MultiModalChatDataCollator(DataCollatorMixin): """ tokenizer: PreTrainedTokenizerBase - processor: ProcessorMixin + processing_strategy: ProcessingStrategy return_tensors: str = "pt" - chat_template: Optional[str] = None packing: bool = False max_images: int = -1 padding: Union[bool, str, PaddingStrategy] = True @@ -35,154 +34,42 @@ def torch_call( self, examples: list[Union[list[int], Any, dict[str, Any]]] ) -> dict[str, Any]: # Handle dict or lists with proper padding and conversion to tensor. - return self.__class__.process_rows( - examples, self.processor, self.chat_template, self.max_images + examples, + self.processing_strategy, + self.max_images, ) @staticmethod - def process_rows(examples, processor, chat_template, max_images, length_only=False): + def process_rows( + examples, + processing_strategy: ProcessingStrategy, + max_images, + length_only=False, + ): # HINT: use `_torch_collate_batch` to stack and pad tensors # see also DataCollatorWithFlattening and DefaultDataCollator # *** This is COPIED from the trl example sft_vlm.py code *** # use this as a starting point - def _preprocess(examples: list[dict]) -> list[dict]: - """ - Preprocess conversation examples to ensure consistent format. - - Converts different conversation formats to OpenAI format with 'messages'. - Supports two formats: - 1. OpenAI format with 'messages' - 2. Legacy format with 'conversations' - - Args: - examples: list of conversation dictionaries - - Returns: - dict in OpenAI format with 'messages' key - - Raises: - ValueError: If the conversation format is not supported - """ - role_mapping = { - "human": "user", - "gpt": "assistant", - } - - def normalize_role(role: str) -> str: - """Normalize role names to OpenAI format. Default to original role if not found.""" - return role_mapping.get(role, role) - - def convert_legacy_format(example: dict) -> dict: - """Convert legacy 'conversations' format to OpenAI 'messages' format.""" - messages = [ - { - "role": normalize_role(convo["from"]), - "content": convo["value"], - } - for convo in example["conversations"] - ] - - # Create new dict without 'conversations' key - result = deepcopy(example) - result.pop("conversations") - return {"messages": messages, **result} - - processed_examples = [] - for example in examples: - # OpenAI format - if "messages" in example: - processed_examples.append(example) - - # Legacy format - elif "conversations" in example: - processed_examples.append(convert_legacy_format(example)) - - else: - raise ValueError( - "Only `messages` and `conversations` message keys are currently supported." - ) - - return processed_examples - - def _process_images(examples, max_images): - """ - Process images from examples, ensuring consistency in image presence and applying max_images limit. - - Args: - examples: List of dictionaries that may contain 'images' key - max_images: Maximum number of images to keep per example (0 means no limit) - - Returns: - Either None (if no images) or List[Image objects] (if all examples have images) - - Raises: - ValueError: If there's a mix of None and non-None images - """ - - def get_image(example): - if "images" not in example: - return None - images = example["images"] - if isinstance(images, str): - return Image.open(images) - return images - - images = [get_image(example) for example in examples] - - # Count None and non-None images - none_count = sum(1 for img in images if img is None) - - # All images are None - if none_count == len(images): - return None - - # Mix of None and non-None images - if none_count > 0: - raise ValueError( - "All images should be either None or not None. " - "Please provide images for all examples or None." - ) - - # Apply max_images limit if specified - if max_images > 0: - images = [ - ( - img_batch[:max_images] - if isinstance(img_batch, (list, tuple)) - else img_batch - ) - for img_batch in images - ] - - return images - # Preprocess the examples - examples = _preprocess(examples) + examples = processing_strategy.preprocess(examples) # Get the texts and images, and apply the chat template - texts = [ - processor.apply_chat_template( - example["messages"], chat_template=chat_template, tokenize=False - ) - for example in examples - ] - - images = _process_images(examples, max_images=max_images) + texts = processing_strategy.process_texts(examples) + images = processing_strategy.process_images(examples, max_images) # Tokenize the texts and process the images - batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + batch = processing_strategy.processor( + text=texts, images=images, return_tensors="pt", padding=True + ) # The labels are the input_ids, and we mask the padding tokens in the loss computation labels = batch["input_ids"].clone() - labels[labels == processor.tokenizer.pad_token_id] = -100 # + labels[labels == processing_strategy.processor.tokenizer.pad_token_id] = -100 # # Ignore the image token index in the loss computation (model specific) - image_token_id = processor.tokenizer.convert_tokens_to_ids( - processor.image_token - ) - labels[labels == image_token_id] = -100 + labels[labels == processing_strategy.image_token_id] = -100 batch["labels"] = labels if length_only: diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index c23359f34d..d997591006 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -132,7 +132,7 @@ def normalize_config(cfg): cfg.is_multimodal = ( hasattr(model_config, "model_type") - and model_config.model_type in ["llava", "mllama"] + and model_config.model_type in ["llava", "mllama", "qwen2_vl"] or any( multimodal_name in cfg.base_model.lower() for multimodal_name in [ @@ -145,7 +145,10 @@ def normalize_config(cfg): cfg.processor_config = ( cfg.processor_config or cfg.base_model_config or cfg.base_model ) - model_config = model_config.text_config + if hasattr(model_config, "text_config"): + model_config = model_config.text_config + elif hasattr(model_config, "get_text_config"): + model_config = model_config.get_text_config() cfg.model_config_type = model_config.model_type diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0781c67989..8e40034e22 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -60,6 +60,9 @@ class ChatTemplate(str, Enum): tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name exaone = "exaone" # pylint: disable=invalid-name metharme = "metharme" # pylint: disable=invalid-name + pixtral = "pixtral" # pylint: disable=invalid-name + llava = "llava" # pylint: disable=invalid-name + qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name class DeprecatedParameters(BaseModel): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 523fd76feb..d4a2769e02 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -91,7 +91,10 @@ def get_module_class_from_name(module, name): def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): if cfg.is_multimodal: - model_config = model_config.text_config + if hasattr(model_config, "text_config"): + model_config = model_config.text_config + elif hasattr(model_config, "get_text_config"): + model_config = model_config.get_text_config() quant_config_exists = ( hasattr(model_config, "quantization_config") @@ -367,7 +370,11 @@ def __init__( # init model config self.model_config = load_model_config(cfg) if cfg.is_multimodal: - self.text_model_config = self.model_config.text_config + try: + self.text_model_config = self.model_config.text_config + except AttributeError: + # for qwen2_vl + self.text_model_config = self.model_config.get_text_config() else: self.text_model_config = self.model_config @@ -1060,7 +1067,9 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: and self.model.get_input_embeddings().num_embeddings < embeddings_len ): resize_kwargs = {} - if self.cfg.mean_resizing_embeddings is not None: + if self.cfg.mean_resizing_embeddings is not None and not ( + self.model_config.model_type == "llava" + ): resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) else: