Skip to content

Commit

Permalink
support phi3
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jul 29, 2024
1 parent 383154b commit 46a49cf
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 17 deletions.
70 changes: 54 additions & 16 deletions mixlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _mixtral_slice_tensor(
"qwen2": "_llama_forward",
"mistral": "_llama_forward",
"phi": "_phi_forward",
"phi3": "_phi3_forward",
}


Expand Down Expand Up @@ -67,47 +68,47 @@ def __init__(
def _llama_forward(
self, expert_mask: torch.Tensor, hidden_states: torch.Tensor, input_dtype
):
common_w1 = self.base_layer_.gate_proj(hidden_states.to(input_dtype)).to(
common_gate = self.base_layer_.gate_proj(hidden_states.to(input_dtype)).to(
hidden_states.dtype
)
common_w3 = self.base_layer_.up_proj(hidden_states.to(input_dtype)).to(
common_up = self.base_layer_.up_proj(hidden_states.to(input_dtype)).to(
hidden_states.dtype
)
final_expert_states = []
for expert_idx in range(self.num_experts_):
_, top_x = torch.where(expert_mask[expert_idx])
lora_w1: Optional[Lora] = self.experts_.get(
lora_gate: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.gate_proj", None
)
lora_w2: Optional[Lora] = self.experts_.get(
lora_down: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.down_proj", None
)
lora_w3: Optional[Lora] = self.experts_.get(
lora_up: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.up_proj", None
)
if lora_w1 is not None:
if lora_gate is not None:
lora_data = _mixtral_slice_tensor(hidden_states, top_x, input_dtype)
w1 = lora_w1(
_mixtral_slice_tensor(common_w1, top_x, input_dtype), lora_data
gate_states = lora_gate(
_mixtral_slice_tensor(common_gate, top_x, input_dtype), lora_data
)
else:
lora_data = None
w1 = _mixtral_slice_tensor(common_w1, top_x, input_dtype)
gate_states = _mixtral_slice_tensor(common_gate, top_x, input_dtype)

if lora_w3 is not None:
if lora_up is not None:
lora_data = _mixtral_slice_tensor(hidden_states, top_x, input_dtype)
w3 = lora_w3(
_mixtral_slice_tensor(common_w3, top_x, input_dtype), lora_data
up_states = lora_up(
_mixtral_slice_tensor(common_up, top_x, input_dtype), lora_data
)
else:
lora_data = None
w3 = _mixtral_slice_tensor(common_w3, top_x, input_dtype)
up_states = _mixtral_slice_tensor(common_up, top_x, input_dtype)

act_result = self.act_(w1) * w3
act_result = self.act_(gate_states) * up_states

if lora_w2 is not None:
if lora_down is not None:
final_expert_states.append(
lora_w2(self.base_layer_.down_proj(act_result), act_result)
lora_down(self.base_layer_.down_proj(act_result), act_result)
)
else:
final_expert_states.append(self.base_layer_.down_proj(act_result))
Expand Down Expand Up @@ -150,6 +151,43 @@ def _phi_forward(

return final_expert_states

def _phi3_forward(
self, expert_mask: torch.Tensor, hidden_states: torch.Tensor, input_dtype
):
common_gate_up = self.base_layer_.gate_up_proj(
hidden_states.to(input_dtype)
).to(hidden_states.dtype)
final_expert_states = []
for expert_idx in range(self.num_experts_):
_, top_x = torch.where(expert_mask[expert_idx])
lora_gate_up: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.gate_up_proj", None
)
lora_down: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.down_proj", None
)
if lora_gate_up is not None:
gate_up_states = lora_gate_up(
_mixtral_slice_tensor(common_gate_up, top_x, input_dtype),
_mixtral_slice_tensor(hidden_states, top_x, input_dtype),
)
else:
gate_up_states = _mixtral_slice_tensor(
common_gate_up, top_x, input_dtype
)

gate_states, up_states = gate_up_states.chunk(2, dim=-1)
act_result = up_states * self.act_(gate_states)

if lora_down is not None:
final_expert_states.append(
lora_down(self.base_layer_.down_proj(act_result), act_result)
)
else:
final_expert_states.append(self.base_layer_.down_proj(act_result))

return final_expert_states

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape

Expand Down
2 changes: 1 addition & 1 deletion tests/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main(
outputs.detach().cpu().numpy(), skip_special_tokens=True
)[0][input_ids.shape[-1] :]

print(output)
print(f"\nOutput: {prompter.get_response(output)}\n")


if __name__ == "__main__":
Expand Down

0 comments on commit 46a49cf

Please sign in to comment.