Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] skip failed tests for xpu #498

Merged
merged 11 commits into from
Jan 24, 2025
31 changes: 18 additions & 13 deletions test/convergence/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
import functools
import os

import pytest
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerFast

from liger_kernel.transformers import apply_liger_kernel_to_mllama
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
from test.utils import FAKE_CONFIGS_PATH
from test.utils import UNTOKENIZED_DATASET_PATH
from test.utils import MiniModelConfig
Expand All @@ -22,6 +13,16 @@
from test.utils import supports_bfloat16
from test.utils import train_bpe_tokenizer

import pytest
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerFast

from liger_kernel.transformers import apply_liger_kernel_to_mllama
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl

try:
# Qwen2-VL is only available in transformers>=4.45.0
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
Expand Down Expand Up @@ -346,10 +347,13 @@ def run_mini_model_multimodal(
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
marks=[
pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
],
),
pytest.param(
"mini_qwen2_vl",
Expand All @@ -368,6 +372,7 @@ def run_mini_model_multimodal(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
],
),
pytest.param(
Expand Down
11 changes: 5 additions & 6 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from test.transformers.test_cross_entropy import CrossEntropyWithZLoss
from test.utils import assert_verbose_allclose
from test.utils import set_seed
from typing import Optional

import pytest
import torch

from test.transformers.test_cross_entropy import CrossEntropyWithZLoss
from test.utils import assert_verbose_allclose
from test.utils import set_seed

from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
from liger_kernel.transformers.functional import liger_fused_linear_cross_entropy
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
Expand Down Expand Up @@ -105,7 +104,7 @@ def forward(self, x, y):
@pytest.mark.parametrize(
"B, T, H, V",
[
(8, 128, 1024, 4096),
pytest.param(8, 128, 1024, 4096, marks=pytest.mark.skipif(device="xpu", reason="skip for XPU")),
(4, 47, 31, 123), # random shape
],
)
Expand Down Expand Up @@ -287,7 +286,7 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, ce_weight, atol
@pytest.mark.parametrize(
"B, T, H, V",
[
(8, 128, 1024, 4096),
pytest.param(8, 128, 1024, 4096, marks=pytest.mark.skipif(device="xpu", reason="skip for XPU")),
(4, 47, 31, 123), # random shape
],
)
Expand Down
10 changes: 5 additions & 5 deletions test/transformers/test_rms_norm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os

import pytest
import torch
import torch.nn as nn

from test.utils import assert_verbose_allclose
from test.utils import set_seed
from test.utils import supports_bfloat16

import pytest
import torch
import torch.nn as nn

from liger_kernel.ops.rms_norm import LigerRMSNormFunction
from liger_kernel.transformers.functional import liger_rms_norm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
Expand Down Expand Up @@ -103,7 +103,7 @@ def forward(self, x):
[
(LlamaRMSNorm, 0.0, "llama"),
(GemmaRMSNorm, 1.0, "gemma"),
(BaseRMSNorm, 0.0, "none"),
pytest.param(BaseRMSNorm, 0.0, "none", marks=pytest.mark.skipif(device="xpu", reason="skip for XPU")),
],
)
@pytest.mark.parametrize(
Expand Down
Loading