Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: 🚀 Add gemma2 flash-attn #53

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
255 changes: 134 additions & 121 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ keywords = ["alignment", "llm", "dpo", "ppo", "rlhf"]

[tool.poetry.dependencies]
python = "^3.10"
transformers = "4.43.1"
transformers = ">=4.44"
peft = "0.8.2"
typer ="^0.9.0"
wandb ="^0.15.3"
Expand Down
11 changes: 11 additions & 0 deletions turbo_alignment/cli/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from turbo_alignment.settings.pipelines.inference.chat import (
ChatInferenceExperimentSettings,
)
from turbo_alignment.settings.pipelines.inference.metrics import MetricsSettings
from turbo_alignment.settings.pipelines.inference.multimodal import (
MultimodalInferenceExperimentSettings,
)
Expand Down Expand Up @@ -66,3 +67,13 @@ def infer_rag_entrypoint(
) -> None:
inference_settings = RAGInferenceExperimentSettings.parse_file(inference_settings_path)
pipelines.RAGInferenceStrategy().run(inference_settings)


@app.command(name='run_metrics', help='Run metrics evaluation')
def infer_metrics(
inference_settings_path: Path = typer.Option(
..., '--inference_settings_path', exists=True, help='Path to inference config file'
)
):
inference_settings = MetricsSettings.parse_file(inference_settings_path)
pipelines.MetricsInferenceStrategy().run(inference_settings)
14 changes: 14 additions & 0 deletions turbo_alignment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ def train_dpo_entrypoint(
pipelines.TrainDPOStrategy().run(experiment_settings)


@app.command(name='train_ref_model', help='Run DPO pipeline')
def train_ref_entrypoint(
experiment_settings_path: Path = typer.Option(
...,
'--experiment_settings_path',
exists=True,
help='Path to experiment config file',
)
) -> None:
print('ok')
experiment_settings = pipeline_settings.DPOTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.ref_model_DPOStrategy().run(experiment_settings)


@app.command(name='train_kto', help='Run KTO pipeline')
def train_kto_entrypoint(
experiment_settings_path: Path = typer.Option(
Expand Down
4 changes: 2 additions & 2 deletions turbo_alignment/common/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def read_jsonl(path: Path):
return [json.loads(line) for line in f]


def write_jsonl(records: Any, path: Path) -> None:
with path.open('w', encoding='utf-8') as f:
def write_jsonl(records: Any, path: Path, mode='w') -> None:
with path.open(mode, encoding='utf-8') as f:
for r in records:
f.write(json.dumps(r, ensure_ascii=False) + '\n')

Expand Down
11 changes: 6 additions & 5 deletions turbo_alignment/common/tf/special_tokens_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, special_tokens_settings:
self._special_tokens_settings = special_tokens_settings
self._special_tokens_already_set: bool = False

def setBOS(self, bos_token: str) -> None:
def setBOS(self, bos_token: str | None) -> None:
if self._tokenizer.bos_token_id is None:
logger.info('Model does not have bos_token_id')
self._tokenizer.add_special_tokens(special_tokens_dict={'bos_token': bos_token})
assert self._tokenizer.bos_token_id is not None
logger.info(f'Created bos_token_id = {self._tokenizer.bos_token_id}')
if bos_token is not None:
self._tokenizer.add_special_tokens(special_tokens_dict={'bos_token': bos_token})
logger.info(f'Created bos_token_id = {self._tokenizer.bos_token_id}')
else:
logger.info(f'Model has bos_token_id = {self._tokenizer.bos_token_id}')

Expand Down Expand Up @@ -94,9 +94,10 @@ def set_custom_tokens(self, tokens: list[str]) -> None:
assert added_tokens == len(tokens)

def setup_model_config(self, model: PreTrainedModel) -> None:
model.config.bos_token_id = self._tokenizer.bos_token_id
model.config.eos_token_id = self._tokenizer.eos_token_id

if self._tokenizer.bos_token_id is not None:
model.config.bos_token_id = self._tokenizer.bos_token_id
if self._tokenizer.pad_token_id is not None:
model.config.pad_token_id = self._tokenizer.pad_token_id
if self._tokenizer.sep_token_id is not None:
Expand Down
6 changes: 4 additions & 2 deletions turbo_alignment/dataset/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def _truncate_and_merge(
input_ids = input_ids[cut_index:]

labels = labels[-len(input_ids) :]
input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids))
labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels))
# input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids))
# labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels))

return input_ids, labels, conversation.get_prompt_repr(left_bound, right_bound)

Expand Down Expand Up @@ -299,6 +299,8 @@ def _encode(
inference=inference,
random_cut=random_cut,
)
if len(input_ids) >= 8000:
raise ValueError(f'{len(input_ids)=}, which is >=8000')

