Skip to content

Commit

Permalink
Allow to propagate file_format.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 530612315
Change-Id: I08da8ecf80381894bd0ff978ef49637d777d1fe5
  • Loading branch information
RLDS Team authored and copybara-github committed May 9, 2023
1 parent 5dece1d commit f5de5aa
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions rlds/tfds/episode_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
class EpisodeWriter():
"""Class that writes trajectory data in TFDS format (and RLDS structure)."""

def __init__(self,
data_directory: str,
ds_config: DatasetConfig,
max_episodes_per_file: int = 1000,
split_name: Optional[str] = 'train',
version: str = '0.0.1',
overwrite: bool = True):
def __init__(
self,
data_directory: str,
ds_config: DatasetConfig,
max_episodes_per_file: int = 1000,
split_name: Optional[str] = 'train',
version: str = '0.0.1',
overwrite: bool = True,
file_format: str = 'tfrecord',
):
"""Constructor.
Args:
Expand All @@ -45,6 +48,8 @@ def __init__(self,
version: version (major.minor.patch) of the dataset.
overwrite: if False, and there is an existing dataset, it will append to
it.
file_format: The format of the files to write, e.g. `tfrecord` or
`array_record`
"""

self._data_directory = data_directory
Expand All @@ -54,10 +59,14 @@ def __init__(self,
data_dir=data_directory,
module_name='')
self._ds_info = tfds.rlds.rlds_base.build_info(ds_config, ds_identity)
self._ds_info.set_file_format('tfrecord')
self._ds_info.set_file_format(file_format)

self._sequential_writer = tfds.core.SequentialWriter(
self._ds_info, max_episodes_per_file, overwrite=overwrite)
self._ds_info,
max_episodes_per_file,
overwrite=overwrite,
file_format=file_format,
)
self._split_name = split_name
self._sequential_writer.initialize_splits([split_name],
fail_if_exists=overwrite)
Expand Down

0 comments on commit f5de5aa

Please sign in to comment.