Skip to content

Commit

Permalink
Merge branch 'discrete_CQL' of https://github.com/BY571/rl into discr…
Browse files Browse the repository at this point in the history
…ete_CQL
  • Loading branch information
BY571 committed Nov 9, 2023
2 parents 1bf7d0d + 6bcf240 commit 4885005
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5319,7 +5319,7 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est):
loss = loss_fn(td)
assert loss_fn.tensor_keys.priority in td.keys(True)

sum([item for key, item in loss.items() if key.startswith('loss')]).backward()
sum([item for key, item in loss.items() if key.startswith("loss")]).backward()
assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0

# Check param update effect on targets
Expand Down Expand Up @@ -5383,14 +5383,18 @@ def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9)
if n == 0:
assert_allclose_td(td, ms_td.select(*td.keys(True, True)))
_loss = sum([item for key, item in loss.items() if key.startswith("loss_")])
_loss_ms = sum([item for key, item in loss_ms.items() if key.startswith("loss_")])
_loss_ms = sum(
[item for key, item in loss_ms.items() if key.startswith("loss_")]
)
assert (
abs(_loss - _loss_ms) < 1e-3
), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0"
else:
with pytest.raises(AssertionError):
assert_allclose_td(loss, loss_ms)
sum([item for key, item in loss_ms.items() if key.startswith('loss_')]).backward()
sum(
[item for key, item in loss_ms.items() if key.startswith("loss_")]
).backward()
assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0

# Check param update effect on targets
Expand Down

0 comments on commit 4885005

Please sign in to comment.