Skip to content

Commit

Permalink
Merge pull request #24427 from kaixih:tolerence_jax_sdpa
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689328282
  • Loading branch information
Google-ML-Automation committed Oct 24, 2024
2 parents 717467a + 7409bae commit c311c73
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl):

self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
self.assertAllClose(dK_ref, dK_ans, rtol=.01, atol=.01)
self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01)

@parameterized.product(
mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'),
Expand Down Expand Up @@ -164,10 +164,10 @@ def testDotProductAttentionMask(self, mask_mode):
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))

self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.02, atol=.02)
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03)
self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01)
self.assertAllClose(dbias_ref, dbias_ans, rtol=.02, atol=.02)

@parameterized.product(
batch_size=[1, 16],
Expand Down Expand Up @@ -224,7 +224,7 @@ def bwd_ans(x, bias, mask):
else:
_, dbias_ref, _ = bwd_ref(x, bias, mask)
_, dbias_ans, _ = bwd_ans(x, bias, mask)
self.assertAllClose(dbias_ans, dbias_ref, rtol=.03, atol=.03)
self.assertAllClose(dbias_ans, dbias_ref, rtol=.02, atol=.02)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
Expand Down

0 comments on commit c311c73

Please sign in to comment.