except ValueError as ex:
output.append(None)
Expand Down
6 changes: 6 additions & 0 deletions turbo_alignment/dataset/pair_preferences/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,10 @@ def __call__(self, examples: list[dict[str, dict[str, Any]]]) -> dict[str, Any]:
if 'precomputed_margin' in examples[0] and examples[0]['precomputed_margin'] is not None:
batch['precomputed_margin'] = torch.tensor([ex['precomputed_margin'] for ex in examples])

if 'ref_chosen_logps' in examples[0] and examples[0]['ref_chosen_logps'] is not None:
batch['ref_chosen_logps'] = torch.tensor([ex['ref_chosen_logps'] for ex in examples])

if 'ref_rejected_logps' in examples[0] and examples[0]['ref_rejected_logps'] is not None:
batch['ref_rejected_logps'] = torch.tensor([ex['ref_rejected_logps'] for ex in examples])

return batch
2 changes: 2 additions & 0 deletions turbo_alignment/dataset/pair_preferences/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ class PairPreferenceRecord(DatasetRecord):
answer_w: ChatMessage
answer_l: ChatMessage
precomputed_margin: float | None = None
ref_chosen_logps: float | None = None
ref_rejected_logps: float | None = None
4 changes: 3 additions & 1 deletion turbo_alignment/dataset/pair_preferences/pair_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str,
if not (chosen_record and rejected_record):
continue

ignore_keys = ['precomputed_margin']
ignore_keys = ['precomputed_margin', 'ref_chosen_logps', 'ref_rejected_logps']
if not self._add_labels:
ignore_keys.append('labels')

Expand All @@ -85,6 +85,8 @@ def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str,
'inputs_w': chosen_tokens,
'inputs_l': rejected_tokens,
'precomputed_margin': record.precomputed_margin,
'ref_chosen_logps': record.ref_chosen_logps,
'ref_rejected_logps': record.ref_rejected_logps,
}
)

Expand Down
18 changes: 6 additions & 12 deletions turbo_alignment/generators/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,15 @@ def _generate_from_batch(
if len(merged_inputs) == 0:
return []

assert self._micro_batch == 1, 'batch size should be 1'

rewards = []
with torch.no_grad():
input_ids = nn.utils.rnn.pad_sequence(
[item['input_ids'] for item in merged_inputs],
padding_value=self._tokenizer.pad_token_id,
batch_first=True,
)
attn_mask = nn.utils.rnn.pad_sequence(
[item['attention_mask'] for item in merged_inputs],
padding_value=0,
batch_first=True,
)
input_ids = [item['input_ids'].unsqueeze(0) for item in merged_inputs]
attn_mask = [item['attention_mask'].unsqueeze(0) for item in merged_inputs]
for i in range(0, len(input_ids), self._micro_batch):
input_ids_batch = input_ids[i : i + self._micro_batch].to(self.device)
attn_mask_batch = attn_mask[i : i + self._micro_batch].to(self.device)
input_ids_batch = input_ids[i].to(self.device)
attn_mask_batch = attn_mask[i].to(self.device)
rewards.extend(self._model(input_ids=input_ids_batch, attention_mask=attn_mask_batch).logits.cpu())

rewards = torch.cat(rewards, dim=0)
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/pipelines/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .chat import ChatInferenceStrategy
from .classification import ClassificationInferenceStrategy
from .metrics import MetricsInferenceStrategy
from .multimodal import MultimodalInferenceStrategy
from .rag import RAGInferenceStrategy
from .rm import RMInferenceStrategy
2 changes: 2 additions & 0 deletions turbo_alignment/pipelines/inference/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def _get_single_inference_settings(
dtype='bfloat16',
tensor_parallel_size=model_inference_settings.tensor_parallel_size,
enable_lora=enable_lora,
gpu_memory_utilization=0.9,
disable_custom_all_reduce=True,
)

else:
Expand Down
176 changes: 176 additions & 0 deletions turbo_alignment/pipelines/inference/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import gc
import json
import math
import os
from typing import Dict, List

import loguru
import torch
import vllm
from accelerate import Accelerator
from accelerate.utils import set_seed
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from turbo_alignment.common.tf.loaders import load_model, load_tokenizer
from turbo_alignment.dataset.base import BaseDataset
from turbo_alignment.dataset.chat import ChatDatasetRecord
from turbo_alignment.dataset.loader import DatasetLoader
from turbo_alignment.generators.vllm_chat import VLLMChatGenerator
from turbo_alignment.metrics import Metric, MetricSettingsRegistry
from turbo_alignment.metrics.registry import KLType
from turbo_alignment.metrics.utils import get_logits
from turbo_alignment.pipelines.base import BaseStrategy
from turbo_alignment.settings.datasets import DatasetStrategy
from turbo_alignment.settings.pipelines.inference.metrics import MetricsSettings


class MetricsInferenceStrategy(BaseStrategy):
def run(self, experiment_settings: MetricsSettings) -> None:
accelerator = Accelerator()
set_seed(seed=0, device_specific=False)
experiment_settings.save_path.mkdir(parents=True, exist_ok=True)

metrics = [
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
for metric in experiment_settings.metric_settings
]

model_inference_settings = experiment_settings.inference_settings
tokenizer = load_tokenizer(
model_inference_settings.tokenizer_settings,
model_inference_settings.model_settings,
)

model = load_model(model_inference_settings.model_settings, tokenizer)
model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True)

