diff --git a/scripts/train/train.py b/scripts/train/train.py index ab3addbb07..180b8ef22b 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -11,6 +11,8 @@ from composer import Trainer from composer.core import Evaluator from composer.core.callback import Callback +from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler, + cyclic_schedule) from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om @@ -458,6 +460,33 @@ def main(cfg: DictConfig) -> Trainer: for name, logger_cfg in logger_configs.items() ] if logger_configs else None + # Profiling + profiler: Optional[Profiler] = None + profiler_cfg: Optional[DictConfig] = pop_config(cfg, + 'profiler', + must_exist=False, + convert=False, + default_value=None) + if profiler_cfg: + profiler_schedule_cfg: Dict = pop_config(profiler_cfg, + 'schedule', + must_exist=True, + convert=True) + profiler_schedule = cyclic_schedule(**profiler_schedule_cfg) + # Only support json trace handler + profiler_trace_handlers: List[TraceHandler] = [] + profiler_trace_cfg: Optional[Dict] = pop_config(profiler_cfg, + 'json_trace_handler', + must_exist=False, + default_value=None, + convert=True) + if profiler_trace_cfg: + profiler_trace_handlers.append( + JSONTraceHandler(**profiler_trace_cfg)) + profiler = Profiler(**profiler_cfg, + trace_handlers=profiler_trace_handlers, + schedule=profiler_schedule) + # Callbacks callbacks: List[Callback] = [ build_callback(str(name), callback_cfg) @@ -576,6 +605,7 @@ def main(cfg: DictConfig) -> Trainer: autoresume=autoresume, python_log_level=python_log_level, dist_timeout=dist_timeout, + profiler=profiler, ) print('Logging config') diff --git a/scripts/train/yamls/pretrain/mpt-small-cpu.yaml b/scripts/train/yamls/pretrain/mpt-small-cpu.yaml new file mode 100644 index 0000000000..cc04f11e44 --- /dev/null +++ b/scripts/train/yamls/pretrain/mpt-small-cpu.yaml @@ -0,0 +1,119 @@ +data_local: ./my-copy-c4 +data_remote: # If blank, files must be present in data_local +max_seq_len: 128 +global_seed: 17 + +# Run Name +run_name: mpt_causal_lm_cpu # If left blank, will be read from env var $RUN_NAME + +# Model +model: + name: mpt_causal_lm + init_device: cpu + d_model: 16 + n_heads: 4 + n_layers: 4 + expansion_ratio: 5 + max_seq_len: ${max_seq_len} + vocab_size: 50368 + attn_config: + attn_impl: torch + loss_fn: torch_crossentropy + +# Tokenizer +tokenizer: + name: EleutherAI/gpt-neox-20b + kwargs: + model_max_length: ${max_seq_len} + +# Dataloaders +train_loader: + name: text + dataset: + local: ${data_local} + remote: ${data_remote} + split: train + shuffle: true + max_seq_len: ${max_seq_len} + shuffle_seed: ${global_seed} + drop_last: true + num_workers: 2 + +eval_loader: + name: text + dataset: + local: ${data_local} + remote: ${data_remote} + split: val + shuffle: false + max_seq_len: ${max_seq_len} + shuffle_seed: ${global_seed} + drop_last: false + num_workers: 2 + +# Optimization +scheduler: + name: cosine_with_warmup + t_warmup: 100ba + alpha_f: 0.1 + +optimizer: + name: decoupled_adamw + lr: 6.0e-4 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 0.0 + +algorithms: + gradient_clipping: + clipping_type: norm + clipping_threshold: 1.0 + +max_duration: 10ba +eval_interval: 5ba +eval_first: false +eval_subset_num_batches: 5 +global_train_batch_size: 256 +autoresume: false + +# System +seed: ${global_seed} +device_eval_batch_size: 16 +device_train_microbatch_size: 16 +# device_train_microbatch_size: auto +precision: fp32 + +# FSDP +fsdp_config: + sharding_strategy: FULL_SHARD + mixed_precision: PURE + activation_checkpointing: false + activation_checkpointing_reentrant: false + activation_cpu_offload: false + limit_all_gathers: true + verbose: false + +# Logging +progress_bar: false +log_to_console: true +console_log_interval: 1ba + +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + memory_monitor: {} + runtime_estimator: {} + +# Checkpoint to local filesystem or remote object store +save_overwrite: true +save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +# save_interval: 500ba +# save_folder: ./{run_name}/checkpoints +# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints + +# Load from local filesystem or remote object store +# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt +# load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt