Skip to content

Commit

Permalink
fix spliting
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Sep 6, 2024
1 parent e854127 commit b0caeaf
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 68 deletions.
152 changes: 85 additions & 67 deletions data_juicer/core/ray_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import concurrent.futures
import os
from typing import Any, Generator, List, Union

Expand All @@ -10,6 +11,7 @@
from data_juicer import cuda_device_count
from data_juicer.core.data import DJDataset
from data_juicer.ops import Filter, Mapper
from data_juicer.ops.base_op import OP
from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import DEFAULT_MAX_FILE_SIZE, Fields
from data_juicer.utils.process_utils import calculate_np
Expand Down Expand Up @@ -44,37 +46,41 @@ def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys):
return dict_with_paths


# TODO: check path for nestdataset
def set_dataset_to_absolute_path(dataset, dataset_path, cfg):
"""
Set all the path in input data to absolute path.
Checks dataset_dir and project_dir for valid paths.
"""
if not (cfg.video_key in dataset.columns() or cfg.image_key
in dataset.columns() or cfg.audio_key in dataset.columns()):
return dataset
dataset_dir = os.path.dirname(dataset_path)
dataset = dataset.map(lambda item: convert_to_absolute_paths(
item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key]))
logger.info(f"transfer {dataset.count()} sample's paths")
return dataset

class RayPreprocessOperator(OP):

def preprocess_dataset(dataset: Dataset, dataset_path, cfg) -> Dataset:
if dataset_path:
dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg)
columns = dataset.columns()
if Fields.stats not in columns:
logger.info(f'columns {columns}')

def process_batch_arrow(table: pa.Table) -> pa.Table:
new_column_data = [{} for _ in range(len(table))]
new_talbe = table.append_column(Fields.stats, [new_column_data])
return new_talbe
def __init__(self, dataset_path=None, cfg=None) -> None:
super().__init__()
self.dataset_path = dataset_path
self.cfg = cfg
self._name = 'RayPreporcess'

def run(self, dataset: Dataset) -> Dataset:
columns = dataset.columns()
if Fields.stats not in columns:
logger.info(f'columns {columns}')

def process_batch_arrow(table: pa.Table) -> pa.Table:
new_column_data = [{} for _ in range(len(table))]
new_talbe = table.append_column(Fields.stats,
[new_column_data])
return new_talbe

dataset = dataset.map_batches(process_batch_arrow,
batch_format='pyarrow')
if self.dataset_path:
# TODO: check path for nestdataset
if not (self.cfg.video_key in dataset.columns()
or self.cfg.image_key in dataset.columns()
or self.cfg.audio_key in dataset.columns()):
return dataset
dataset_dir = os.path.dirname(self.dataset_path)
dataset = dataset.map(lambda item: convert_to_absolute_paths(
item, dataset_dir,
[self.cfg.video_key, self.cfg.image_key, self.cfg.audio_key]))
return dataset

dataset = dataset.map_batches(process_batch_arrow,
batch_format='pyarrow')
return dataset
def use_cuda(self):
return False


