Skip to content

Commit

Permalink
Support CE after grad acc fix (#375)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
Based on #374, but make it
leaner
1. The use of cross entropy in model code has changed after grad fix
2. It changed from module CrossEntropy to functional cross_entropy
3. Our monkey patching needs to change accordingly
4. While also make sure backward compatibility by adding a condition for
different versions

Notable Changes

1. Add a functional api for CE to take keyword args
2. Add back conv test with logits to test CE convergence
3. Add back comp test for transformers 4.44

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
ByronHsu authored Nov 12, 2024
1 parent b2b6970 commit 5ef09d5
Show file tree
Hide file tree
Showing 8 changed files with 881 additions and 12 deletions.
27 changes: 26 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,31 @@ jobs:
python -m pip install --upgrade pip
pip install modal
- name: Run unit tests
- name: Run tests
run: |
modal run dev.modal.tests
tests-bwd:
runs-on: ubuntu-latest
needs: [checkstyle]
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install modal
- name: Run tests
run: |
modal run dev.modal.tests_bwd
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ checkstyle:
# Command to run pytest for convergence tests
# We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286
test-convergence:
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence

HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_multimodal.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_with_logits.py

# Command to run all benchmark scripts and update benchmarking data file
# By default this doesn't overwrite existing data for the same benchmark experiment
Expand Down
28 changes: 28 additions & 0 deletions dev/modal/tests_bwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pathlib import Path

import modal

ROOT_PATH = Path(__file__).parent.parent.parent

# tests_bwd is to ensure the backward compatibility of liger with older transformers
image = (
modal.Image.debian_slim()
.pip_install_from_pyproject(
ROOT_PATH / "pyproject.toml", optional_dependencies=["dev"]
)
.pip_install("transformers==4.44.2")
)

app = modal.App("liger_tests", image=image)

# mount: add local files to the remote container
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel")


@app.function(gpu="A10G", mounts=[repo], timeout=60 * 10)
def liger_tests():
import subprocess

subprocess.run(["pip", "install", "-e", "."], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test"], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test-convergence"], check=True, cwd="/root/liger-kernel")
33 changes: 32 additions & 1 deletion src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
from liger_kernel.ops.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyFunction,
Expand All @@ -13,7 +15,6 @@
from liger_kernel.ops.swiglu import LigerSiLUMulFunction

liger_swiglu = LigerSiLUMulFunction.apply
liger_cross_entropy = LigerCrossEntropyFunction.apply
liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
liger_geglu = LigerGELUMulFunction.apply
liger_rms_norm = LigerRMSNormFunction.apply
Expand All @@ -23,3 +24,33 @@
liger_jsd = LigerJSDFunction.apply
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
liger_group_norm = LigerGroupNormFunction.apply


# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
# `weight` and `size_average` are placeholders and not implemented yet
def liger_cross_entropy(
input,
target,
weight=None,
size_average=None,
ignore_index: int = -100,
reduce=None,
reduction: str = "mean",
label_smoothing: float = 0.0,
lse_square_scale: float = 0.0,
softcap: Optional[float] = None,
return_z_loss: bool = False,
):
loss, z_loss = LigerCrossEntropyFunction.apply(
input,
target,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
)
if not return_z_loss:
return loss
return loss, z_loss
62 changes: 55 additions & 7 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers import PreTrainedModel

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
Expand Down Expand Up @@ -111,8 +112,16 @@ def apply_liger_kernel_to_llama(
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP

if cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss

if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
Expand Down Expand Up @@ -192,7 +201,13 @@ def apply_liger_kernel_to_mllama(
if swiglu:
modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
if cross_entropy:
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
Expand Down Expand Up @@ -342,7 +357,14 @@ def apply_liger_kernel_to_mixtral(
if rms_norm:
modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
if cross_entropy:
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss

if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
Expand Down Expand Up @@ -417,7 +439,13 @@ def apply_liger_kernel_to_gemma(
if rms_norm:
modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
if cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if geglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if fused_linear_cross_entropy:
Expand Down Expand Up @@ -474,6 +502,7 @@ def apply_liger_kernel_to_gemma2(
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."

from transformers.models.gemma2 import modeling_gemma2
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model

Expand All @@ -490,7 +519,13 @@ def apply_liger_kernel_to_gemma2(
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
if cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
Expand Down Expand Up @@ -562,8 +597,15 @@ def apply_liger_kernel_to_qwen2(
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm

if cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss

# import pdb; pdb.set_trace()
if fused_linear_cross_entropy:
Expand Down Expand Up @@ -710,7 +752,13 @@ def apply_liger_kernel_to_phi3(
if swiglu:
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
if cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
Expand Down
Loading

0 comments on commit 5ef09d5

Please sign in to comment.