Skip to content

Commit

Permalink
fix categorical action in cql loss
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Nov 8, 2023
1 parent 9af9f99 commit 53ef411
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,9 @@ def cql_loss(self, tensordict):

logsumexp = torch.logsumexp(qvalues, dim=-1, keepdim=True)
if self.action_space == "categorical":
if current_action.shape != qvalues.shape:
# unsqueeze the action if it lacks on trailing singleton dim
current_action = current_action.unsqueeze(-1)
q_a = qvalues.gather(-1, current_action)
else:
q_a = (qvalues * current_action).sum(dim=-1, keepdim=True)
Expand Down

0 comments on commit 53ef411

Please sign in to comment.