Skip to content

Commit

Permalink
try to train
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Nov 29, 2023
1 parent 48b0dd6 commit ccd004b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
11 changes: 10 additions & 1 deletion configs/49M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
},

# batch / data settings
"train_micro_batch_size_per_gpu": 32,
"train_micro_batch_size_per_gpu": 4,
"gas": 1,
"data_impl": "mmap",
"num_workers": 1,
Expand Down Expand Up @@ -88,4 +88,13 @@
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,

"train_data_paths": ["/mnt/ssd-1/hailey/conditional-training/llm-foundry/scripts/data_prep/my-copy-c4/train_small"],
"valid_data_paths": ["/mnt/ssd-1/hailey/conditional-training/llm-foundry/scripts/data_prep/my-copy-c4/val_small"],
"test_data_paths": ["/mnt/ssd-1/hailey/conditional-training/llm-foundry/scripts/data_prep/my-copy-c4/val_small"],

"use_streaming": true,

"tokenizer_type": "HFTokenizer",
"vocab-file": "/mnt/ssd-1/hailey/pythia/utils/20B_tokenizer.json"
}
8 changes: 6 additions & 2 deletions megatron/data/streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
except ModuleNotFoundError:
raise Exception("Must install `streaming` package to use StreamingDatasets!")

from typing import Optional, Sequence, Union, Any, Dict, List

import torch
import numpy as np

# TAKEN FROM MOSAICML LLM-FOUNDRY
# https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/data/text_data.py#L23C1-L192C28
class StreamingTextDataset(StreamingDataset):
Expand Down Expand Up @@ -189,7 +194,7 @@ def build_streaming_dataset(split, neox_args=None):

if data_weights:
# normalize proportions
data_weights = [weight / data_weights.sum() for weight in data_weights]
data_weights = [weight / sum(data_weights) for weight in data_weights]

streams = []
for i, path in enumerate(data_paths):
Expand All @@ -202,7 +207,6 @@ def build_streaming_dataset(split, neox_args=None):
)

return StreamingTextDataset(
tokenizer=neox_args.tokenizer.tokenizer, # TODO: drop this arg from the copied-over StreamingTextDataset
max_seq_len=neox_args.seq_length + 1,
streams=streams,
split=None,
Expand Down
19 changes: 13 additions & 6 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
get_params_for_weight_decay_optimization,
)
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.data.data_utils import build_train_valid_test_data_iterators
from megatron.data.data_utils import build_train_valid_test_data_iterators, build_train_valid_test_data_iterators_streaming
from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
from megatron.logging import tb_wandb_log, training_log
Expand Down Expand Up @@ -196,11 +196,18 @@ def pretrain(neox_args):

# Data stuff.
timers("train/valid/test data iterators").start()
(
train_data_iterator,
valid_data_iterator,
test_data_iterator,
) = build_train_valid_test_data_iterators(neox_args=neox_args)
if neox_args.use_streaming:
(
train_data_iterator,
valid_data_iterator,
test_data_iterator,
) = build_train_valid_test_data_iterators_streaming(neox_args=neox_args)
else:
(
train_data_iterator,
valid_data_iterator,
test_data_iterator,
) = build_train_valid_test_data_iterators(neox_args=neox_args)
timers("train/valid/test data iterators").stop()

if neox_args.use_mup and neox_args.coord_check:
Expand Down

0 comments on commit ccd004b

Please sign in to comment.