Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Extend TensorDictPrimer default_value options #2071

Merged
merged 21 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6406,17 +6406,11 @@ def test_trans_parallel_env_check(self):
finally:
env.close()

def test_trans_serial_env_check(self):
with pytest.raises(RuntimeError, match="The leading shape of the primer specs"):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])),
)
_ = env.observation_spec

@pytest.mark.parametrize("spec_shape", [[4], [2, 4]])
def test_trans_serial_env_check(self, spec_shape):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
Expand Down Expand Up @@ -6516,6 +6510,72 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done):
r1 = env.rollout(100, break_when_any_done=break_when_any_done)
tensordict.tensordict.assert_allclose_td(r0, r1)

def test_callable_default_value(self):
def create_tensor():
return torch.ones(3)

env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey=UnboundedContinuousTensorSpec([3]), default_value=create_tensor
),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
assert ("next", "mykey") in env.rollout(3).keys(True)

def test_dict_default_value(self):

# Test with a dict of float default values
key1_spec = UnboundedContinuousTensorSpec([3])
key2_spec = UnboundedContinuousTensorSpec([3])
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey1=key1_spec,
mykey2=key2_spec,
default_value={
"mykey1": 1.0,
"mykey2": 2.0,
},
),
)
check_env_specs(env)
reset_td = env.reset()
assert "mykey1" in reset_td.keys()
assert "mykey2" in reset_td.keys()
rollout_td = env.rollout(3)
assert ("next", "mykey1") in rollout_td.keys(True)
assert ("next", "mykey2") in rollout_td.keys(True)
assert (rollout_td.get(("next", "mykey1")) == 1.0).all()
assert (rollout_td.get(("next", "mykey2")) == 2.0).all()

# Test with a dict of callable default values
key1_spec = UnboundedContinuousTensorSpec([3])
key2_spec = DiscreteTensorSpec(3, dtype=torch.int64)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey1=key1_spec,
mykey2=key2_spec,
default_value={
"mykey1": lambda: torch.ones(3),
"mykey2": lambda: torch.tensor(1, dtype=torch.int64),
},
),
)
check_env_specs(env)
reset_td = env.reset()
assert "mykey1" in reset_td.keys()
assert "mykey2" in reset_td.keys()
rollout_td = env.rollout(3)
assert ("next", "mykey1") in rollout_td.keys(True)
assert ("next", "mykey2") in rollout_td.keys(True)
assert (rollout_td.get(("next", "mykey1")) == torch.ones(3)).all
assert (
rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64)
).all


