From 1f94a2a2190289ea5fcb93c13320164d947eecca Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Wed, 25 Sep 2024 08:22:29 -0500 Subject: [PATCH] - Add KTO Post-training example --- post-training/README.md | 8 ++ post-training/configs/llama3-8b-kto.yml | 120 ++++++++++++++++++++++++ post-training/llama_data.py | 15 ++- 3 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 post-training/configs/llama3-8b-kto.yml diff --git a/post-training/README.md b/post-training/README.md index e6be7d931..1ba5cde2f 100644 --- a/post-training/README.md +++ b/post-training/README.md @@ -40,6 +40,14 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/lla python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages ``` +## KTO data +```bash +python post-training/llama_dpo_data.py +python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward +python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_test_filtered.jsonl --output-prefix data/kto/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward +python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward +``` + ## Converting back to hf ```bash diff --git a/post-training/configs/llama3-8b-kto.yml b/post-training/configs/llama3-8b-kto.yml new file mode 100644 index 000000000..e819d37cb --- /dev/null +++ b/post-training/configs/llama3-8b-kto.yml @@ -0,0 +1,120 @@ +{ + "pipe_parallel_size": 0, + "model_parallel_size": 4, + "make_vocab_size_divisible_by": 1, + + # model settings + "num_layers": 32, + "hidden_size": 4096, + "num_attention_heads": 32, + "num_kv_heads": 8, + # llama3 supports more than this but this is just for testing. + "seq_length": 1024, + "max_position_embeddings": 1024, + "pos_emb": "rotary", + "rotary_pct": 1, + "rotary_emb_base": 500000, + "rope_fusion": true, + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + + "attention_config": [[["flash"], 32]], + + "scaled_upper_triang_masked_softmax_fusion": true, + "bias_gelu_fusion": false, + "use_bias_in_norms": false, + "use_bias_in_attn_linear": false, + "use_bias_in_mlp": false, + "use_flashattn_swiglu": true, + "activation": "swiglu", + "intermediate_size": 14336, + "mlp_multiple_of": 14336, + + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00001, + "betas": [0.9, 0.95], + "eps": 1.0e-8 + } + }, + "min_lr": 0.000001, + + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 1260000000, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 1260000000, + "contiguous_gradients": true, + "cpu_offload": false + }, + + + "train_impl": "kto", + "kto_fp32": true, + "kto_beta": 0.1, + "allow_chopped": false, + "train_label_data_paths": [ "data/kto/llama3_train_messages_label_document" ], + "test_label_data_paths": [ "data/kto/llama3_test_messages_label_document" ], + "valid_label_data_paths": [ "data/kto/llama3_train_messages_label_document" ], + "train_data_paths": [ "data/kto/llama3_train_messages_document" ], + "test_data_paths": [ "data/kto/llama3_test_messages_document" ], + "valid_data_paths": [ "data/kto/llama3_train_messages_document" ], + "train_reward_data_paths": [ "data/kto/llama3_train_messages_reward_document" ], + "test_reward_data_paths": [ "data/kto/llama3_test_messages_reward_document" ], + "valid_reward_data_paths": [ "data/kto/llama3_train_messages_reward_document" ], + + "train_micro_batch_size_per_gpu": 32, + "gradient_accumulation_steps": 2, + "data_impl": "mmap", + "pack_impl": "unpacked", + "num_workers": 1, + + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + "precision": "bfloat16", + "fp32_allreduce": true, + "bf16": { + "enabled": true + }, + "data_types": { + "grad_accum_dtype": "fp32" + }, + + "train_iters": 477, + "lr_decay_iters": 477, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.1, + "checkpoint_factor": 1000, + "eval_interval": 100, + "eval_iters": 10, + + "log_interval": 1, + "steps_per_print": 1, + "wall_clock_breakdown": true, + + + "save": "checkpoints/kto/llama3/llama3-8b-instruct", + #"load": "", # once run is started, to restart from intermediate ckpt use "load" = "save" + "load": "checkpoints/neox_converted/llama3-8b-instruct", + "vocab-file": "checkpoints/neox_converted/llama3-8b-instruct/tokenizer/tokenizer.json", + "use_wandb": true, + "wandb_group": "llama3-8b-instruct", + "wandb_project": "ultrafeedback-kto", + "finetune": true, # set to false once resuming from intermediate finetuning step + "tokenizer_type": "HFTokenizer", +} diff --git a/post-training/llama_data.py b/post-training/llama_data.py index 5eef8c2d3..eab6ac9f1 100644 --- a/post-training/llama_data.py +++ b/post-training/llama_data.py @@ -22,7 +22,6 @@ ) as f: writer = jsonlines.Writer(f) for item in raw_datasets[split]: - # add empty system messages item["chosen"] = item["chosen"] item["rejected"] = item["rejected"] writer.write(item) @@ -33,6 +32,18 @@ ) as f: writer = jsonlines.Writer(f) for item in raw_datasets[split]: - # add empty system messages item["messages"] = item["chosen"] writer.write(item) +os.makedirs(os.path.join("data", "kto"), exist_ok=True) +for split in ["train", "test"]: + with open( + os.path.join("data", "kto", f"llama3_kto_{split}_filtered.jsonl"), "w" + ) as f: + writer = jsonlines.Writer(f) + for item in raw_datasets[split]: + item["messages"] = item["chosen"] + item["reward"] = 1 + writer.write(item) + item["messages"] = item["rejected"] + item["reward"] = -1 + writer.write(item)