Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 19, 2024
1 parent 7a7c1aa commit 40390f5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 15 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
)

# copy action from the input tensordict to the output
transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec))
transformed_env.append_transform(TensorDictPrimer(base_env.full_action_spec))

transformed_env.append_transform(DoubleToFloat())
obsnorm = ObservationNorm(
Expand Down
45 changes: 36 additions & 9 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7408,7 +7408,7 @@ def make_env():
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
env = TransformedEnv(
maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=Unbounded([2, 4])),
TensorDictPrimer(mykey=Unbounded([4])),
)
try:
check_env_specs(env)
Expand All @@ -7423,11 +7423,39 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
pass

@pytest.mark.parametrize("spec_shape", [[4], [2, 4]])
def test_trans_serial_env_check(self, spec_shape):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=Unbounded(spec_shape)),
)
@pytest.mark.parametrize("expand_specs", [True, False, None])
def test_trans_serial_env_check(self, spec_shape, expand_specs):
if expand_specs is None:
with pytest.warns(FutureWarning, match=""):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
mykey=Unbounded(spec_shape), expand_specs=expand_specs
),
)
env.observation_spec
elif expand_specs is True:
shape = spec_shape[:-1]
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
Composite(mykey=Unbounded(spec_shape), shape=shape),
expand_specs=expand_specs,
),
)
else:
# If we don't expand, we can't use [4]
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
mykey=Unbounded(spec_shape), expand_specs=expand_specs
),
)
if spec_shape == [4]:
with pytest.raises(ValueError):
env.observation_spec
return

check_env_specs(env)
assert "mykey" in env.reset().keys()
r = env.rollout(3)
Expand Down Expand Up @@ -10310,9 +10338,8 @@ def _make_transform_env(self, out_key, base_env):
transform = KLRewardTransform(actor, out_keys=out_key)
return Compose(
TensorDictPrimer(
primers={
"sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1])
}
sample_log_prob=Unbounded(shape=base_env.action_spec.shape[:-1]),
shape=base_env.shape,
),
transform,
)
Expand Down
19 changes: 16 additions & 3 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4984,6 +4984,7 @@ def __init__(
| Dict[NestedKey, float]
| Dict[NestedKey, Callable] = None,
reset_key: NestedKey | None = None,
expand_specs: bool = None,
**kwargs,
):
self.device = kwargs.pop("device", None)
Expand All @@ -4995,8 +4996,10 @@ def __init__(
)
kwargs = primers
if not isinstance(kwargs, Composite):
kwargs = Composite(kwargs)
kwargs = Composite(**kwargs)
self.primers = kwargs
self.expand_specs = expand_specs

if random and default_value:
raise ValueError(
"Setting random to True and providing a default_value are incompatible."
Expand Down Expand Up @@ -5089,15 +5092,25 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
)

if self.primers.shape != observation_spec.shape:
if self.primers.shape == () and self.parent.batch_size != ():
if self.expand_specs:
self.primers = self._expand_shape(self.primers)
else:
elif self.expand_specs is None:
warnings.warn(
f"expand_specs wasn't specified in the {type(self).__name__} constructor. "
f"The current behaviour is that the transform will attempt to set the shape of the composite "
f"spec, and if this can't be done it will be expanded. "
f"From v0.8, a mismatched shape between the spec of the transform and the env's batch_size "
f"will raise an exception.",
category=FutureWarning,
)
try:
# We try to set the primer shape to the observation spec shape
self.primers.shape = observation_spec.shape
except ValueError:
# If we fail, we expand them to that shape
self.primers = self._expand_shape(self.primers)
else:
self.primers.shape = observation_spec.shape

device = observation_spec.device
observation_spec.update(self.primers.clone().to(device))
Expand Down
6 changes: 4 additions & 2 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,8 @@ def make_tuple(key):
{
in_key1: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)),
in_key2: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)),
}
},
expand_specs=True,
)

@property
Expand Down Expand Up @@ -1467,7 +1468,8 @@ def make_tuple(key):
return TensorDictPrimer(
{
in_key1: Unbounded(shape=(self.gru.num_layers, self.gru.hidden_size)),
}
},
expand_specs=True,
)

@property
Expand Down

0 comments on commit 40390f5

Please sign in to comment.