Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
clean up tests and add separate tolerances for fwd and bwd
Browse files Browse the repository at this point in the history
  • Loading branch information
dianaml0 committed Dec 23, 2022
1 parent 71df679 commit 59a933e
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions gpu_tests/test_sequence_parallel_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,23 @@ def test_xformers_parity(self):
if "xformers" not in sys.modules:
raise unittest.SkipTest("xformers not available, skipping test")

atol = 4e-3
rtol = 4e-4
fw_atol = 4e-3
fw_rtol = 4e-4

bw_atol = 9e-2
bw_rtol = 2e-2

_distributed_init()
tensor_model_parallel_size_ = 1
initialize_model_parallel(tensor_model_parallel_size_)

S, B, E = 8, 16, 64
H = 2
args = SimpleNamespace(
sequence_parallel=True,
decoder_embed_dim=64,
decoder_embed_dim=E,
dropout=0.0,
decoder_attention_heads=2,
decoder_attention_heads=H,
decoder_ffn_embed_dim=64,
decoder_layers=1,
attention_dropout=0.0,
Expand Down Expand Up @@ -96,7 +101,7 @@ def test_xformers_parity(self):
result = decoder(x_)

torch.distributed.barrier()
_assert_allclose(xf_result, result, atol=atol, rtol=rtol)
_assert_allclose(xf_result, result, atol=fw_atol, rtol=fw_rtol)

# Test Backwards
reset_seeds()
Expand All @@ -105,7 +110,7 @@ def test_xformers_parity(self):
result.backward(torch.ones_like(x_))

torch.distributed.barrier()
_assert_allclose(x.grad, x_.grad, atol=atol, rtol=rtol)
_assert_allclose(x.grad, x_.grad, atol=bw_atol, rtol=bw_rtol)

# Reset groups
destroy_model_parallel()
Expand Down

0 comments on commit 59a933e

Please sign in to comment.