Skip to content

Commit

Permalink
Simplify no warnings assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 19, 2024
1 parent 7793429 commit 98051d2
Showing 1 changed file with 1 addition and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import abc
import re
import unittest
import warnings
from types import ModuleType
from typing import Any
from unittest import mock
Expand Down Expand Up @@ -500,18 +499,12 @@ def test_amortized_computation_failure_tolerance(self) -> None:
step += 1

# Case 3: amortized computation succeeds after tolerance hit (test reset) -> no error.
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
with self.assertNoLogs(level="WARNING") as cm:
self._preconditioner_list.update_preconditioners(
masked_grad_list=masked_grad_list,
step=torch.tensor(step),
perform_amortized_computation=True,
)
self.assertEqual(
len(warning_list),
0,
f"Expected no warnings but got: {warning_list}",
)
self.assertEqual(
mock_amortized_computation.call_count,
self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1),
Expand Down

0 comments on commit 98051d2

Please sign in to comment.