Skip to content

Commit

Permalink
- Finish testing reinforce
Browse files Browse the repository at this point in the history
- Add example files
- TODO: Add the whole online loop instructions once vllm fork is pushed
  • Loading branch information
dmahan93 committed Sep 24, 2024
1 parent 053f67e commit b921170
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 14 deletions.
5 changes: 3 additions & 2 deletions megatron/data/online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.utils.data
import socket
import pickle
from megatron.mpu.initialize import get_data_parallel_src_rank
from megatron.mpu.initialize import get_data_parallel_rank


class OnlineDataset(torch.utils.data.Dataset):
Expand All @@ -37,7 +37,7 @@ def __init__(
dataserver_ports: Union[int, List[int]] = 10000,
):
self.num_samples = num_samples
self.global_rank = get_data_parallel_src_rank()
self.global_rank = get_data_parallel_rank()
self.leave_one_out = leave_one_out
self.reward_buffer = []
self.online_batching_data = []
Expand All @@ -62,6 +62,7 @@ def update_online_batches(self):
else:
# in case we want to use different ports for different ranks, e.g. per machine sampling
port = self.dataserver_ports[self.global_rank]
print(f"Connecting to {ipaddr}:{port}")
s.connect((ipaddr, port))
s.send(self.data_split.encode())
data = b""
Expand Down
4 changes: 2 additions & 2 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,9 +1119,9 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Whether to cast logits to fp32 for Reinforce loss calculation.
"""

use_full_kl: bool = True
kl_impl: Literal["abs", "mse", "kl", "full"] = "mse"
"""
Use full KL divergence in Reinforce loss calculation.
KL divergence implementation, can be one of "abs", "mse", "kl", or "full"
"""

kl_div_beta: float = 0.1
Expand Down
25 changes: 15 additions & 10 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def pretrain(neox_args):
)
timers("model and optimizer").stop()

if neox_args.serve_weights:
if neox_args.serve_model_weights:
start_server(model)
# sync...
torch.distributed.barrier()
Expand Down Expand Up @@ -784,7 +784,7 @@ def forward_step(
if type(ref_outputs) is tuple:
ref_outputs, _ = ref_outputs
ref_outputs = ref_outputs
if neox_args.use_full_kl:
if neox_args.kl_impl == "full":
# Have to do the loss over all tokens...
ref_outputs = gather_from_model_parallel_region(ref_outputs)
if neox_args.fp32_reinforce:
Expand All @@ -801,7 +801,7 @@ def forward_step(
outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(outputs) is tuple:
outputs, _ = outputs
if neox_args.use_full_kl:
if neox_args.kl_impl == "full":
# Have to do the loss over all tokens...
outputs = gather_from_model_parallel_region(outputs)
if neox_args.fp32_reinforce:
Expand All @@ -818,14 +818,20 @@ def forward_step(
metrics["reward_std"] = raw_rewards.clone().detach().std()
loss_mask_sum = loss_mask.sum()
if reference_model is not None:
if neox_args.use_full_kl:
if neox_args.kl_impl == "full":
# Following along with
# https://github.com/huggingface/trl/blob/104a02d207b63a4a062882aaff68f2d275493399/trl/trainer/ppo_trainer.py#L1120
# https://github.com/huggingface/trl/blob/104a02d207b63a4a062882aaff68f2d275493399/trl/trainer/ppo_trainer.py#L1109
kl = F.kl_div(ref_logp, logp, log_target=True, reduction="none").sum(-1)
metrics["kl"] = kl.clone().detach().mean()
else:
kl = (per_token_logp - ref_per_token_logp).sum(-1)
metrics["kl"] = kl.clone().detach()
kl = per_token_logp - ref_per_token_logp
if neox_args.kl_impl == "abs":
kl = kl.abs()
elif neox_args.kl_impl == "mse":
kl = 0.5 * (kl).square()
elif neox_args.kl_impl == "kl":
pass
with torch.no_grad():
metrics["kl"] = kl.clone().detach().mean()
loss = (-per_token_logp * rewards) + (neox_args.kl_div_beta * kl)
loss = (loss * loss_mask).sum(-1) / loss_mask_sum
loss = loss.mean()
Expand Down Expand Up @@ -1145,7 +1151,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
(neox_args.train_impl == "kto")
and (neox_args.precompute_model_name is None)
)
or ((neox_args.train_type == "reinforce") and (neox_args.kl_div_beta > 0.0))
or ((neox_args.train_impl == "reinforce") and (neox_args.kl_div_beta > 0.0))
)
model = get_model(neox_args=neox_args, use_cache=use_cache)
if needs_reference_model:
Expand Down Expand Up @@ -1278,7 +1284,6 @@ def train_step(
reference_model=None,
):
"""Single training step."""

# Pipeline parallelism schedules forward/backward/step
if neox_args.is_pipe_parallel:
reduced_loss = train_step_pipe(
Expand Down
119 changes: 119 additions & 0 deletions post-training/configs/llama3-8b-reinforce.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
{
"pipe_parallel_size": 0,
"model_parallel_size": 4,
"make_vocab_size_divisible_by": 1,

# model settings
"num_layers": 32,
"hidden_size": 4096,
"num_attention_heads": 32,
"num_kv_heads": 8,
# llama3 supports more than this but this is just for testing.
"seq_length": 1024,
"max_position_embeddings": 1024,
"pos_emb": "rotary",
"rotary_pct": 1,
"rotary_emb_base": 500000,
"rope_fusion": true,
"no_weight_tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,

"attention_config": [[["flash"], 32]],

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"use_bias_in_mlp": false,
"use_flashattn_swiglu": true,
"activation": "swiglu",
"intermediate_size": 14336,
"mlp_multiple_of": 14336,

"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00001,
"betas": [0.9, 0.95],
"eps": 1.0e-8
}
},
"min_lr": 0.000001,

"zero_optimization": {
"stage": 1,
"allgather_partitions": true,
"allgather_bucket_size": 1260000000,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1260000000,
"contiguous_gradients": true,
"cpu_offload": false
},

"train_impl": "reinforce",
"dataset_impl": "online",
"reinforce_leave_one_out": true,
"fp32_reinforce": true,
"kl_impl": "abs",
"online_dataserver_ports": [10000, 10001],
"serve_model_weights": true,
"train_label_data_paths": [ "data/sft/llama3_train_messages_label_document" ],
"test_label_data_paths": [ "data/sft/llama3_test_messages_label_document" ],
"valid_label_data_paths": [ "data/sft/llama3_train_messages_label_document" ],
"train_data_paths": [ "data/sft/llama3_train_messages_document" ],
"test_data_paths": [ "data/sft/llama3_test_messages_document" ],
"valid_data_paths": [ "data/sft/llama3_train_messages_document" ],

"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 4,
"data_impl": "mmap",
"pack_impl": "unpacked",
"num_workers": 1,

"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

"precision": "bfloat16",
"fp32_allreduce": true,
"bf16": {
"enabled": true
},
"data_types": {
"grad_accum_dtype": "fp32"
},

"train_iters": 477,
"lr_decay_iters": 477,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.1,
"checkpoint_factor": 1000,
"eval_interval": 100,
"eval_iters": 10,

"log_interval": 1,
"steps_per_print": 1,
"wall_clock_breakdown": true,


"save": "checkpoints/reinforce/llama3/llama3-8b-instruct",
#"load": "", # once run is started, to restart from intermediate ckpt use "load" = "save"
"load": "checkpoints/neox_converted/llama3-8b-instruct",
"vocab-file": "checkpoints/neox_converted/llama3-8b-instruct/tokenizer/tokenizer.json",
"use_wandb": true,
"wandb_group": "llama3-8b-instruct",
"wandb_project": "reinforce-test",
"finetune": true, # set to false once resuming from intermediate finetuning step
"tokenizer_type": "HFTokenizer",
}
Loading

0 comments on commit b921170

Please sign in to comment.