From 60e810204b3efd0abd75e445d09a9e6d2514a2f9 Mon Sep 17 00:00:00 2001 From: RLDS Team Date: Wed, 10 May 2023 07:27:21 -0700 Subject: [PATCH] Let EpisodeWriter ability to create multi split datasets. PiperOrigin-RevId: 530901874 Change-Id: I8004130ee8527e0224b4d190ae9c56266616caec --- rlds/tfds/episode_writer.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/rlds/tfds/episode_writer.py b/rlds/tfds/episode_writer.py index e4caf73..ab2b852 100644 --- a/rlds/tfds/episode_writer.py +++ b/rlds/tfds/episode_writer.py @@ -14,7 +14,7 @@ # coding=utf-8 """TFDS episode writer.""" -from typing import Optional +from typing import Optional, Union from absl import logging from rlds import rlds_types @@ -32,7 +32,7 @@ def __init__( data_directory: str, ds_config: DatasetConfig, max_episodes_per_file: int = 1000, - split_name: Optional[str] = 'train', + split_name: Union[Optional[str], list[str]] = 'train', version: str = '0.0.1', overwrite: bool = True, file_format: str = 'tfrecord', @@ -67,8 +67,13 @@ def __init__( overwrite=overwrite, file_format=file_format, ) + if isinstance(split_name, list): + sequential_writer_splits = split_name + else: + sequential_writer_splits = [split_name] + self._split_name = split_name - self._sequential_writer.initialize_splits([split_name], + self._sequential_writer.initialize_splits(sequential_writer_splits, fail_if_exists=overwrite) logging.info('Creating dataset in: %r', self._data_directory) @@ -78,7 +83,25 @@ def add_episode(self, episode: rlds_types.Episode) -> None: Args: episode: episode to add to the dataset. """ + if isinstance(self._split_name, list): + raise ValueError( + 'This EpisodeWriter was configured as a multi split ' + 'writer. Please use add_episode_to_split() method.' + ) self._sequential_writer.add_examples({self._split_name: [episode]}) + def add_episode_to_split( + self, episode: rlds_types.Episode, split: str + ) -> None: + """Adds the episode to the dataset to a given split. + + Args: + episode: episode to add to the dataset. + split: The split to add this episode to. + """ + if isinstance(self._split_name, list) and split not in self._split_name: + raise ValueError('Unknown split %s' % split) + self._sequential_writer.add_examples({split: [episode]}) + def close(self) -> None: self._sequential_writer.close_all()