diff --git a/examples/video/video-from-dataset.py b/examples/video/video-from-dataset.py new file mode 100644 index 00000000000..71fe36fe7c3 --- /dev/null +++ b/examples/video/video-from-dataset.py @@ -0,0 +1,48 @@ +"""Video from dataset example. + +This example shows how to save a video from a dataset. + +To run it, you will need to install the openx requirements as well as torchvision. +""" + +from torchrl.data.datasets import OpenXExperienceReplay +from torchrl.record import CSVLogger, VideoRecorder + +# Create a logger that saves videos as mp4 +logger = CSVLogger("./dump", video_format="mp4") + + +# We use the VideoRecorder transform to save register the images coming from the batch. +t = VideoRecorder( + logger=logger, tag="pixels", in_keys=[("next", "observation", "image")] +) +# Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False) +dataset = OpenXExperienceReplay( + "cmu_stretch", + batch_size=2000, + slice_len=200, + download=True, + strict_length=False, + transform=t, +) + +# Get a batch of data and visualize it +for _ in dataset: + # The transform has seen the data since it's in the replay buffer + t.dump() + break + +# Alternatively, we can build the dataset without the VideoRecorder and call it manually: +dataset = OpenXExperienceReplay( + "cmu_stretch", + batch_size=2000, + slice_len=200, + download=True, + strict_length=False, +) + +# Get a batch of data and visualize it +for data in dataset: + t(data) + t.dump() + break diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index a1e8db6f782..6b4703902b0 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -305,9 +305,9 @@ def __init__( slice_len: int | None = None, pad: float | bool | None = None, replacement: bool = None, - streaming: bool = True, + streaming: bool | None = None, root: str | Path | None = None, - download: bool = False, + download: bool | None = None, sampler: Sampler | None = None, writer: Writer | None = None, collate_fn: Callable | None = None, @@ -317,6 +317,13 @@ def __init__( split_trajs: bool = False, strict_length: bool = True, ): + if download is None and streaming is None: + download = False + streaming = True + elif download is None: + download = not streaming + elif streaming is None: + streaming = not download self.download = download self.streaming = streaming self.dataset_id = dataset_id diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index da208519cd0..c7abe28f690 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -33,11 +33,12 @@ class VideoRecorder(ObservationTransform): in_keys (Sequence of NestedKey, optional): keys to be read to produce the video. Default is :obj:`"pixels"`. skip (int): frame interval in the output video. - Default is 2. + Default is ``2`` if the transform has a parent environment, and ``1`` if not. center_crop (int, optional): value of square center crop. make_grid (bool, optional): if ``True``, a grid is created assuming that a tensor of shape [B x W x H x 3] is provided, with B being the batch - size. Default is True. + size. Default is ``True`` if the transform has a parent environment, and ``False`` + if not. out_keys (sequence of NestedKey, optional): destination keys. Defaults to ``in_keys`` if not provided. @@ -66,6 +67,26 @@ class VideoRecorder(ObservationTransform): >>> env.transform.dump() + The transform can also be used within a dataset to save the video collected. Unlike in the environment case, + images will come in a batch. The ``skip`` argument will enable to save the images only at specific intervals. + + >>> from torchrl.data.datasets import OpenXExperienceReplay + >>> from torchrl.envs import Compose + >>> from torchrl.record import VideoRecorder, CSVLogger + >>> # Create a logger that saves videos as mp4 + >>> logger = CSVLogger("./dump", video_format="mp4") + >>> # We use the VideoRecorder transform to save register the images coming from the batch. + >>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")]) + >>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False) + >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200, + ... download=True, strict_length=False, + ... transform=t) + >>> # Get a batch of data and visualize it + >>> for data in dataset: + ... t.dump() + ... break + + Our video is available under ``./cheetah_videos/cheetah/videos/run_video_0.mp4``! """ @@ -75,9 +96,9 @@ def __init__( logger: Logger, tag: str, in_keys: Optional[Sequence[NestedKey]] = None, - skip: int = 2, + skip: int | None = None, center_crop: Optional[int] = None, - make_grid: bool = True, + make_grid: bool | None = None, out_keys: Optional[Sequence[NestedKey]] = None, **kwargs, ) -> None: @@ -102,12 +123,59 @@ def __init__( ) self.obs = [] + @property + def make_grid(self): + make_grid = self._make_grid + if make_grid is None: + if self.parent is not None: + self._make_grid = True + return True + self._make_grid = False + return False + return make_grid + + @make_grid.setter + def make_grid(self, value): + self._make_grid = value + + @property + def skip(self): + skip = self._skip + if skip is None: + if self.parent is not None: + self._skip = 2 + return 2 + self._skip = 1 + return 1 + return skip + + @skip.setter + def skip(self, value): + self._skip = value + def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: - if not (observation.shape[-1] == 3 or observation.ndimension() == 2): - raise RuntimeError(f"Invalid observation shape, got: {observation.shape}") - observation_trsf = observation.clone() self.count += 1 if self.count % self.skip == 0: + if ( + observation.ndim >= 3 + and observation.shape[-3] == 3 + and observation.shape[-2] > 3 + and observation.shape[-1] > 3 + ): + # permute the channels to the last dim + observation_trsf = observation.permute( + *range(observation.ndim - 3), -2, -1, -3 + ) + else: + observation_trsf = observation + if not ( + observation_trsf.shape[-1] == 3 or observation_trsf.ndimension() == 2 + ): + raise RuntimeError( + f"Invalid observation shape, got: {observation.shape}" + ) + observation_trsf = observation_trsf.clone() + if observation.ndimension() == 2: observation_trsf = observation.unsqueeze(-3) else: @@ -131,7 +199,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation_trsf = center_crop_fn( observation_trsf, [self.center_crop, self.center_crop] ) - if self.make_grid and observation_trsf.ndimension() == 4: + if self.make_grid and observation_trsf.ndimension() >= 4: if not _has_tv: raise ImportError( "Could not import torchvision, `make_grid` not available." @@ -139,30 +207,42 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: ) from torchvision.utils import make_grid - observation_trsf = make_grid(observation_trsf) - self.obs.append(observation_trsf.to(torch.uint8)) + observation_trsf = make_grid(observation_trsf.flatten(0, -4)) + self.obs.append(observation_trsf.to(torch.uint8)) + elif observation_trsf.ndimension() >= 4: + self.obs.extend(observation_trsf.to(torch.uint8).flatten(0, -4)) + else: + self.obs.append(observation_trsf.to(torch.uint8)) return observation + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + return self._call(tensordict) + def dump(self, suffix: Optional[str] = None) -> None: - """Writes the video to the self.logger attribute. + """Writes the video to the ``self.logger`` attribute. + + Calling ``dump`` when no image has been stored in a no-op. Args: suffix (str, optional): a suffix for the video to be recorded """ - if suffix is None: - tag = self.tag + if self.obs: + obs = torch.stack(self.obs, 0).unsqueeze(0).cpu() else: - tag = "_".join([self.tag, suffix]) - obs = torch.stack(self.obs, 0).unsqueeze(0).cpu() - del self.obs - if self.logger is not None: - self.logger.log_video( - name=tag, - video=obs, - step=self.iter, - **self.video_kwargs, - ) - del obs + obs = None + self.obs = [] + if obs is not None: + if suffix is None: + tag = self.tag + else: + tag = "_".join([self.tag, suffix]) + if self.logger is not None: + self.logger.log_video( + name=tag, + video=obs, + step=self.iter, + **self.video_kwargs, + ) self.iter += 1 self.count = 0 self.obs = []