From 7c83893dd473e6350fabb3084eaaaf4c46971d8e Mon Sep 17 00:00:00 2001 From: Raza Sikander <54884406+raza-sikander@users.noreply.github.com> Date: Tue, 16 Apr 2024 01:28:56 +0530 Subject: [PATCH] Remove dtype(fp16) condition check for residual_add unit test (#5329) When the dtype is bf16 or fp32 the if condition is not satisfied and it continues execution instead of skipping when triton is not installed. Co-authored-by: Shaik Raza Sikander Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- tests/unit/ops/transformer/inference/test_residual_add.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/ops/transformer/inference/test_residual_add.py b/tests/unit/ops/transformer/inference/test_residual_add.py index c2952f74ff2d..91830e25fc81 100644 --- a/tests/unit/ops/transformer/inference/test_residual_add.py +++ b/tests/unit/ops/transformer/inference/test_residual_add.py @@ -77,7 +77,7 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f @pytest.mark.parametrize("use_triton_ops", [True, False]) def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size, pre_attn_norm, use_triton_ops): - if not deepspeed.HAS_TRITON and use_triton_ops and dtype == torch.float16: + if not deepspeed.HAS_TRITON and use_triton_ops: pytest.skip("triton has to be installed for the test") ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())