class TestTimeMaxPool(TransformBase):
@pytest.mark.parametrize("T", [2, 4])
Expand Down
119 changes: 97 additions & 22 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4431,8 +4431,12 @@ class TensorDictPrimer(Transform):
random (bool, optional): if ``True``, the values will be drawn randomly from
the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed.
Defaults to `False`.
default_value (float, optional): if non-random filling is chosen, this
value will be used to populate the tensors. Defaults to `0.0`.
default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random
filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float,
all elements of the tensors will be set to that value. If it is a callable, this callable is expected to
return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value`
is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will
be used to generate the corresponding tensors. Defaults to `0.0`.
reset_key (NestedKey, optional): the reset key to be used as partial
reset indicator. Must be unique. If not provided, defaults to the
only reset key of the parent environment (if it has only one)
Expand Down Expand Up @@ -4489,8 +4493,11 @@ class TensorDictPrimer(Transform):
def __init__(
self,
primers: dict | CompositeSpec = None,
random: bool = False,
default_value: float = 0.0,
random: bool | None = None,
default_value: float
| Callable
| Dict[NestedKey, float]
| Dict[NestedKey, Callable] = 0.0,
reset_key: NestedKey | None = None,
**kwargs,
):
Expand All @@ -4505,8 +4512,23 @@ def __init__(
if not isinstance(kwargs, CompositeSpec):
kwargs = CompositeSpec(kwargs)
self.primers = kwargs
if (random is not None) and isinstance(default_value, (dict, Callable)):
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Setting random to True and providing a default_value are incompatible."
)
self.random = random
if isinstance(default_value, dict):
if len(default_value) != len(self.primers) and set(dict.keys()) != set(
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
self.primers.keys(True, True)
):
raise ValueError(
"If a default_value dictionary is provided, it must match the primers keys."
)
default_value = {
key: default_value[key] for key in self.primers.keys(True, True)
}
self.default_value = default_value
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
self._validated = False
self.reset_key = reset_key

# sanity check
Expand Down Expand Up @@ -4559,6 +4581,9 @@ def to(self, *args, **kwargs):
self.primers = self.primers.to(device)
return super().to(*args, **kwargs)

def _try_expand_shape(self, spec):
return spec.expand((*self.parent.batch_size, *spec.shape))

def transform_observation_spec(
self, observation_spec: CompositeSpec
) -> CompositeSpec:
Expand All @@ -4568,15 +4593,20 @@ def transform_observation_spec(
)
for key, spec in self.primers.items():
if spec.shape[: len(observation_spec.shape)] != observation_spec.shape:
raise RuntimeError(
f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. "
f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}."
)
try:
expanded_spec = self._try_expand_shape(spec)
except AttributeError:
raise RuntimeError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will this be reached?

Copy link
Contributor Author

@albertbou92 albertbou92 Apr 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if for any reason self.parent is None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would transform_observation_spec be called when parent is None?

f"The leading shape of the primer specs ({self.__class__}) should match the one of the "
f"parent env. Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's "
f"shape is {expanded_spec.shape}."
)
spec = expanded_spec
try:
device = observation_spec.device
except RuntimeError:
device = self.device
observation_spec[key] = spec.to(device)
observation_spec[key] = self.primers[key] = spec.to(device)
return observation_spec

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
Expand All @@ -4589,8 +4619,25 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
def _batch_size(self):
return self.parent.batch_size

def _validate_value_tensor(self, value, spec):
if value.shape != spec.shape:
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
f"Value shape ({value.shape}) does not match the spec shape ({spec.shape})."
)
if value.dtype != spec.dtype:
raise RuntimeError(
f"Value dtype ({value.dtype}) does not match the spec dtype ({spec.dtype})."
)
if value.device != spec.device:
raise RuntimeError(
f"Value device ({value.device}) does not match the spec device ({spec.device})."
)
if not spec.is_in(value):
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(f"Value ({value}) is not in the spec domain ({spec}).")
return True

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
for key, spec in self.primers.items():
for key, spec in self.primers.items(True, True):
if spec.shape[: len(tensordict.shape)] != tensordict.shape:
raise RuntimeError(
"The leading shape of the spec must match the tensordict's, "
Expand All @@ -4601,11 +4648,22 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.random:
value = spec.rand()
else:
value = torch.full_like(
spec.zero(),
self.default_value,
)
if isinstance(self.default_value, dict):
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
value = self.default_value[key]
else:
value = self.default_value
if callable(value):
value = value()
if not self._validated:
self._validate_value_tensor(value, spec)
else:
value = torch.full_like(
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
spec.zero(),
value,
)
tensordict.set(key, value)
if not self._validated:
self._validated = True
return tensordict

def _step(
Expand Down Expand Up @@ -4634,22 +4692,39 @@ def _reset(
)
_reset = _get_reset(self.reset_key, tensordict)
if _reset.any():
for key, spec in self.primers.items():
for key, spec in self.primers.items(True, True):
if self.random:
value = spec.rand(shape)
else:
value = torch.full_like(
spec.zero(shape),
self.default_value,
)
prev_val = tensordict.get(key, 0.0)
value = torch.where(expand_as_right(_reset, value), value, prev_val)
if isinstance(self.default_value, dict):
value = self.default_value[key]
else:
value = self.default_value
if callable(value):
value = value()
if not self._validated:
self._validate_value_tensor(value, spec)
else:
value = torch.full_like(
spec.zero(shape),
value,
)
prev_val = tensordict.get(key, 0.0)
value = torch.where(
expand_as_right(_reset, value), value, prev_val
)
tensordict_reset.set(key, value)
self._validated = True
return tensordict_reset

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(primers={self.primers}, default_value={self.default_value}, random={self.random})"
default_value = (
self.default_value
if isinstance(self.default_value, float)
else self.default_value.__class__.__name__
)
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
return f"{class_name}(primers={self.primers}, default_value={default_value}, random={self.random})"


class PinMemoryTransform(Transform):
Expand Down
Loading