sft = load_model(model_inference_settings.model_settings, tokenizer) # fix
sft = accelerator.prepare_model(sft, device_placement=True, evaluation_mode=True)

vllm_model = vllm.LLM(
model=model_inference_settings.model_settings.model_path.absolute().as_posix(),
dtype='bfloat16',
gpu_memory_utilization=0.3,
tensor_parallel_size=model_inference_settings.tensor_parallel_size,
# max_model_len=model_inference_settings.max_model_len,
)

dataset = DatasetLoader[BaseDataset](BaseDataset).load_datasets(
experiment_settings.dataset_settings,
tokenizer=tokenizer,
strategy=DatasetStrategy.INFERENCE,
)[0]

generation_settings = experiment_settings.inference_settings.generation_settings[0]

batch_size = model_inference_settings.batch
generator = VLLMChatGenerator(
model=vllm_model,
tokenizer=tokenizer,
transformers_settings=generation_settings.transformers_settings,
custom_generation_settings=generation_settings.custom_settings,
batch=batch_size,
return_logits=True,
)

input_records = [dataset[idx] for idx in range(len(dataset))]
records_batches = [
input_records[i * batch_size : (i + 1) * batch_size] for i in range(math.ceil(len(dataset) / batch_size))
]

original_records_batches: list[list[ChatDatasetRecord]] = [
[dataset.get_original_record_by_id(r['id']) for r in batch] for batch in records_batches
]

num_gens = generation_settings.transformers_settings.num_return_sequences

metrics_kwargs = {}

with open(os.path.join(experiment_settings.save_path, 'metrics.jsonl'), 'w') as f:
for i, (records_batch, original_records_batch) in enumerate(
zip(records_batches, original_records_batches)
):
loguru.logger.info('batch {}/{}', i + 1, len(records_batches))

generation_outputs = generator._generate_from_batch(
records_batch,
original_records_batch,
dataset.source.name,
)

string_answers = [[answer.content for answer in g.answers] for g in generation_outputs]
string_labels = [[g.messages[-1].content] * len(g.answers) for g in generation_outputs]

flattened_answers = [answer for g in generation_outputs for answer in g.answers]
answer_tokens_ids = [answer.answer_token_ids.cpu() for answer in flattened_answers]
input_tokens_ids = [answer.input_token_ids.cpu() for answer in flattened_answers]

logits = get_logits(input_tokens_ids, answer_tokens_ids, model)
sft_logits = get_logits(input_tokens_ids, answer_tokens_ids, sft)
metrics_kwargs[KLType.SFT_MODEL] = sft_logits

batch_metrics = [{} for _ in range(batch_size)]

for i in range(len(batch_metrics)):
batch_metrics[i]['context'] = []
batch_metrics[i]['label'] = []
for idx in range(len(original_records_batch[i].messages[:-1])):
batch_metrics[i]['context'].append(
{
'content': original_records_batch[i].messages[idx].content,
'role': original_records_batch[i].messages[idx].role,
}
)

batch_metrics[i]['label'] = [
{
'content': original_records_batch[i].messages[-1].content,
'role': original_records_batch[i].messages[-1].role,
}
]

batch_metrics[i]['answers'] = [
{'id': idx, 'content': string_answers[i][idx]} for idx in range(len(string_answers[i]))
]

for metric in metrics:
metric_results = metric.compute(
dataset=dataset,
references=string_labels,
predictions=string_answers,
accelerator=accelerator,
tokenizer=tokenizer,
dataset_name=dataset.source.name,
answer_tokens_ids=answer_tokens_ids,
input_token_ids=input_tokens_ids,
logits=logits,
metrics_kwargs=metrics_kwargs,
)[0].element_wise_scores
for scores in metric_results:
metric_name = scores.label.split('@@', 1)[-1]
metric_values = scores.values
if metric_name == 'reward':
metric_values = [m[0] for m in metric_values]
if metric_name in ['reward', 'kl_with_sft', 'length', 'perplexity']:
metric_values = [
metric_values[i * num_gens : (i + 1) * num_gens]
for i in range(len(metric_values) // num_gens)
]

for i in range(len(metric_values)):
batch_metrics[i][metric_name] = metric_values[i]

for item in batch_metrics:
if len(item) == 0:
continue
json.dump(item, f)
f.write('\n')

del batch_metrics
del generation_outputs
del flattened_answers
del string_answers
del string_labels
del logits
del sft_logits
gc.collect()
1 change: 1 addition & 0 deletions turbo_alignment/pipelines/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .kto import TrainKTOStrategy
from .multimodal import TrainMultimodalStrategy
from .rag import TrainRAGStrategy
from .ref_model import ref_model_DPOStrategy
from .rm import TrainRMStrategy
from .sft import TrainSFTStrategy
Loading
Loading