Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 19, 2024
2 parents 735f4b6 + 49558c3 commit 2a8625c
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch.nn
import torch.optim
from tensordict import TensorDict, TensorDictParams
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor

Expand Down Expand Up @@ -216,12 +217,19 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
in_keys=["loc", "scale"],
spec=action_spec,
distribution_class=TanhNormal,
distribution_kwargs={
"low": torch.as_tensor(action_spec.space.low, device=device),
"high": torch.as_tensor(action_spec.space.high, device=device),
"tanh_loc": False,
"safe_tanh": not cfg.compile.compile,
},
# Wrapping the kwargs in a TensorDictParams such that these items are
# send to device when necessary
distribution_kwargs=TensorDictParams(
TensorDict(
{
"low": torch.as_tensor(action_spec.space.low, device=device),
"high": torch.as_tensor(action_spec.space.high, device=device),
"tanh_loc": False,
"safe_tanh": not cfg.compile.compile,
}
),
no_convert=True,
),
default_interaction_type=ExplorationType.RANDOM,
)

Expand Down

0 comments on commit 2a8625c

Please sign in to comment.