diff --git a/megatron/data/online_dataset.py b/megatron/data/online_dataset.py index a5fc3bd82..9a12c1875 100644 --- a/megatron/data/online_dataset.py +++ b/megatron/data/online_dataset.py @@ -23,7 +23,7 @@ import torch.utils.data import socket import pickle -from megatron.mpu.initialize import get_data_parallel_src_rank +from megatron.mpu.initialize import get_data_parallel_rank class OnlineDataset(torch.utils.data.Dataset): @@ -37,7 +37,7 @@ def __init__( dataserver_ports: Union[int, List[int]] = 10000, ): self.num_samples = num_samples - self.global_rank = get_data_parallel_src_rank() + self.global_rank = get_data_parallel_rank() self.leave_one_out = leave_one_out self.reward_buffer = [] self.online_batching_data = [] @@ -62,6 +62,7 @@ def update_online_batches(self): else: # in case we want to use different ports for different ranks, e.g. per machine sampling port = self.dataserver_ports[self.global_rank] + print(f"Connecting to {ipaddr}:{port}") s.connect((ipaddr, port)) s.send(self.data_split.encode()) data = b"" diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index dc696bd8e..e54c06ee0 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1119,9 +1119,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): Whether to cast logits to fp32 for Reinforce loss calculation. """ - use_full_kl: bool = True + kl_impl: Literal["abs", "mse", "kl", "full"] = "mse" """ - Use full KL divergence in Reinforce loss calculation. + KL divergence implementation, can be one of "abs", "mse", "kl", or "full" """ kl_div_beta: float = 0.1 diff --git a/megatron/training.py b/megatron/training.py index afe4b6c51..2efd94f61 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -205,7 +205,7 @@ def pretrain(neox_args): ) timers("model and optimizer").stop() - if neox_args.serve_weights: + if neox_args.serve_model_weights: start_server(model) # sync... torch.distributed.barrier() @@ -784,7 +784,7 @@ def forward_step( if type(ref_outputs) is tuple: ref_outputs, _ = ref_outputs ref_outputs = ref_outputs - if neox_args.use_full_kl: + if neox_args.kl_impl == "full": # Have to do the loss over all tokens... ref_outputs = gather_from_model_parallel_region(ref_outputs) if neox_args.fp32_reinforce: @@ -801,7 +801,7 @@ def forward_step( outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args) if type(outputs) is tuple: outputs, _ = outputs - if neox_args.use_full_kl: + if neox_args.kl_impl == "full": # Have to do the loss over all tokens... outputs = gather_from_model_parallel_region(outputs) if neox_args.fp32_reinforce: @@ -818,14 +818,20 @@ def forward_step( metrics["reward_std"] = raw_rewards.clone().detach().std() loss_mask_sum = loss_mask.sum() if reference_model is not None: - if neox_args.use_full_kl: + if neox_args.kl_impl == "full": # Following along with - # https://github.com/huggingface/trl/blob/104a02d207b63a4a062882aaff68f2d275493399/trl/trainer/ppo_trainer.py#L1120 + # https://github.com/huggingface/trl/blob/104a02d207b63a4a062882aaff68f2d275493399/trl/trainer/ppo_trainer.py#L1109 kl = F.kl_div(ref_logp, logp, log_target=True, reduction="none").sum(-1) - metrics["kl"] = kl.clone().detach().mean() else: - kl = (per_token_logp - ref_per_token_logp).sum(-1) - metrics["kl"] = kl.clone().detach() + kl = per_token_logp - ref_per_token_logp + if neox_args.kl_impl == "abs": + kl = kl.abs() + elif neox_args.kl_impl == "mse": + kl = 0.5 * (kl).square() + elif neox_args.kl_impl == "kl": + pass + with torch.no_grad(): + metrics["kl"] = kl.clone().detach().mean() loss = (-per_token_logp * rewards) + (neox_args.kl_div_beta * kl) loss = (loss * loss_mask).sum(-1) / loss_mask_sum loss = loss.mean() @@ -1145,7 +1151,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): (neox_args.train_impl == "kto") and (neox_args.precompute_model_name is None) ) - or ((neox_args.train_type == "reinforce") and (neox_args.kl_div_beta > 0.0)) + or ((neox_args.train_impl == "reinforce") and (neox_args.kl_div_beta > 0.0)) ) model = get_model(neox_args=neox_args, use_cache=use_cache) if needs_reference_model: @@ -1278,7 +1284,6 @@ def train_step( reference_model=None, ): """Single training step.""" - # Pipeline parallelism schedules forward/backward/step if neox_args.is_pipe_parallel: reduced_loss = train_step_pipe( diff --git a/post-training/configs/llama3-8b-reinforce.yml b/post-training/configs/llama3-8b-reinforce.yml new file mode 100644 index 000000000..8d8e04462 --- /dev/null +++ b/post-training/configs/llama3-8b-reinforce.yml @@ -0,0 +1,119 @@ +{ + "pipe_parallel_size": 0, + "model_parallel_size": 4, + "make_vocab_size_divisible_by": 1, + + # model settings + "num_layers": 32, + "hidden_size": 4096, + "num_attention_heads": 32, + "num_kv_heads": 8, + # llama3 supports more than this but this is just for testing. + "seq_length": 1024, + "max_position_embeddings": 1024, + "pos_emb": "rotary", + "rotary_pct": 1, + "rotary_emb_base": 500000, + "rope_fusion": true, + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + + "attention_config": [[["flash"], 32]], + + "scaled_upper_triang_masked_softmax_fusion": true, + "bias_gelu_fusion": false, + "use_bias_in_norms": false, + "use_bias_in_attn_linear": false, + "use_bias_in_mlp": false, + "use_flashattn_swiglu": true, + "activation": "swiglu", + "intermediate_size": 14336, + "mlp_multiple_of": 14336, + + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00001, + "betas": [0.9, 0.95], + "eps": 1.0e-8 + } + }, + "min_lr": 0.000001, + + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 1260000000, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 1260000000, + "contiguous_gradients": true, + "cpu_offload": false + }, + + "train_impl": "reinforce", + "dataset_impl": "online", + "reinforce_leave_one_out": true, + "fp32_reinforce": true, + "kl_impl": "abs", + "online_dataserver_ports": [10000, 10001], + "serve_model_weights": true, + "train_label_data_paths": [ "data/sft/llama3_train_messages_label_document" ], + "test_label_data_paths": [ "data/sft/llama3_test_messages_label_document" ], + "valid_label_data_paths": [ "data/sft/llama3_train_messages_label_document" ], + "train_data_paths": [ "data/sft/llama3_train_messages_document" ], + "test_data_paths": [ "data/sft/llama3_test_messages_document" ], + "valid_data_paths": [ "data/sft/llama3_train_messages_document" ], + + "train_micro_batch_size_per_gpu": 8, + "gradient_accumulation_steps": 4, + "data_impl": "mmap", + "pack_impl": "unpacked", + "num_workers": 1, + + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + "precision": "bfloat16", + "fp32_allreduce": true, + "bf16": { + "enabled": true + }, + "data_types": { + "grad_accum_dtype": "fp32" + }, + + "train_iters": 477, + "lr_decay_iters": 477, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.1, + "checkpoint_factor": 1000, + "eval_interval": 100, + "eval_iters": 10, + + "log_interval": 1, + "steps_per_print": 1, + "wall_clock_breakdown": true, + + + "save": "checkpoints/reinforce/llama3/llama3-8b-instruct", + #"load": "", # once run is started, to restart from intermediate ckpt use "load" = "save" + "load": "checkpoints/neox_converted/llama3-8b-instruct", + "vocab-file": "checkpoints/neox_converted/llama3-8b-instruct/tokenizer/tokenizer.json", + "use_wandb": true, + "wandb_group": "llama3-8b-instruct", + "wandb_project": "reinforce-test", + "finetune": true, # set to false once resuming from intermediate finetuning step + "tokenizer_type": "HFTokenizer", +} diff --git a/post-training/online_data_example_llama3.py b/post-training/online_data_example_llama3.py new file mode 100644 index 000000000..bdd902512 --- /dev/null +++ b/post-training/online_data_example_llama3.py @@ -0,0 +1,177 @@ +import socket +import threading +import datasets +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +import requests +import pickle +from collections import defaultdict +import time + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +def http_bot(url, pload): + for i in range(10): + try: + headers = {"User-Agent": "vLLM Client"} + response = requests.post(url, headers=headers, json=pload, stream=True) + data = response.json() + return data + except Exception as e: + # give it a few seconds to recover + time.sleep(5) + print(e) + continue + raise Exception("Failed to connect to server") + + +def threaded_data_gatherer( + prefix, + max_completion_len, + tokenizer, + model_name, + num_completions, + i, + dp_idx, + data_to_send, + rm_pipeline, +): + pload = { + "temperature": 1.0, + "max_tokens": 0, + "stop": "<|eot_id|>", + "stream": False, + "model": model_name, + "prompt": "", + "n": num_completions, + } + # Grab tokens... + prefix_tokens = tokenizer.encode(prefix) + prompt = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": "Please write a mildly negative movie review starting with " + + prefix, + } + ], + add_generation_prompt=True, + tokenize=False, + ) + prompt_tokens = tokenizer.encode(prompt) + pload["max_tokens"] = max_completion_len - len(prefix_tokens) + pload["prompt"] = prompt + prefix + completions = http_bot(f"http://localhost:{8000+dp_idx}/v1/completions", pload) + completions = [completion["text"].strip() for completion in completions["choices"]] + + def reward_fn(samples, **kwargs): + sentiments = list(map(get_positive_score, rm_pipeline(samples))) + return sentiments + + rewards = reward_fn([prefix + " " + completion for completion in completions]) + if i == 0 and dp_idx == 0: + print(completions) + completions = [ + tokenizer.encode(completion + "<|eot_id|>") for completion in completions + ] + data_to_send.append( + {"prefix": prompt_tokens, "completions": completions, "rewards": rewards} + ) + + +def data_generator( + bs_per_dp, + dataset, + tokenizer, + model_name, + max_prefix_len, + max_completion_len, + num_completions, + dp_idx, + dp_size, + tp_size, + rm_pipeline, +): + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind( + ("localhost", 10000 + dp_idx) + ) # only one data loader per data parallel group + split_counter = defaultdict(lambda: dp_idx) + while True: + server.listen(1) + conn, addr = server.accept() + split = conn.recv(4096).decode() + if split == "valid": + split = "unsupervised" + data_to_send = list() + threads = list() + for i in range(bs_per_dp): + prefix = " ".join( + dataset[split][split_counter[split]]["text"].split()[:5] + ) # grab a few words to prompt it... + split_counter[split] = (split_counter[split] + dp_size) % len( + dataset[split] + ) + threads.append( + threading.Thread( + target=threaded_data_gatherer, + args=( + prefix, + max_completion_len, + tokenizer, + model_name, + num_completions, + i, + dp_idx, + data_to_send, + rm_pipeline, + ), + ) + ) + threads[-1].start() + for thread in threads: + thread.join() + conn.send(pickle.dumps(data_to_send)) + conn.close() + print( + f"Sent data to {dp_idx} for {split} split at iter {split_counter[split]}..." + ) + + +if __name__ == "__main__": + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device="cpu", + ) + dataset = datasets.load_dataset("imdb") + threads = list() + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + for i in range(2): + threads.append( + threading.Thread( + target=data_generator, + args=( + 64, # bs_per_dp + dataset, # dataset + tokenizer, # tokenizer + "meta-llama/Meta-Llama-3-8B-Instruct", # model_name + 128, # max_prefix_len + 256, # max_completion_len + 4, # num_completions + i, # dp_idx + 2, # dp_size + 4, # tp_size + sentiment_fn, # rm_pipeline + ), + ) + ) + threads[-1].start() + for thread in threads: + thread.join() diff --git a/post-training/online_example.sh b/post-training/online_example.sh new file mode 100644 index 000000000..abe601faa --- /dev/null +++ b/post-training/online_example.sh @@ -0,0 +1,7 @@ +# Launch vllm +CUDA_VISIBLE_DEVICES=0,1,2,3 conda run --no-capture-output -n vllm python -m vllm.entrypoints.openai.api_server --model=meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --from-remote-program --tensor-parallel-size=4 --enforce-eager --gpu-memory-utilization=0.2 --port 8000 --max-model-len=1024 --max-num-seqs=512 & + +CUDA_VISIBLE_DEVICES=4,5,6,7 conda run --no-capture-output -n vllm python -m vllm.entrypoints.openai.api_server --model=meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --from-remote-program --tensor-parallel-size=4 --enforce-eager --gpu-memory-utilization=0.2 --port 8001 --max-model-len=1024 --max-num-seqs=512 & + +# Launch training +conda run --no-capture-output -n neox python deepy.py train.py post-training/configs/llama3-8b-reinforce.yml