Skip to content

Commit

Permalink
Monkey patch layer norm in mllama (#302)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Monkey patches layer norm in mllama for conditional generation
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
Tested monkey patching works as intended
<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shivam Sahni <[email protected]>
  • Loading branch information
shivam15s and Shivam Sahni authored Oct 17, 2024
1 parent 24a7efc commit 6ab3b9f
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 5 deletions.
22 changes: 21 additions & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def apply_liger_kernel_to_mllama(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
layer_norm: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
Expand Down Expand Up @@ -151,12 +152,15 @@ def apply_liger_kernel_to_mllama(
MllamaForCausalLM,
MllamaForConditionalGeneration,
MllamaTextModel,
MllamaVisionModel,
)

from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward

if rope:
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
if layer_norm:
modeling_mllama.nn.LayerNorm = LigerLayerNorm
if rms_norm:
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
if swiglu:
Expand All @@ -174,11 +178,14 @@ def apply_liger_kernel_to_mllama(

if isinstance(model, MllamaForConditionalGeneration):
language_model: MllamaForCausalLM = model.language_model
vision_model: MllamaVisionModel = model.vision_model
text_model: MllamaTextModel = language_model.model
elif isinstance(model, MllamaForCausalLM):
text_model = model.model
vision_model = None
elif isinstance(model, MllamaTextModel):
text_model = model
vision_model = None
else:
raise ValueError(f"Unsupported Mllama model type: {type(model)}")

Expand All @@ -194,6 +201,20 @@ def apply_liger_kernel_to_mllama(
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)

if vision_model:
_patch_layer_norm_module(vision_model.layernorm_pre)
_patch_layer_norm_module(vision_model.layernorm_post)

for layer in vision_model.transformer.layers:
if layer_norm:
_patch_layer_norm_module(layer.input_layernorm)
_patch_layer_norm_module(layer.post_attention_layernorm)

for layer in vision_model.global_transformer.layers:
if layer_norm:
_patch_layer_norm_module(layer.input_layernorm)
_patch_layer_norm_module(layer.post_attention_layernorm)


def apply_liger_kernel_to_mistral(
rope: bool = True,
Expand Down Expand Up @@ -767,7 +788,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
for key, value in kwargs.items()
if key in apply_fn_signature.parameters
}

logger.info(
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
)
Expand Down
5 changes: 1 addition & 4 deletions test/convergence/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,12 @@ def run_mini_model_multimodal(
kwargs = {
"rms_norm": True,
"cross_entropy": True,
"layer_norm": True,
}
model_supports_rope = "qwen2_vl" not in model_name
if model_supports_rope:
kwargs["rope"] = True

model_supports_layer_norm = "qwen2_vl" in model_name
if model_supports_layer_norm:
kwargs["layer_norm"] = True

if "gemma" in model_name:
kwargs["geglu"] = True
else:
Expand Down
42 changes: 42 additions & 0 deletions test/transformers/test_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,27 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
layer.post_attention_layernorm.forward
) != inspect.getsource(LigerRMSNorm.forward)

assert inspect.getsource(
dummy_model_instance.vision_model.layernorm_pre.forward
) != inspect.getsource(LigerLayerNorm.forward)
assert inspect.getsource(
dummy_model_instance.vision_model.layernorm_post.forward
) != inspect.getsource(LigerLayerNorm.forward)
for layer in dummy_model_instance.vision_model.transformer.layers:
assert inspect.getsource(
layer.input_layernorm.forward
) != inspect.getsource(LigerLayerNorm.forward)
assert inspect.getsource(
layer.post_attention_layernorm.forward
) != inspect.getsource(LigerLayerNorm.forward)
for layer in dummy_model_instance.vision_model.global_transformer.layers:
assert inspect.getsource(
layer.input_layernorm.forward
) != inspect.getsource(LigerLayerNorm.forward)
assert inspect.getsource(
layer.post_attention_layernorm.forward
) != inspect.getsource(LigerLayerNorm.forward)

# Test applying kernels to the model instance
_apply_liger_kernel_to_instance(model=dummy_model_instance)

Expand All @@ -320,6 +341,27 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
layer.post_attention_layernorm.forward
) == inspect.getsource(LigerRMSNorm.forward)

assert inspect.getsource(
dummy_model_instance.vision_model.layernorm_pre.forward
) == inspect.getsource(LigerLayerNorm.forward)
assert inspect.getsource(
dummy_model_instance.vision_model.layernorm_post.forward
) == inspect.getsource(LigerLayerNorm.forward)
for layer in dummy_model_instance.vision_model.transformer.layers:
assert inspect.getsource(
layer.input_layernorm.forward
) == inspect.getsource(LigerLayerNorm.forward)
assert inspect.getsource(
layer.post_attention_layernorm.forward
) == inspect.getsource(LigerLayerNorm.forward)
for layer in dummy_model_instance.vision_model.global_transformer.layers:
assert inspect.getsource(
layer.input_layernorm.forward
) == inspect.getsource(LigerLayerNorm.forward)
assert inspect.getsource(
layer.post_attention_layernorm.forward
) == inspect.getsource(LigerLayerNorm.forward)


def test_apply_liger_kernel_to_instance_for_mllama_for_causal_lm():
# Ensure any monkey patching is cleaned up for subsequent tests
Expand Down
2 changes: 2 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,10 @@ def revert_liger_kernel_to_mllama():
Revert all Liger kernel patches applied to MLlama.
"""

import torch.nn as nn
from transformers.models.mllama import modeling_mllama

importlib.reload(nn)
importlib.reload(modeling_mllama)
print("Liger kernel patches have been reverted.")

Expand Down

0 comments on commit 6ab3b9f

Please sign in to comment.