Skip to content

Commit

Permalink
[Feature] Some improvements to VecNorm (#2251)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 26, 2024
1 parent 849b3de commit 670a8cf
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 29 deletions.
29 changes: 29 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8005,6 +8005,35 @@ def test_to_obsnorm_multikeys(self):
td1 = transform0[0].to_observation_norm()._step(td, td.clone())
assert_allclose_td(td0, td1)

loc = transform0[0].loc
scale = transform0[0].scale
keys = list(transform0[0].in_keys)
td2 = (td.select(*keys) - loc) / (scale + torch.finfo(scale.dtype).eps)
td2.rename_key_("a", "a_avg")
td2.rename_key_(("b", "c"), ("b", "c_avg"))
assert_allclose_td(td0.select(*td2.keys(True, True)), td2)

def test_frozen(self):
transform0 = VecNorm(
in_keys=["a", ("b", "c")], out_keys=["a_avg", ("b", "c_avg")]
)
with pytest.raises(
RuntimeError, match="Make sure the VecNorm has been initialized"
):
transform0.frozen_copy()
td = TensorDict({"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4])
td0 = transform0._step(td, td.clone())
transform1 = transform0.frozen_copy()
td1 = transform1._step(td, td.clone())
assert_allclose_td(td0, td1)

td += 1
td2 = transform0._step(td, td.clone())
td3 = transform1._step(td, td.clone())
assert_allclose_td(td2, td3)
with pytest.raises(AssertionError):
assert_allclose_td(td0, td2)


def test_added_transforms_are_in_eval_mode_trivial():
base_env = ContinuousActionVecMockEnv()
Expand Down
169 changes: 140 additions & 29 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2454,6 +2454,9 @@ class ObservationNorm(ObservationTransform):
as it is done for standardization. Default is `False`.
eps (float, optional): epsilon increment for the scale in the ``standard_normal`` case.
Defaults to ``1e-6`` if not recoverable directly from the scale dtype.
Examples:
>>> torch.set_default_tensor_type(torch.DoubleTensor)
>>> r = torch.randn(100, 3)*torch.randn(3) + torch.randn(3)
Expand Down Expand Up @@ -2495,6 +2498,7 @@ def __init__(
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
standard_normal: bool = False,
eps: float | None = None,
):
if in_keys is None:
raise RuntimeError(
Expand All @@ -2517,7 +2521,13 @@ def __init__(
if not isinstance(standard_normal, torch.Tensor):
standard_normal = torch.as_tensor(standard_normal)
self.register_buffer("standard_normal", standard_normal)
self.eps = 1e-6
self.eps = (
eps
if eps is not None
else torch.finfo(scale.dtype).eps
if isinstance(scale, torch.Tensor) and scale.dtype.is_floating_point
else 1e-6
)

if loc is not None and not isinstance(loc, torch.Tensor):
loc = torch.tensor(loc, dtype=torch.get_default_dtype())
Expand Down Expand Up @@ -4815,7 +4825,10 @@ class VecNorm(Transform):
processes that share the same reference.
To use VecNorm at inference time and avoid updating the values with the new
observations, one should substitute this layer by `vecnorm.to_observation_norm()`.
observations, one should substitute this layer by :meth:`~.to_observation_norm`.
This will provide a static version of `VecNorm` which will not be updated
when the source transform is updated.
To get a frozen copy of the VecNorm layer, see :meth:`~.frozen_copy`.
Args:
in_keys (sequence of NestedKey, optional): keys to be updated.
Expand Down Expand Up @@ -4897,6 +4910,35 @@ def __init__(
self.decay = decay
self.shapes = shapes
self.eps = eps
self.frozen = False

def freeze(self) -> VecNorm:
"""Freezes the VecNorm, avoiding the stats to be updated when called.
See :meth:`~.unfreeze`.
"""
self.frozen = True
return self

def unfreeze(self) -> VecNorm:
"""Unfreezes the VecNorm.
See :meth:`~.freeze`.
"""
self.frozen = False
return self

def frozen_copy(self):
"""Returns a copy of the Transform that keeps track of the stats but does not update them."""
if self._td is None:
raise RuntimeError(
"Make sure the VecNorm has been initialized before creating a frozen copy."
)
clone = self.clone()
# replace values
clone._td = self._td.copy()
# freeze
return clone.freeze()

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
Expand Down Expand Up @@ -4980,52 +5022,78 @@ def _init(self, tensordict: TensorDictBase, key: str) -> None:
pass

def _update(self, key, value, N) -> torch.Tensor:
# TODO: we should revert this and have _td be like: TensorDict{"sum": ..., "ssq": ..., "count"...})
# to facilitate the computation of the stats using TD internals.
# Moreover, _td can be locked so these ops will be very fast on CUDA.
_sum = self._td.get(_append_last(key, "_sum"))
_ssq = self._td.get(_append_last(key, "_ssq"))
_count = self._td.get(_append_last(key, "_count"))

value_sum = _sum_left(value, _sum)
_sum *= self.decay
_sum += value_sum
self._td.set_(
_append_last(key, "_sum"),
_sum,
)

if not self.frozen:
_sum *= self.decay
_sum += value_sum
self._td.set_(
_append_last(key, "_sum"),
_sum,
)

_ssq = self._td.get(_append_last(key, "_ssq"))
value_ssq = _sum_left(value.pow(2), _ssq)
_ssq *= self.decay
_ssq += value_ssq
self._td.set_(
_append_last(key, "_ssq"),
_ssq,
)
if not self.frozen:
_ssq *= self.decay
_ssq += value_ssq
self._td.set_(
_append_last(key, "_ssq"),
_ssq,
)

_count = self._td.get(_append_last(key, "_count"))
_count *= self.decay
_count += N
self._td.set_(
_append_last(key, "_count"),
_count,
)
if not self.frozen:
_count *= self.decay
_count += N
self._td.set_(
_append_last(key, "_count"),
_count,
)

mean = _sum / _count
std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps).sqrt()
return (value - mean) / std.clamp_min(self.eps)

def to_observation_norm(self) -> Union[Compose, ObservationNorm]:
"""Converts VecNorm into an ObservationNorm class that can be used at inference time."""
"""Converts VecNorm into an ObservationNorm class that can be used at inference time.
The :class:`~torchrl.envs.ObservationNorm` layer can be updated using the :meth:`~torch.nn.Module.state_dict`
API.
Examples:
>>> from torchrl.envs import GymEnv, VecNorm
>>> vecnorm = VecNorm(in_keys=["observation"])
>>> train_env = GymEnv("CartPole-v1", device=None).append_transform(
... vecnorm)
>>>
>>> r = train_env.rollout(4)
>>>
>>> eval_env = GymEnv("CartPole-v1").append_transform(
... vecnorm.to_observation_norm())
>>> print(eval_env.transform.loc, eval_env.transform.scale)
>>>
>>> r = train_env.rollout(4)
>>> # Update entries with state_dict
>>> eval_env.transform.load_state_dict(
... vecnorm.to_observation_norm().state_dict())
>>> print(eval_env.transform.loc, eval_env.transform.scale)
"""
out = []
loc = self.loc
scale = self.scale
for key, key_out in zip(self.in_keys, self.out_keys):
_sum = self._td.get(_append_last(key, "_sum"))
_ssq = self._td.get(_append_last(key, "_ssq"))
_count = self._td.get(_append_last(key, "_count"))
mean = _sum / _count
std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps).sqrt()

_out = ObservationNorm(
loc=mean,
scale=std,
loc=loc.get(key),
scale=scale.get(key),
standard_normal=True,
in_keys=key,
out_keys=key_out,
Expand All @@ -5035,6 +5103,49 @@ def to_observation_norm(self) -> Union[Compose, ObservationNorm]:
return Compose(*out)
return _out

def _get_loc_scale(self, loc_only=False, scale_only=False):
loc = {}
scale = {}
for key in self.in_keys:
_sum = self._td.get(_append_last(key, "_sum"))
_ssq = self._td.get(_append_last(key, "_ssq"))
_count = self._td.get(_append_last(key, "_count"))
loc[key] = _sum / _count
scale[key] = (_ssq / _count - loc[key].pow(2)).clamp_min(self.eps).sqrt()
if not scale_only:
loc = TensorDict(loc)
else:
loc = None
if not loc_only:
scale = TensorDict(scale)
else:
scale = None
return loc, scale

@property
def standard_normal(self):
"""Whether the affine transform given by `loc` and `scale` follows the standard normal equation.
Similar to :class:`~torchrl.envs.ObservationNorm` standard_normal attribute.
Always returns ``True``.
"""
return True

@property
def loc(self):
"""Returns a TensorDict with the loc to be used for an affine transform."""
# We can't cache that value bc the summary stats could be updated by a different process
loc, _ = self._get_loc_scale(loc_only=True)
return loc

@property
def scale(self):
"""Returns a TensorDict with the scale to be used for an affine transform."""
# We can't cache that value bc the summary stats could be updated by a different process
_, scale = self._get_loc_scale(scale_only=True)
return scale

@staticmethod
def build_td_for_shared_vecnorm(
env: EnvBase,
Expand Down

0 comments on commit 670a8cf

Please sign in to comment.