Skip to content

Commit

Permalink
update vmap fail case test
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Jan 8, 2024
1 parent f737172 commit 1bfcc27
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import itertools
import operator
import re
import warnings
from copy import deepcopy
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -266,14 +267,18 @@ def forward(self, td):
# If user sets vmap randomness to a specific value
if vmap_randomness in ("different", "same") and dropout > 0.0:
loss_module.set_vmap_randomness(vmap_randomness)
loss_module(td)["loss"]
# Fail case
elif vmap_randomness == "error" and dropout > 0.0:
with pytest.raises(RuntimeError):
with pytest.raises(RuntimeError) as exc_info:
loss_module(td)["loss"]

else:
loss_module(td)["loss"]
# Accessing cause of the caught exception
cause = exc_info.value.__cause__
assert re.match(
r"vmap: called random operation while in randomness error mode", str(cause)
)
return
loss_module(td)["loss"]


class TestDQN(LossModuleTestBase):
Expand Down

0 comments on commit 1bfcc27

Please sign in to comment.