Skip to content

Commit

Permalink
Remove dtype(fp16) condition check for residual_add unit test (micros…
Browse files Browse the repository at this point in the history
…oft#5329)

When the dtype is bf16 or fp32 the if condition is not satisfied and it
continues execution instead of skipping when triton is not installed.

Co-authored-by: Shaik Raza Sikander <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
4 people authored and umchand committed May 20, 2024
1 parent 74ee35e commit 9ad631f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/unit/ops/transformer/inference/test_residual_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f
@pytest.mark.parametrize("use_triton_ops", [True, False])
def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size,
pre_attn_norm, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops and dtype == torch.float16:
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
Expand Down

0 comments on commit 9ad631f

Please sign in to comment.