Skip to content

Commit

Permalink
Add CI test for trl rewarding and ppo, fix backward failure in ppo ca… (
Browse files Browse the repository at this point in the history
#1020)

Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi authored Jul 12, 2024
1 parent 270d150 commit 27a8f27
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 27 deletions.
8 changes: 4 additions & 4 deletions examples/trl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ There are three main steps to the PPO training process:
2. Reward modeling using dialog pairs from the SE dataset on the llama-v2-7b-se to create llama-v2-7b-se-rm
```
python ../gaudi_spawn.py --world_size 8 --use_mpi reward_modeling.py \
--model_name=./sft/final_merged_checkpoint \
--tokenizer_name=meta-llama/Llama-2-7b-hf \
--model_name_or_path=./sft/final_merged_checkpoint \
--tokenizer_name_or_path=meta-llama/Llama-2-7b-hf \
--output_dir=./rm
```
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
Expand All @@ -210,9 +210,9 @@ There are three main steps to the PPO training process:
3. RL fine-tuning of llama-v2-7b-se with the llama-v2-7b-se-rm reward model:
```
python ../gaudi_spawn.py --world_size 8 --use_mpi ppo.py \
--model_name=./sft/final_merged_checkpoint \
--model_name_or_path=./sft/final_merged_checkpoint \
--reward_model_name=./rm_merged_checkpoint \
--tokenizer_name=meta-llama/Llama-2-7b-hf \
--tokenizer_name_or_path=meta-llama/Llama-2-7b-hf \
--adafactor=False \
--output_max_length=128 \
--batch_size=8 \
Expand Down
27 changes: 18 additions & 9 deletions examples/trl/ppo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/rl_training.py, enable it for Gaudi2
import json
import time
from dataclasses import dataclass, field
from typing import List, Optional

Expand Down Expand Up @@ -26,8 +28,10 @@ class ScriptArguments:

# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
# models like gpt-neo* models are more suitable.
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
tokenizer_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the tokenizer name"})
model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
tokenizer_name_or_path: Optional[str] = field(
default="meta-llama/Llama-2-7b-hf", metadata={"help": "the tokenizer name"}
)
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
Expand Down Expand Up @@ -83,7 +87,7 @@ class ScriptArguments:
dataset_name = "lvwerra/stack-exchange-paired"
config = GaudiPPOConfig(
steps=script_args.steps,
model_name=script_args.model_name,
model_name=script_args.model_name_or_path,
learning_rate=script_args.learning_rate,
log_with=script_args.log_with,
batch_size=script_args.batch_size,
Expand Down Expand Up @@ -119,7 +123,7 @@ class ScriptArguments:
sent_kwargs["padding"] = "max_length"
sent_kwargs["max_length"] = script_args.input_max_length + script_args.output_max_length

tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name)
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path)
# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
# only for this model.

Expand All @@ -133,6 +137,7 @@ class ScriptArguments:
def build_dataset(
tokenizer,
dataset_name="lvwerra/stack-exchange-paired",
input_max_length=512,
):
"""
Build dataset for training. This builds the dataset from `load_dataset`, one should
Expand Down Expand Up @@ -168,14 +173,14 @@ def preprocess_function(examples):
num_proc=num_proc,
remove_columns=original_columns,
)
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False)
ds = ds.filter(lambda x: len(x["input_ids"]) < input_max_length, batched=False)

ds.set_format(type="torch")
return ds


# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(tokenizer)
dataset = build_dataset(tokenizer, input_max_length=script_args.input_max_length)


def collator(data):
Expand Down Expand Up @@ -232,7 +237,6 @@ def collator(data):
data_collator=collator,
optimizer=optimizer,
)

# We then build the sentiment analysis pipeline using our reward model, passing the
# model name and the sentiment analysis pipeline arguments. Let's also make sure to
# set the device to the same device as the PPOTrainer.
Expand Down Expand Up @@ -283,12 +287,13 @@ def collator(data):
output_length_sampler = LengthSampler(output_min_length, output_max_length)
else:
output_length_sampler = LengthSampler(output_max_length, output_max_length + 1)
s0 = time.time()
sample = 0
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
if epoch >= config.total_ppo_epochs:
break

question_tensors = batch["input_ids"]

sample = sample + len(question_tensors)
response_tensors = ppo_trainer.generate(
question_tensors,
return_prompt=False,
Expand All @@ -308,5 +313,9 @@ def collator(data):

if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")
s1 = time.time()

ppo_trainer.save_pretrained(script_args.output_dir)
metrics = {"train_runtime": s1 - s0, "train_samples_per_second": sample / (s1 - s0)}
with open(f"{script_args.output_dir}/all_results.json", mode="w") as file:
json.dump(metrics, file)
21 changes: 13 additions & 8 deletions examples/trl/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ class ScriptArguments:
gradient_accumulation_steps: Optional[int] = field(default=1)
learning_rate: Optional[float] = field(default=2e-5)
weight_decay: Optional[float] = field(default=0.001)
model_name: Optional[str] = field(
model_name_or_path: Optional[str] = field(
default="meta-llama/Llama-2-7b-hf",
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
},
)
tokenizer_name: Optional[str] = field(
tokenizer_name_or_path: Optional[str] = field(
default="meta-llama/Llama-2-7b-hf",
metadata={
"help": "The tokenizer for your model, if left empty will use the default for your model",
Expand Down Expand Up @@ -156,10 +156,12 @@ class ScriptArguments:
)

# Load the value-head model and tokenizer.
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
tokenizer_name = (
script_args.tokenizer_name_or_path
if script_args.tokenizer_name_or_path is not None
else script_args.model_name_or_path
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=script_args.token)
tokenizer.pad_token = tokenizer.eos_token

peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
Expand All @@ -171,14 +173,14 @@ class ScriptArguments:
)
torch.autograd.set_detect_anomaly(True)
model = AutoModelForSequenceClassification.from_pretrained(
script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16
script_args.model_name_or_path, num_labels=1, torch_dtype=torch.bfloat16
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Need to do this for gpt2, because it doesn't have an official pad token.
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model.config.pad_token_id = tokenizer.eos_token_id
model.config.use_cache = not script_args.gradient_checkpointing
model.config.use_fused_rope = False
Expand Down Expand Up @@ -268,7 +270,10 @@ def on_step_end(self, args, state, control, **kwargs):

trainer.add_callback(EvaluateFirstStepCallback())

trainer.train(script_args.resume_from_checkpoint)
train_result = trainer.train(script_args.resume_from_checkpoint)
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

print("Saving last checkpoint of the model")
trainer.save_model(script_args.output_dir)
54 changes: 54 additions & 0 deletions tests/baselines/llama_7b.json
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,60 @@
}
}
},
"trl-reward": {
"num_train_epochs": 1,
"eval_batch_size": 1,
"distribution": {
"multi_card": {
"learning_rate": 5e-4,
"train_batch_size": 1,
"train_runtime": 250,
"train_samples_per_second": 1.6,
"extra_arguments": [
"--logging_steps 1",
"--lora_r 8",
"--lora_alpha 16",
"--lora_dropout 0.05",
"--lora_target_modules q_proj v_proj k_proj out_proj fc_in fc_out wte",
"--max_length 1024",
"--eval_steps 200",
"--lr_scheduler_type cosine",
"--weight_decay 0.05",
"--gradient_accumulation_steps 4",
"--train_subset 500",
"--eval_subset 100"
]
}
}
},
"trl-ppo": {
"num_train_epochs": 1,
"eval_batch_size": 1,
"distribution": {
"multi_card": {
"learning_rate": 5e-4,
"train_batch_size": 8,
"train_runtime": 62,
"train_samples_per_second": 0.50,
"extra_arguments": [
"--lora_r 8",
"--lora_alpha 16",
"--lora_dropout 0.05",
"--reward_model_name HuggingFaceH4/tiny-random-LlamaForSequenceClassification",
"--lora_target_modules q_proj v_proj k_proj out_proj fc_in fc_out wte",
"--max_train_samples 1000",
"--use_habana",
"--ppo_epochs 1",
"--batched_gen True",
"--mini_batch_size 1",
"--output_max_length 128",
"--input_max_length 128",
"--learning_rate 1.4e-5",
"--early_stopping"
]
}
}
},
"prompt-tuning": {
"num_train_epochs": 20,
"eval_batch_size": 1,
Expand Down
45 changes: 39 additions & 6 deletions tests/test_examples.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ def is_valid_model_type(model_type: str) -> bool:
MODEL_FOR_CAUSAL_LM_MAPPING,
["llama"],
),
"reward_modeling": _get_supported_models_for_script(
MODELS_TO_TEST_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
["llama"],
),
"ppo": _get_supported_models_for_script(
MODELS_TO_TEST_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
["llama"],
),
"run_prompt_tuning_clm": _get_supported_models_for_script(
MODELS_TO_TEST_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
Expand Down Expand Up @@ -215,6 +225,8 @@ def to_test(
elif (
"sft" in example_name
or "dpo" in example_name
or "reward_modeling" in example_name
or "ppo" in example_name
or "prompt_tuning" in example_name
or example_name == "run_sequence_classification"
) and not IS_GAUDI2:
Expand Down Expand Up @@ -451,7 +463,6 @@ def test(self):

with open(Path(tmp_dir) / "all_results.json") as fp:
results = json.load(fp)

# Ensure performance requirements (accuracy, training time) are met
self.assert_no_regression(results, baseline.get("distribution").get(distribution), model_name)

Expand Down Expand Up @@ -521,7 +532,7 @@ def _create_command_line(
"--num_gpus 8",
"--no_local_rank",
]
if self.EXAMPLE_NAME == "dpo":
if self.EXAMPLE_NAME in ["dpo", "reward_modeling"]:
cmd_line += [
f"{script}",
f"--model_name_or_path {model_name}",
Expand All @@ -530,6 +541,14 @@ def _create_command_line(
f"--per_device_train_batch_size {train_batch_size}",
f"--per_device_eval_batch_size {eval_batch_size}",
]
elif self.EXAMPLE_NAME == "ppo":
cmd_line += [
f"{script}",
f"--model_name_or_path {model_name}",
f"--tokenizer_name_or_path {model_name}",
f"--output_dir {output_dir}",
f"--batch_size {train_batch_size}",
]
else:
cmd_line += [
f"{script}",
Expand All @@ -550,10 +569,10 @@ def _create_command_line(

if "compile" in task:
cmd_line += ["--use_lazy_mode False"]
elif self.EXAMPLE_NAME != "dpo":
elif self.EXAMPLE_NAME not in ["dpo", "ppo", "reward_modeling"]:
cmd_line += ["--use_lazy_mode"]

if "bloom" not in model_name and self.EXAMPLE_NAME != "dpo":
if "bloom" not in model_name and self.EXAMPLE_NAME not in ["dpo", "ppo", "reward_modeling"]:
cmd_line.append("--do_eval")

if extra_command_line_arguments is not None:
Expand Down Expand Up @@ -587,10 +606,12 @@ def assert_no_regression(self, results: Dict, baseline: Dict, model_name: str):
for metric_name in self.REGRESSION_METRICS.keys():
if metric_name in baseline and metric_name in results:
metrics_to_assess.append(metric_name)

# There is no accuracy metric for `run_clip.py`, `run_bridgetower.py` and BLOOM
min_number_metrics = 3
if self.EXAMPLE_NAME in ["run_clip", "run_bridgetower", "sft", "dpo"] or "bloom" in model_name:
if (
self.EXAMPLE_NAME in ["run_clip", "run_bridgetower", "sft", "dpo", "ppo", "reward_modeling"]
or "bloom" in model_name
):
min_number_metrics = 2

# Check that at least 3 metrics are assessed:
Expand Down Expand Up @@ -795,6 +816,18 @@ class MultiCardDPOExampleTester(ExampleTesterBase, metaclass=ExampleTestMeta, ex
DATASET_NAME = "lvwerra/stack-exchange-paired"


class MultiCardRewardExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="reward_modeling", multi_card=True
):
TASK_NAME = "trl-reward"
DATASET_NAME = "lvwerra/stack-exchange-paired"


class MultiCardPPOExampleTester(ExampleTesterBase, metaclass=ExampleTestMeta, example_name="ppo", multi_card=True):
TASK_NAME = "trl-ppo"
DATASET_NAME = "lvwerra/stack-exchange-paired"


class MultiCardProteinFoldingClassificationTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_sequence_classification", multi_card=True
):
Expand Down

0 comments on commit 27a8f27

Please sign in to comment.