Skip to content

Commit

Permalink
fix benchmark tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 15, 2024
1 parent e6bb55b commit facf2e3
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,9 +671,9 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
)
# take max over actions
state_action_value = state_action_value.reshape(
torch.Size([self.num_qvalue_nets])
+ tensordict.shape
+ torch.Size([self.num_random, -1])
torch.Size(
[*self.num_qvalue_nets, *tensordict.shape, self.num_random, -1]
)
).max(-2)[0]
# take min over qvalue nets
next_state_value = state_action_value.min(0)[0]
Expand Down Expand Up @@ -730,19 +730,13 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
"This could be caused by calling cql_loss method before q_loss method."
)

random_actions_tensor = (
torch.FloatTensor(
tensordict.shape[:-1]
+ torch.Size(
[
tensordict.shape[-1] * self.num_random,
tensordict[self.tensor_keys.action].shape[-1],
]
)
random_actions_tensor = pred_q1.new_empty(
(
*tensordict.shape[:-1],
tensordict.shape[-1] * self.num_random,
tensordict[self.tensor_keys.action].shape[-1],
)
.uniform_(-1, 1)
.to(tensordict.device)
)
).uniform_(-1, 1)
curr_actions_td, curr_log_pis = self._get_policy_actions(
tensordict.copy(),
self.actor_network_params,
Expand Down Expand Up @@ -1076,9 +1070,9 @@ def __init__(
self.loss_function = loss_function
if action_space is None:
# infer from value net
try:
if hasattr(value_network, "action_space"):
action_space = value_network.spec
except AttributeError:
else:
# let's try with action_space then
try:
action_space = value_network.action_space
Expand Down Expand Up @@ -1201,8 +1195,6 @@ def value_loss(
with torch.no_grad():
td_error = (pred_val_index - target_value).pow(2)
td_error = td_error.unsqueeze(-1)
if tensordict.device is not None:
td_error = td_error.to(tensordict.device)

tensordict.set(
self.tensor_keys.priority,
Expand Down

0 comments on commit facf2e3

Please sign in to comment.