Skip to content

Commit

Permalink
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 511782544
  • Loading branch information
hawkinsp authored and copybara-github committed Feb 23, 2023
1 parent 52a66ae commit 785390a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jmp/_src/loss_scale_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ def test_static_loss_scale(self, cls, scale):

@parameterized.named_parameters(
("NoOpLossScale", jmp.NoOpLossScale),
("StaticLossScale", lambda: jmp.StaticLossScale(0)),
("StaticLossScale", lambda: jmp.StaticLossScale(0)), # pytype: disable=wrong-arg-types # jax-ndarray
)
def test_static_empty_trees(self, create):
loss_scale = create()
self.assertEmpty(jax.tree_util.tree_leaves(loss_scale))

def test_dynamic_loss_scale_no_warnings(self):
with warnings.catch_warnings(record=True) as logged_warnings:
jmp.DynamicLossScale(2. ** 15)
jmp.DynamicLossScale(2. ** 15) # pytype: disable=wrong-arg-types # jax-ndarray
self.assertEmpty(logged_warnings)

def test_dynamic_loss_scale_tree(self):
Expand Down

0 comments on commit 785390a

Please sign in to comment.