diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 70c2b068f..49a68ea29 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -488,6 +488,10 @@ def build_train_valid_test_data_iterators(neox_args): return train_data_iterator, valid_data_iterator, test_data_iterator +def build_streaming_dataset(neox_args): + """build a StreamingTextDataset""" + + def build_train_valid_test_data_iterators_streaming(neox_args): """as above, but builds Mosaic StreamingDatasets instead""" @@ -546,7 +550,10 @@ def build_train_valid_test_data_iterators_streaming(neox_args): remote_dir = 's3://{data_path[0]}' # Local directory where dataset is cached during operation local_dir = '/tmp/cache-{data_path[0]}/{split}' - ds.append(StreamingDataset(local=local_dir, remote=remote_dir, split=None, shuffle=True)) # TODO: sampler from megatron handles shuffle, right? check this + + # TODO: switch to StreamingTextDataset from llm-foundry + new_ds = build_streaming_dataset(split=split, neox_args=neox_args) + ds.append(new_ds) #Load mosaic streaming datasets from train_data_paths, valid_data_paths, test_data_paths train_ds, valid_ds, test_ds = ds @@ -588,7 +595,7 @@ def build_train_valid_test_data_iterators_streaming(neox_args): neox_args.do_valid = flags[1].item() neox_args.do_test = flags[2].item() - # Shift the start iterations. TODO: how to do this with streamingdatasets? + # Shift the start iterations. TODO: how to do this with streamingdatasets? might be same if we still use our megatron sampler if train_dataloader is not None: train_dataloader.batch_sampler.start_iter = ( neox_args.iteration * neox_args.gradient_accumulation_steps diff --git a/megatron/data/streaming_dataset.py b/megatron/data/streaming_dataset.py new file mode 100644 index 000000000..d80e5aa0d --- /dev/null +++ b/megatron/data/streaming_dataset.py @@ -0,0 +1,209 @@ +from streaming import Stream, StreamingDataset + + +# TAKEN FROM MOSAICML LLM-FOUNDRY +# https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/data/text_data.py#L23C1-L192C28 +class StreamingTextDataset(StreamingDataset): + """Generic text dataset using MosaicML's StreamingDataset. + + Args: + max_seq_len (int): The max sequence length of each sample. + streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from, + which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or + ``remote``/``local``. Defaults to ``None``. + remote (str, optional): Remote path or directory to download the dataset from. If ``None``, + its data must exist locally. StreamingDataset uses either ``streams`` or + ``remote``/``local``. Defaults to ``None``. + local (str, optional): Local working directory to download shards to. This is where shards + are cached while they are being used. Uses a temp directory if not set. + StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. + split (str, optional): Which dataset split to use, if any. If provided, we stream from/to + the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. + download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. + download_timeout (float): Number of seconds to wait for a shard to download before raising + an exception. Defaults to ``60``. + validate_hash (str, optional): Optional hash or checksum algorithm to use to validate + shards. Defaults to ``None``. + keep_zip (bool): Whether to keep or delete the compressed form when decompressing + downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to + `False``. + epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all + streams. If ``None``, takes its value from the total number of underlying samples. + Provide this field if you are weighting streams relatively to target a larger or + smaller epoch size. Defaults to ``None``. + predownload (int, optional): Target number of samples ahead to download the shards of while + iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``. + cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's + shard cache. Before downloading a shard, the least recently used resident shard(s) may + be evicted (deleted from the local cache) in order to stay under the limit. Set to None + to disable shard eviction. Supports integer bytes as well as string human-readable + bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None. + partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. + num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with + resumption. If ``None``, this is interpreted as 64 times the number of physical + nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the + number of physical nodes of the initial run otherwise. Defaults to ``None``. + batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is + partitioned over the workers. Defaults to ``None``. + shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to + ``False``. + shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``. + shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. + shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split + into blocks of this size, and samples within each block are shuffled. If ``None``, its + value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to + ``None``. + sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. + Defaults to ``balanced``. + sampling_granularity (int): When picking samples for a stream's final partial repeat, + how many samples to pick from the same shard at a time (``1`` for evenly balanced + across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc). + Defaults to ``1``. + batching_method (str): Which batching method to use, either ``random``, ``stratified``, or + ``per_stream``. Defaults to ``random``. + """ + + def __init__(self, + max_seq_len: int, + streams: Optional[Sequence[Stream]] = None, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + download_retry: int = 2, + download_timeout: float = 60, + validate_hash: Optional[str] = None, + keep_zip: bool = False, + epoch_size: Optional[Union[int, str]] = None, + predownload: Optional[int] = None, + cache_limit: Optional[Union[int, str]] = None, + partition_algo: str = 'relaxed', + num_canonical_nodes: Optional[int] = None, + batch_size: Optional[int] = None, + shuffle: bool = False, + shuffle_algo: str = 'py1e', + shuffle_seed: int = 9176, + shuffle_block_size: Optional[int] = None, + sampling_method: str = 'balanced', + sampling_granularity: int = 1, + batching_method: str = 'random', + **kwargs: Any): + + group_method = kwargs.pop('group_method', None) + if group_method is not None: + raise NotImplementedError( + 'group_method is deprecated and has been removed.\nTo ' + + 'concatenate, use the --concat_tokens ' + + 'argument when creating your MDS dataset with concat_c4.py') + + if len(kwargs) > 0: + raise ValueError( + f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}' + ) + + if local is not None and (remote is None or (local == remote)): + if os.path.isdir(local): + contents = set(os.listdir(local)) + if split not in contents: + raise ValueError( + f'local directory {local} does not contain split {split}' + ) + + # TODO: discover where yamls are being converted incorrect, but temporary workaround + if isinstance(shuffle_block_size, float): + shuffle_block_size = int(shuffle_block_size) + + # Build Dataset + super().__init__( + streams=streams, + remote=remote, + local=local, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip, + epoch_size=epoch_size, + predownload=predownload, + cache_limit=cache_limit, + partition_algo=partition_algo, + num_canonical_nodes=num_canonical_nodes, + batch_size=batch_size, + shuffle=shuffle, + shuffle_algo=shuffle_algo, + shuffle_seed=shuffle_seed, + shuffle_block_size=shuffle_block_size, + sampling_method=sampling_method, + sampling_granularity=sampling_granularity, + batching_method=batching_method, + ) + + self.max_seq_len = max_seq_len + + def _read_binary_tokenized_sample(self, sample: Dict[str, + Any]) -> torch.Tensor: + return torch.from_numpy( + np.frombuffer(sample['tokens'], + dtype=np.int64)[:self.max_seq_len].copy()) + + # How to process a sample + def __getitem__(self, + idx: int) -> Union[Dict[str, List[int]], torch.Tensor]: + sample = super().__getitem__(idx) + if 'tokens' in sample: + token_sample = self._read_binary_tokenized_sample(sample) + else: + raise RuntimeError( + 'StreamingTextDataset needs samples to have a `tokens` column' + ) + return token_sample + + +def build_streaming_dataset(split, neox_args=None): + """build a StreamingTextDataset""" + + assert split in ["train", "valid", "test"] + + train_iters = neox_args.train_iters + eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters + test_iters = neox_args.eval_iters + train_val_test_num_samples = { + "train": train_iters * neox_args.train_batch_size, + "valid": eval_iters * neox_args.train_batch_size, + "test": test_iters * neox_args.train_batch_size, + } + + data_paths = { + "train": neox_args.train_data_paths, + "valid": neox_args.valid_data_paths, + "test": neox_args.test_data_paths + }[split] + + + data_weights = { + "train": neox_args.train_data_weights, + "valid": neox_args.valid_data_weights, + "test": neox_args.test_data_weights, + }[split] + + if data_weights: + # normalize proportions + data_weights = [weight / data_weights.sum() for weight in data_weights] + + streams = [] + for i, path in enumerate(data_paths): + streams.append( + Stream( + remote=path if "s3://" in path else None, + local=path, # TODO: right now, only support local datasets. + proportion=data_weights[i] if data_weights else None, # support for upsampling + ) + ) + + return StreamingTextDataset( + tokenizer=neox_args.tokenizer.tokenizer, # TODO: drop this arg from the copied-over StreamingTextDataset + max_seq_len=neox_args.seq_length + 1, + streams=streams, + split=None, + epoch_size=train_val_test_num_samples[split] + ) + diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 957960832..a6232463d 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -788,6 +788,14 @@ class NeoXArgsTraining(NeoXArgsTemplate): Warm up mmap files. """ + use_streaming: bool = False + """ + Whether to circumvent Megatron's builtin dataset implementations and use + StreamingDatasets. + + Must be used with train_data_paths splits as opposed to data_path. + """ + save: str = None """ Output directory to save checkpoints to.