Skip to content

Commit

Permalink
Let EpisodeWriter ability to create multi split datasets.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 530901874
Change-Id: I8004130ee8527e0224b4d190ae9c56266616caec
  • Loading branch information
RLDS Team authored and copybara-github committed May 10, 2023
1 parent f5de5aa commit 60e8102
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions rlds/tfds/episode_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# coding=utf-8
"""TFDS episode writer."""
from typing import Optional
from typing import Optional, Union

from absl import logging
from rlds import rlds_types
Expand All @@ -32,7 +32,7 @@ def __init__(
data_directory: str,
ds_config: DatasetConfig,
max_episodes_per_file: int = 1000,
split_name: Optional[str] = 'train',
split_name: Union[Optional[str], list[str]] = 'train',
version: str = '0.0.1',
overwrite: bool = True,
file_format: str = 'tfrecord',
Expand Down Expand Up @@ -67,8 +67,13 @@ def __init__(
overwrite=overwrite,
file_format=file_format,
)
if isinstance(split_name, list):
sequential_writer_splits = split_name
else:
sequential_writer_splits = [split_name]

self._split_name = split_name
self._sequential_writer.initialize_splits([split_name],
self._sequential_writer.initialize_splits(sequential_writer_splits,
fail_if_exists=overwrite)
logging.info('Creating dataset in: %r', self._data_directory)

Expand All @@ -78,7 +83,25 @@ def add_episode(self, episode: rlds_types.Episode) -> None:
Args:
episode: episode to add to the dataset.
"""
if isinstance(self._split_name, list):
raise ValueError(
'This EpisodeWriter was configured as a multi split '
'writer. Please use add_episode_to_split() method.'
)
self._sequential_writer.add_examples({self._split_name: [episode]})

def add_episode_to_split(
self, episode: rlds_types.Episode, split: str
) -> None:
"""Adds the episode to the dataset to a given split.
Args:
episode: episode to add to the dataset.
split: The split to add this episode to.
"""
if isinstance(self._split_name, list) and split not in self._split_name:
raise ValueError('Unknown split %s' % split)
self._sequential_writer.add_examples({split: [episode]})

def close(self) -> None:
self._sequential_writer.close_all()

0 comments on commit 60e8102

Please sign in to comment.