diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py index 0155867..440cbc0 100755 --- a/turbo_alignment/generators/chat.py +++ b/turbo_alignment/generators/chat.py @@ -31,8 +31,9 @@ def __init__( if isinstance(transformers_settings.stop_strings, list): raise ValueError('You should use only 1 eos token with VLLM') - eos_token_id: list[int] = self._tokenizer.encode(transformers_settings.stop_strings, add_special_tokens=False) - + #TODO separate stop_strings and eos_token + self.eos_token_id: int = self._tokenizer.encode(transformers_settings.stop_strings, add_special_tokens=False) + assert len(self.eos_token_id) == 1, 'Currently stop_strings == stop_token' #TODO beam search was deprecated in vllm 0.6.2 # beam_search_params: dict[str, Any] = { @@ -45,7 +46,7 @@ def __init__( print(f'Generation Params:{transformers_settings.stop_strings=}\n{transformers_settings.num_return_sequences=}\n\ {transformers_settings.repetition_penalty=}\n{transformers_settings.temperature=}\n\ {transformers_settings.top_p=}\n{transformers_settings.top_k=}\n\ - {custom_generation_settings.skip_special_tokens=}\n{eos_token_id=}\n{transformers_settings.max_new_tokens=}', flush=True) + {custom_generation_settings.skip_special_tokens=}\n{self.eos_token_id=}\n{transformers_settings.max_new_tokens=}', flush=True) self._sampling_params = SamplingParams( n=transformers_settings.num_return_sequences, @@ -54,7 +55,7 @@ def __init__( top_p=transformers_settings.top_p, top_k=transformers_settings.top_k, skip_special_tokens=custom_generation_settings.skip_special_tokens, - stop_token_ids=eos_token_id, + stop_token_ids=self.eos_token_id, max_tokens=transformers_settings.max_new_tokens, # **beam_search_params, ) @@ -70,7 +71,7 @@ def generate_from_batch_records( self, dataset_name: str, records: list[dict[str, Any]], - original_records: list[ChatDatasetRecord] | None = None, + original_records: bool = False, ) -> list[ChatInferenceOutput]: import os #TODO Make sure that records are already splitted between ranks(Assuming micro_rollout_batch_size equal to micro_batch_size) @@ -85,7 +86,8 @@ def generate_from_batch_records( answers = [] for a in request_output.outputs: answer_token_ids=torch.tensor(a.token_ids).unsqueeze(0) - + #TODO assuming len(eos_token_id) == 1 + answer_token_ids[:, -1] = self.eos_token_id[0] ans_msg = AnswerMessage( id=str(a.index), content=a.text, @@ -96,14 +98,14 @@ def generate_from_batch_records( answers.append(ans_msg) if original_records: - original_record = original_records[i] outputs.append( ChatInferenceOutput( input_token_ids=torch.tensor(request_output.prompt_token_ids).unsqueeze(0), - id=original_record.id, + input_attention_mask=records['attention_mask'][i, :].unsqueeze(0).cpu(), + id=None, dataset_name=dataset_name, - messages=original_record.messages, - label=original_record.label, + messages=None, + label=None, answers=answers, ) ) diff --git a/turbo_alignment/pipelines/train/reinforce.py b/turbo_alignment/pipelines/train/reinforce.py index 420100d..60167b3 100644 --- a/turbo_alignment/pipelines/train/reinforce.py +++ b/turbo_alignment/pipelines/train/reinforce.py @@ -37,6 +37,51 @@ logger = get_project_logger() +class ReinforceDataCollator(DataCollatorForTokenClassification): + def torch_call(self, features): + import torch + from transformers.data.data_collator import pad_without_fast_tokenizer_warning + + for _ in features: + print(f'{_.keys()=}') + + label_name = "label" if "label" in features[0].keys() else "labels" + labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None + + no_labels_features = [{k: v for k, v in feature.items() if k != label_name and isinstance(v, torch.Tensor)} for feature in features] + + batch = pad_without_fast_tokenizer_warning( + self.tokenizer, + no_labels_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + + if labels is None: + return batch + + sequence_length = batch["input_ids"].shape[1] + padding_side = self.tokenizer.padding_side + + def to_list(tensor_or_iterable): + if isinstance(tensor_or_iterable, torch.Tensor): + return tensor_or_iterable.tolist() + return list(tensor_or_iterable) + + if padding_side == "right": + batch[label_name] = [ + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels + ] + else: + batch[label_name] = [ + [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels + ] + + batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) + return batch + @ray.remote(num_gpus=1) class TrainREINFORCEStrategy(BaseTrainStrategy[REINFORCETrainExperimentSettings], DistributedTorchRayActor): def __init__(self, world_size, rank, local_rank, master_addr, master_port): @@ -68,7 +113,7 @@ def _get_data_collator( tokenizer: PreTrainedTokenizerBase, **kwargs, ) -> Callable: - return DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8) + return ReinforceDataCollator(tokenizer, pad_to_multiple_of=8) @staticmethod def _get_cherry_pick_callback( diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index 34f885a..da31ba2 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -174,7 +174,9 @@ def __init__( self.stop_generation_token_id = tokenizer.encode(args.stop_token, add_special_tokens=False) assert len(self.stop_generation_token_id) == 1, self.stop_generation_token_id - + + #TODO separate stop_strings and eos_token + self.generator_transformers_settings = GeneratorTransformersSettings( temperature=args.temperature, max_new_tokens=args.max_new_tokens, @@ -291,16 +293,26 @@ def get_answers_and_rewards( generator = self._get_chat_generator() generations: list[ChatInferenceOutput] = generator.generate_from_batch_records( - dataset_name='online', records=inputs + dataset_name='online', records=inputs, original_records=True ) logging.info(f'generations elapsed time:{time.time() - start}') - response_ids = [ans.answer_token_ids for g in generations for ans in g.answers] - response_attention_mask = [ans.answer_attention_mask for g in generations for ans in g.answers] + for g in generations: + for ans in g.answers: + print(f'{g.input_token_ids.device=}, {ans.answer_token_ids.device=}', flush=True) + print(f'{g.input_attention_mask.device=}, {ans.answer_attention_mask.device=}', flush=True) + break + break + response_ids = [torch.cat([g.input_token_ids, ans.answer_token_ids], dim=1) for g in generations for ans in g.answers] + response_attention_mask = [torch.cat([g.input_attention_mask, ans.answer_attention_mask], dim=1) for g in generations for ans in g.answers] + + import random + ind = random.randint(0, len(response_ids) - 1) + assert response_ids[ind].shape == response_attention_mask[ind].shape if torch.distributed.get_rank() == 0: - print(f'Generated completion at index [0] shape: {response_ids[0].shape}', flush=True) - print(f'Completion decoded: {self.tokenizer.batch_decode(response_ids[0])}', flush=True) + print(f'Prompt with completion at index [0] shape: {response_ids[0].shape}', flush=True) + print(f'Prompt with completion decoded: {self.tokenizer.batch_decode(response_ids[0])}', flush=True) # Padding max_length = max([response_id.size(1) for response_id in response_ids])