From 383154bf3a6531580e661dc3cd5acbf8c0319bf8 Mon Sep 17 00:00:00 2001 From: Mike Lee Date: Mon, 29 Jul 2024 16:36:33 +0800 Subject: [PATCH] support phi model --- mixlora/config.py | 99 +++++++++++++++++++++++++++++------------------ mixlora/model.py | 43 ++++++++++++++++++-- pyproject.toml | 6 +-- tests/generate.py | 2 +- 4 files changed, 105 insertions(+), 45 deletions(-) diff --git a/mixlora/config.py b/mixlora/config.py index 06f7208..b5ec25a 100644 --- a/mixlora/config.py +++ b/mixlora/config.py @@ -1,6 +1,6 @@ import copy from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Optional import torch from transformers.activations import ACT2FN @@ -14,6 +14,30 @@ class AdapterConfig: dtype_: torch.dtype = None +lora_target_modules = { + # LLaMA names + "q_proj": False, + "k_proj": False, + "v_proj": False, + "o_proj": False, + "gate_proj": False, + "down_proj": False, + "up_proj": False, + # Phi names + "q_proj": False, + "k_proj": False, + "v_proj": False, + "dense": False, + "fc1": False, + "fc2": False, + # Phi3 names + "qkv_proj": False, + "o_proj": False, + "gate_up_proj": False, + "down_proj": False, +} + + @dataclass class LoraConfig(AdapterConfig): # Weight-Decomposed Low-Rank Adaptation @@ -45,35 +69,28 @@ def check(self) -> "LoraConfig": return self - def from_config(self, config: Dict[str, any]) -> "LoraConfig": - self.use_dora_ = config.get("use_dora", False) - self.use_rslora_ = config.get("use_rslora", False) - self.lora_init_ = config.get("lora_init", "original") - self.lora_r_ = config["r"] - self.lora_alpha_ = config["lora_alpha"] - self.lora_dropout_ = config["lora_dropout"] - self.target_modules_ = { - # LLaMA names - "q_proj": False, - "k_proj": False, - "v_proj": False, - "o_proj": False, - "gate_proj": False, - "down_proj": False, - "up_proj": False, - } + @staticmethod + def from_config(config: Dict[str, any]) -> "LoraConfig": + lora_config = LoraConfig() + lora_config.use_dora_ = config.get("use_dora", False) + lora_config.use_rslora_ = config.get("use_rslora", False) + lora_config.lora_init_ = config.get("lora_init", "original") + lora_config.lora_r_ = config["r"] + lora_config.lora_alpha_ = config["lora_alpha"] + lora_config.lora_dropout_ = config["lora_dropout"] + lora_config.target_modules_ = copy.deepcopy(lora_target_modules) if isinstance(config["target_modules"], List): for target in config["target_modules"]: - if target in self.target_modules_: - self.target_modules_[target] = True + if target in lora_target_modules: + lora_config.target_modules_[target] = True elif isinstance(config["target_modules"], Dict): for target, value in config["target_modules"].items(): - if target in self.target_modules_: - self.target_modules_[target] = value + if target in lora_target_modules: + lora_config.target_modules_[target] = value else: raise ValueError("broken config item: target_modules") - return self + return lora_config def export(self) -> Dict[str, any]: config = {} @@ -109,7 +126,7 @@ class MixLoraConfig(LoraConfig): jitter_noise_: float = None router_loss_: bool = True num_experts_: int = None - act_fn_: str = None + act_fn_: Optional[str] = None # mixtral config top_k_: int = None @@ -141,30 +158,36 @@ def check(self) -> "MixLoraConfig": return self - def from_config(self, config: Dict[str, any]) -> "MixLoraConfig": - assert config["peft_type"] == "MIXLORA" - super().from_config(config) + @staticmethod + def from_config(config: Dict[str, any]) -> "MixLoraConfig": + lora_config = MixLoraConfig(**LoraConfig.from_config(config).__dict__) + assert ( + "peft_type" in config + and config["peft_type"] == "MIXLORA" + and "routing_strategy" in config + and config["routing_strategy"] == "mixtral" + ), "MixLoraConfig only supports MixLoRA models with 'mixtral' routing_strategy." if "expert_lora" in config: expert_config = copy.deepcopy(config) expert_config.update(config["expert_lora"]) - self.expert_config_ = LoraConfig().from_config(expert_config) - self.router_aux_loss_coef_ = config.get( + lora_config.expert_config_ = LoraConfig().from_config(expert_config) + lora_config.router_aux_loss_coef_ = config.get( "router_aux_loss_coef", 0.001 ) # for training - self.routing_strategy_ = config["routing_strategy"] - self.router_loss_ = config.get("router_loss", True) - self.num_experts_ = config["num_experts"] + lora_config.routing_strategy_ = config["routing_strategy"] + lora_config.router_loss_ = config.get("router_loss", True) + lora_config.num_experts_ = config["num_experts"] # silu for mixtral or gelu_new for switch transformers # left blank to automatically use the original act_fn of FFN - self.act_fn_ = config.get("act_fn", None) - if self.routing_strategy_ == "mixtral": - self.router_init_range_ = config.get("router_init_range", 0.02) - self.jitter_noise_ = config.get("jitter_noise", 0.0) - self.top_k_ = config.get("top_k", 2) + lora_config.act_fn_ = config.get("act_fn", None) + if lora_config.routing_strategy_ == "mixtral": + lora_config.router_init_range_ = config.get("router_init_range", 0.02) + lora_config.jitter_noise_ = config.get("jitter_noise", 0.0) + lora_config.top_k_ = config.get("top_k", 2) else: raise NotImplementedError() - return self + return lora_config def export(self) -> Dict[str, any]: config = super().export() diff --git a/mixlora/model.py b/mixlora/model.py index 7715573..40507d9 100644 --- a/mixlora/model.py +++ b/mixlora/model.py @@ -33,6 +33,7 @@ def _mixtral_slice_tensor( "gemma2": "_llama_forward", "qwen2": "_llama_forward", "mistral": "_llama_forward", + "phi": "_phi_forward", } @@ -113,6 +114,42 @@ def _llama_forward( return final_expert_states + def _phi_forward( + self, expert_mask: torch.Tensor, hidden_states: torch.Tensor, input_dtype + ): + common_fc1 = self.base_layer_.fc1(hidden_states.to(input_dtype)).to( + hidden_states.dtype + ) + final_expert_states = [] + for expert_idx in range(self.num_experts_): + _, top_x = torch.where(expert_mask[expert_idx]) + lora_fc1: Optional[Lora] = self.experts_.get( + f"experts.{expert_idx}.fc1", None + ) + lora_fc2: Optional[Lora] = self.experts_.get( + f"experts.{expert_idx}.fc2", None + ) + if lora_fc1 is not None: + lora_data = _mixtral_slice_tensor(hidden_states, top_x, input_dtype) + act_result = self.act_( + lora_fc1( + _mixtral_slice_tensor(common_fc1, top_x, input_dtype), lora_data + ) + ) + else: + act_result = self.act_( + _mixtral_slice_tensor(common_fc1, top_x, input_dtype) + ) + + if lora_fc2 is not None: + final_expert_states.append( + lora_fc2(self.base_layer_.fc2(act_result), act_result) + ) + else: + final_expert_states.append(self.base_layer_.fc2(act_result)) + + return final_expert_states + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -260,9 +297,9 @@ def load_adapter_weights( with open( name_or_path + os.sep + "adapter_config.json", "r", encoding="utf8" ) as fp: - config = MixLoraConfig(adapter_name_=adapter_name, dtype_=dtype).from_config( - json.load(fp) - ) + config = MixLoraConfig.from_config(json.load(fp)) + config.adapter_name_ = adapter_name + config.dtype_ = dtype weights = torch.load( name_or_path + os.sep + "adapter_model.bin", map_location=device diff --git a/pyproject.toml b/pyproject.toml index 60319ee..73f4515 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mixlora" -version = "0.1.1" +version = "0.1.2" description = "State-of-the-art Parameter-Efficient MoE Fine-tuning Method" readme = "README.md" requires-python = ">=3.8" @@ -14,8 +14,8 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "torch==2.3.1", - "transformers==4.42.4", + "torch>=2.3.0,<2.4.0", + "transformers>=4.43.0,<4.44.0", "huggingface_hub", ] diff --git a/tests/generate.py b/tests/generate.py index 77c9a87..96c40c4 100644 --- a/tests/generate.py +++ b/tests/generate.py @@ -29,7 +29,7 @@ def main( ) output = tokenizer.batch_decode( outputs.detach().cpu().numpy(), skip_special_tokens=True - )[0][len(instruction) :] + )[0][input_ids.shape[-1] :] print(output)