Skip to content

Commit

Permalink
concat completions
Browse files Browse the repository at this point in the history
  • Loading branch information
Almaz Dautov committed Nov 28, 2024
1 parent ac888d4 commit 1892559
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 17 deletions.
22 changes: 12 additions & 10 deletions turbo_alignment/generators/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def __init__(
if isinstance(transformers_settings.stop_strings, list):
raise ValueError('You should use only 1 eos token with VLLM')

eos_token_id: list[int] = self._tokenizer.encode(transformers_settings.stop_strings, add_special_tokens=False)

#TODO separate stop_strings and eos_token
self.eos_token_id: int = self._tokenizer.encode(transformers_settings.stop_strings, add_special_tokens=False)
assert len(self.eos_token_id) == 1, 'Currently stop_strings == stop_token'
#TODO beam search was deprecated in vllm 0.6.2

# beam_search_params: dict[str, Any] = {
Expand All @@ -45,7 +46,7 @@ def __init__(
print(f'Generation Params:{transformers_settings.stop_strings=}\n{transformers_settings.num_return_sequences=}\n\
{transformers_settings.repetition_penalty=}\n{transformers_settings.temperature=}\n\
{transformers_settings.top_p=}\n{transformers_settings.top_k=}\n\
{custom_generation_settings.skip_special_tokens=}\n{eos_token_id=}\n{transformers_settings.max_new_tokens=}', flush=True)
{custom_generation_settings.skip_special_tokens=}\n{self.eos_token_id=}\n{transformers_settings.max_new_tokens=}', flush=True)

self._sampling_params = SamplingParams(
n=transformers_settings.num_return_sequences,
Expand All @@ -54,7 +55,7 @@ def __init__(
top_p=transformers_settings.top_p,
top_k=transformers_settings.top_k,
skip_special_tokens=custom_generation_settings.skip_special_tokens,
stop_token_ids=eos_token_id,
stop_token_ids=self.eos_token_id,
max_tokens=transformers_settings.max_new_tokens,
# **beam_search_params,
)
Expand All @@ -70,7 +71,7 @@ def generate_from_batch_records(
self,
dataset_name: str,
records: list[dict[str, Any]],
original_records: list[ChatDatasetRecord] | None = None,
original_records: bool = False,
) -> list[ChatInferenceOutput]:
import os
#TODO Make sure that records are already splitted between ranks(Assuming micro_rollout_batch_size equal to micro_batch_size)
Expand All @@ -85,7 +86,8 @@ def generate_from_batch_records(
answers = []
for a in request_output.outputs:
answer_token_ids=torch.tensor(a.token_ids).unsqueeze(0)

#TODO assuming len(eos_token_id) == 1
answer_token_ids[:, -1] = self.eos_token_id[0]
ans_msg = AnswerMessage(
id=str(a.index),
content=a.text,
Expand All @@ -96,14 +98,14 @@ def generate_from_batch_records(

answers.append(ans_msg)
if original_records:
original_record = original_records[i]
outputs.append(
ChatInferenceOutput(
input_token_ids=torch.tensor(request_output.prompt_token_ids).unsqueeze(0),
id=original_record.id,
input_attention_mask=records['attention_mask'][i, :].unsqueeze(0).cpu(),
id=None,
dataset_name=dataset_name,
messages=original_record.messages,
label=original_record.label,
messages=None,
label=None,
answers=answers,
)
)
Expand Down
47 changes: 46 additions & 1 deletion turbo_alignment/pipelines/train/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,51 @@

logger = get_project_logger()

class ReinforceDataCollator(DataCollatorForTokenClassification):
def torch_call(self, features):
import torch
from transformers.data.data_collator import pad_without_fast_tokenizer_warning

for _ in features:
print(f'{_.keys()=}')

label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None

no_labels_features = [{k: v for k, v in feature.items() if k != label_name and isinstance(v, torch.Tensor)} for feature in features]

batch = pad_without_fast_tokenizer_warning(
self.tokenizer,
no_labels_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)

if labels is None:
return batch

sequence_length = batch["input_ids"].shape[1]
padding_side = self.tokenizer.padding_side

def to_list(tensor_or_iterable):
if isinstance(tensor_or_iterable, torch.Tensor):
return tensor_or_iterable.tolist()
return list(tensor_or_iterable)

if padding_side == "right":
batch[label_name] = [
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]
else:
batch[label_name] = [
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
]

batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
return batch

@ray.remote(num_gpus=1)
class TrainREINFORCEStrategy(BaseTrainStrategy[REINFORCETrainExperimentSettings], DistributedTorchRayActor):
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
Expand Down Expand Up @@ -68,7 +113,7 @@ def _get_data_collator(
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> Callable:
return DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)
return ReinforceDataCollator(tokenizer, pad_to_multiple_of=8)

@staticmethod
def _get_cherry_pick_callback(
Expand Down
24 changes: 18 additions & 6 deletions turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def __init__(

self.stop_generation_token_id = tokenizer.encode(args.stop_token, add_special_tokens=False)
assert len(self.stop_generation_token_id) == 1, self.stop_generation_token_id


#TODO separate stop_strings and eos_token

self.generator_transformers_settings = GeneratorTransformersSettings(
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
Expand Down Expand Up @@ -291,16 +293,26 @@ def get_answers_and_rewards(
generator = self._get_chat_generator()

generations: list[ChatInferenceOutput] = generator.generate_from_batch_records(
dataset_name='online', records=inputs
dataset_name='online', records=inputs, original_records=True
)
logging.info(f'generations elapsed time:{time.time() - start}')

response_ids = [ans.answer_token_ids for g in generations for ans in g.answers]
response_attention_mask = [ans.answer_attention_mask for g in generations for ans in g.answers]
for g in generations:
for ans in g.answers:
print(f'{g.input_token_ids.device=}, {ans.answer_token_ids.device=}', flush=True)
print(f'{g.input_attention_mask.device=}, {ans.answer_attention_mask.device=}', flush=True)
break
break
response_ids = [torch.cat([g.input_token_ids, ans.answer_token_ids], dim=1) for g in generations for ans in g.answers]
response_attention_mask = [torch.cat([g.input_attention_mask, ans.answer_attention_mask], dim=1) for g in generations for ans in g.answers]

import random
ind = random.randint(0, len(response_ids) - 1)
assert response_ids[ind].shape == response_attention_mask[ind].shape

if torch.distributed.get_rank() == 0:
print(f'Generated completion at index [0] shape: {response_ids[0].shape}', flush=True)
print(f'Completion decoded: {self.tokenizer.batch_decode(response_ids[0])}', flush=True)
print(f'Prompt with completion at index [0] shape: {response_ids[0].shape}', flush=True)
print(f'Prompt with completion decoded: {self.tokenizer.batch_decode(response_ids[0])}', flush=True)

# Padding
max_length = max([response_id.size(1) for response_id in response_ids])
Expand Down

0 comments on commit 1892559

Please sign in to comment.