Skip to content

Commit

Permalink
[brief] Update the tests for the CuDNN context.
Browse files Browse the repository at this point in the history
[detailed]
- Ensure that it works regardless of the initial value assigned.
  • Loading branch information
marovira committed Apr 25, 2024
1 parent c769351 commit 975c13a
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,18 @@ def test_functions(self) -> None:
with pytest.raises(RuntimeError):
cuda.requires_cuda_support()

def _check_cudnn(self, val: bool) -> None:
torch.backends.cudnn.benchmark = val
assert torch.backends.cudnn.benchmark == val
with cuda.DisableCuDNNBenchmarkContext():
assert not torch.backends.cudnn.benchmark

assert torch.backends.cudnn.benchmark == val

def test_disable_cudnn_context(self) -> None:
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True

assert torch.backends.cudnn.benchmark
with cuda.DisableCuDNNBenchmarkContext():
assert not torch.backends.cudnn.benchmark
assert torch.backends.cudnn.benchmark
self._check_cudnn(True)
self._check_cudnn(False)


@dataclasses.dataclass
Expand Down

0 comments on commit 975c13a

Please sign in to comment.