From 53ef411b0da386ac0320a0e6c7b90e019269cd92 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 8 Nov 2023 16:59:48 +0100 Subject: [PATCH] fix categorical action in cql loss --- torchrl/objectives/cql.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index c465d8ca59a..9055e5464c6 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -1152,6 +1152,9 @@ def cql_loss(self, tensordict): logsumexp = torch.logsumexp(qvalues, dim=-1, keepdim=True) if self.action_space == "categorical": + if current_action.shape != qvalues.shape: + # unsqueeze the action if it lacks on trailing singleton dim + current_action = current_action.unsqueeze(-1) q_a = qvalues.gather(-1, current_action) else: q_a = (qvalues * current_action).sum(dim=-1, keepdim=True)