Skip to content

Commit

Permalink
Fixed shape for MultiStep returns + Distributional loss
Browse files Browse the repository at this point in the history
  • Loading branch information
roger-creus committed Jul 5, 2024
1 parent 55f0a52 commit 5546b57
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict:
support = support.to("cpu")
pns_a = pns_a.to("cpu")

Tz = reward + (1 - terminated.to(reward.dtype)) * discount * support
Tz = reward + (1 - terminated.to(reward.dtype)) * discount.unsqueeze(-1) * support.repeat(batch_size, 1)
if Tz.shape != torch.Size([batch_size, atoms]):
raise RuntimeError(
"Tz shape must be torch.Size([batch_size, atoms]), "
Expand Down

0 comments on commit 5546b57

Please sign in to comment.