def get_num_gpus(op, op_proc):
Expand Down Expand Up @@ -133,6 +139,11 @@ def split_jsonl(file_path: str, max_size: int,
break


def parallel_split_jsonl(file_path, max_size, output_dir) -> List[str]:
"""Wrapper function for using with ThreadPoolExecutor."""
return list(split_jsonl(file_path, max_size, output_dir))


def split_jsonl_dataset(
dataset_paths: Union[str, List[str]],
max_size: int,
Expand All @@ -152,10 +163,21 @@ def split_jsonl_dataset(
dataset_paths = [dataset_paths]

logger.info('Re-splitting dataset files...')
for path in dataset_paths:
for sub_file_path in split_jsonl(path, max_size, output_dir):
logger.info(f'Splited into {sub_file_path}')
yield sub_file_path

with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
futures = {
executor.submit(parallel_split_jsonl, path, max_size, output_dir):
path
for path in dataset_paths
}
for future in concurrent.futures.as_completed(futures):
try:
results = future.result()
for result in results:
logger.info(f'Splited into {result}')
yield result
except Exception as e:
logger.error(f'Failed to split file: {e}')


def get_jsonl_file_names(dataset_dir_path: str) -> List[str]:
Expand Down Expand Up @@ -185,8 +207,7 @@ def get_jsonl_file_names(dataset_dir_path: str) -> List[str]:

def best_file_num(cpu: int, memory: int, file_size: int) -> int:
"""Calculate the best number of files in a single batch.
Each cpu should process the same number of files (at least one),
while the total memory should be at least 2 times larger than the
The total memory should be at least 4 times larger than the
total file size.
Args:
Expand All @@ -197,9 +218,8 @@ def best_file_num(cpu: int, memory: int, file_size: int) -> int:
Returns:
int: best number of files in a single batch
"""
max_files_by_memory = memory // (16 * file_size)

best_num_files = max(1, (max_files_by_memory // cpu)) * cpu
max_files_by_memory = memory // (4 * file_size)
best_num_files = min(cpu, max_files_by_memory)
logger.info(f'Best number of files in a single batch: {best_num_files}')
return best_num_files

Expand Down Expand Up @@ -241,43 +261,36 @@ def __init__(self, datasets: Union[Dataset, Generator], cfg=None) -> None:
if cfg:
self.num_proc = cfg.np
self.output_dataset = []
self._ops = [RayPreprocessOperator(dataset_path=None, cfg=cfg)]

@classmethod
def read_jsonl(cls,
path: Union[str, List[str]],
cfg: Any = None) -> RayDataset:
files = split_jsonl_dataset(get_jsonl_file_names(path),
DEFAULT_MAX_FILE_SIZE, cfg.work_dir)
cpu = ray.cluster_resources().get('CPU', 0)
memory = ray.cluster_resources().get('memory', 0) / 1024 / 1024
logger.info(f'CPU: {cpu}, Memory: {memory}')
batch_file_num = best_file_num(cpu, memory, DEFAULT_MAX_FILE_SIZE)
return RayDataset(datasets=load_splited_json_dataset(
files, batch_file_num),
cfg=cfg)
cfg: Any = None,
resplit: bool = True) -> RayDataset:
if resplit:
resplit_dir = os.path.join(cfg.work_dir, 'resplit')
os.makedirs(resplit_dir, exist_ok=True)
files = split_jsonl_dataset(get_jsonl_file_names(path),
DEFAULT_MAX_FILE_SIZE, resplit_dir)
cpu = ray.cluster_resources().get('CPU', 0)
memory = ray.cluster_resources().get('memory', 0) / 1024 / 1024
logger.info(f'CPU: {cpu}, Memory: {memory}')
batch_file_num = best_file_num(cpu, memory, DEFAULT_MAX_FILE_SIZE)
return RayDataset(datasets=load_splited_json_dataset(
files, batch_file_num),
cfg=cfg)
else:
return RayDataset(datasets=rd.read_json(path), cfg=cfg)

@classmethod
def read_item(cls, data: dict, cfg: Any = None) -> RayDataset:
return RayDataset(dataset=rd.from_items(data), cfg=cfg)

def process(self,
operators,
*,
exporter=None,
checkpointer=None,
tracer=None) -> DJDataset:
outputs = []
for dataset in self.datasets:
# todo: pass dataset path into the function
data = preprocess_dataset(dataset, dataset_path=None, cfg=self.cfg)
if operators is None:
return self
if not isinstance(operators, list):
operators = [operators]
for op in operators:
data = self._run_single_op(op, data)
outputs.append(data)
self.datasets = outputs
def process(self, operators) -> DJDataset:
if not isinstance(operators, list):
operators = [operators]
self._ops.extend(operators)
return self

def _run_single_op(self, op, dataset: Dataset) -> Dataset:
Expand All @@ -298,6 +311,8 @@ def _run_single_op(self, op, dataset: Dataset) -> Dataset:
if op.stats_export_path is not None:
dataset.write_json(op.stats_export_path, force_ascii=False)
dataset = dataset.filter(op.process)
elif isinstance(op, RayPreprocessOperator):
dataset = op.run(dataset)
else:
logger.error(
'Ray executor only support Filter and Mapper OPs for now')
Expand All @@ -317,4 +332,7 @@ def to_pandas(self) -> pd.DataFrame:

def write_json(self, path: str, force_ascii: bool = False) -> None:
for dataset in self.datasets:
dataset.write_json(path, force_ascii=force_ascii)
if len(self._ops) > 0:
for op in self._ops:
dataset = self._run_single_op(op, dataset)
dataset.write_json(path, force_ascii=force_ascii)
4 changes: 3 additions & 1 deletion data_juicer/core/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def run(self, load_data_np=None):
from data_juicer.format.formatter import FORMATTERS
dataset = FORMATTERS.modules[obj_name](**args).load_dataset()
else:
dataset = RayDataset.read_jsonl(self.cfg.dataset_path, self.cfg)
dataset = RayDataset.read_jsonl(self.cfg.dataset_path,
self.cfg,
resplit=True)
# 2. extract processes
logger.info('Preparing process operators...')
ops = load_ops(self.cfg.process, self.cfg.op_fusion)
Expand Down

0 comments on commit b0caeaf

Please sign in to comment.