Skip to content

Commit

Permalink
- Add KTO Post-training example
Browse files Browse the repository at this point in the history
  • Loading branch information
dmahan93 committed Sep 25, 2024
1 parent f5d7ff9 commit 1f94a2a
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 2 deletions.
8 changes: 8 additions & 0 deletions post-training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/lla
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
```

## KTO data
```bash
python post-training/llama_dpo_data.py
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_test_filtered.jsonl --output-prefix data/kto/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
```


## Converting back to hf
```bash
Expand Down
120 changes: 120 additions & 0 deletions post-training/configs/llama3-8b-kto.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
{
"pipe_parallel_size": 0,
"model_parallel_size": 4,
"make_vocab_size_divisible_by": 1,

# model settings
"num_layers": 32,
"hidden_size": 4096,
"num_attention_heads": 32,
"num_kv_heads": 8,
# llama3 supports more than this but this is just for testing.
"seq_length": 1024,
"max_position_embeddings": 1024,
"pos_emb": "rotary",
"rotary_pct": 1,
"rotary_emb_base": 500000,
"rope_fusion": true,
"no_weight_tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,

"attention_config": [[["flash"], 32]],

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"use_bias_in_mlp": false,
"use_flashattn_swiglu": true,
"activation": "swiglu",
"intermediate_size": 14336,
"mlp_multiple_of": 14336,

"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00001,
"betas": [0.9, 0.95],
"eps": 1.0e-8
}
},
"min_lr": 0.000001,

"zero_optimization": {
"stage": 1,
"allgather_partitions": true,
"allgather_bucket_size": 1260000000,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1260000000,
"contiguous_gradients": true,
"cpu_offload": false
},


"train_impl": "kto",
"kto_fp32": true,
"kto_beta": 0.1,
"allow_chopped": false,
"train_label_data_paths": [ "data/kto/llama3_train_messages_label_document" ],
"test_label_data_paths": [ "data/kto/llama3_test_messages_label_document" ],
"valid_label_data_paths": [ "data/kto/llama3_train_messages_label_document" ],
"train_data_paths": [ "data/kto/llama3_train_messages_document" ],
"test_data_paths": [ "data/kto/llama3_test_messages_document" ],
"valid_data_paths": [ "data/kto/llama3_train_messages_document" ],
"train_reward_data_paths": [ "data/kto/llama3_train_messages_reward_document" ],
"test_reward_data_paths": [ "data/kto/llama3_test_messages_reward_document" ],
"valid_reward_data_paths": [ "data/kto/llama3_train_messages_reward_document" ],

"train_micro_batch_size_per_gpu": 32,
"gradient_accumulation_steps": 2,
"data_impl": "mmap",
"pack_impl": "unpacked",
"num_workers": 1,

"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

"precision": "bfloat16",
"fp32_allreduce": true,
"bf16": {
"enabled": true
},
"data_types": {
"grad_accum_dtype": "fp32"
},

"train_iters": 477,
"lr_decay_iters": 477,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.1,
"checkpoint_factor": 1000,
"eval_interval": 100,
"eval_iters": 10,

"log_interval": 1,
"steps_per_print": 1,
"wall_clock_breakdown": true,


"save": "checkpoints/kto/llama3/llama3-8b-instruct",
#"load": "", # once run is started, to restart from intermediate ckpt use "load" = "save"
"load": "checkpoints/neox_converted/llama3-8b-instruct",
"vocab-file": "checkpoints/neox_converted/llama3-8b-instruct/tokenizer/tokenizer.json",
"use_wandb": true,
"wandb_group": "llama3-8b-instruct",
"wandb_project": "ultrafeedback-kto",
"finetune": true, # set to false once resuming from intermediate finetuning step
"tokenizer_type": "HFTokenizer",
}
15 changes: 13 additions & 2 deletions post-training/llama_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
) as f:
writer = jsonlines.Writer(f)
for item in raw_datasets[split]:
# add empty system messages
item["chosen"] = item["chosen"]
item["rejected"] = item["rejected"]
writer.write(item)
Expand All @@ -33,6 +32,18 @@
) as f:
writer = jsonlines.Writer(f)
for item in raw_datasets[split]:
# add empty system messages
item["messages"] = item["chosen"]
writer.write(item)
os.makedirs(os.path.join("data", "kto"), exist_ok=True)
for split in ["train", "test"]:
with open(
os.path.join("data", "kto", f"llama3_kto_{split}_filtered.jsonl"), "w"
) as f:
writer = jsonlines.Writer(f)
for item in raw_datasets[split]:
item["messages"] = item["chosen"]
item["reward"] = 1
writer.write(item)
item["messages"] = item["rejected"]
item["reward"] = -1
writer.write(item)

0 comments on commit 1f94a2a

Please sign in to comment.