Skip to content

Commit

Permalink
fixes 9
Browse files Browse the repository at this point in the history
  • Loading branch information
Almaz Dautov committed Nov 26, 2024
1 parent 2baa486 commit 7ce5374
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 48 deletions.
2 changes: 1 addition & 1 deletion turbo_alignment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def reinforce_training(
seed=0,
enable_prefix_caching=False,
enforce_eager=False,
max_model_len=1024,
max_model_len=experiment_settings.trainer_settings.actor_settings.max_model_len,
)

ray.get(policy_models.async_fit_actor_model(
Expand Down
41 changes: 30 additions & 11 deletions turbo_alignment/common/tf/loaders/model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import torch
from liger_kernel.transformers import (
apply_liger_kernel_to_gemma2,
apply_liger_kernel_to_llama,
apply_liger_kernel_to_qwen2,
)
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
Expand All @@ -7,7 +12,6 @@
PeftConfigRegistry,
TransformersAutoModelRegistry,
)
from turbo_alignment.modeling.liger_kernels import apply_liger_kernel_to_gemma2
from turbo_alignment.settings.model import (
ModelForPeftSettings,
ModelType,
Expand Down Expand Up @@ -46,22 +50,37 @@ def load_model(
model_settings: PreTrainedModelSettings,
tokenizer: PreTrainedTokenizerBase,
) -> PreTrainedModel:
if model_settings.liger_kernels_settings is not None:
apply_liger_kernel_to_llama(
rope=model_settings.liger_kernels_settings.use_rope,
cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy,
swiglu=model_settings.liger_kernels_settings.use_mlp,
rms_norm=model_settings.liger_kernels_settings.use_rms_norm,
fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy,
)

apply_liger_kernel_to_gemma2(
rope=model_settings.liger_kernels_settings.use_rope,
cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy,
geglu=model_settings.liger_kernels_settings.use_mlp,
rms_norm=model_settings.liger_kernels_settings.use_rms_norm,
fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy,
)

apply_liger_kernel_to_qwen2(
rope=model_settings.liger_kernels_settings.use_rope,
cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy,
swiglu=model_settings.liger_kernels_settings.use_mlp,
rms_norm=model_settings.liger_kernels_settings.use_rms_norm,
fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy,
)

model = TransformersAutoModelRegistry.by_name(model_settings.model_type).from_pretrained(
model_settings.model_path,
**model_settings.transformers_settings.dict(exclude_none=True),
**model_settings.model_kwargs,
torch_dtype=torch.bfloat16,
)
if model_settings.liger_kernels_settings is not None:
apply_liger_kernel_to_gemma2(
rope=model_settings.liger_kernels_settings.use_rope,
cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy,
geglu=model_settings.liger_kernels_settings.use_geglu,
model=model
)
import ray
print(f'node_id:{ray.get_runtime_context().get_node_id()}, gpu_id: {ray.get_runtime_context().get_accelerator_ids()["GPU"]}', flush=True)

if model_settings.transformers_settings.load_in_8bit:
model = prepare_model_for_int8_training(model)
Expand Down Expand Up @@ -99,4 +118,4 @@ def load_model(
)
model.base_model.model.score.weight.requires_grad = True

return model
return model
4 changes: 2 additions & 2 deletions turbo_alignment/common/tf/special_tokens_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def setSEP(self, sep_token: str | None) -> None:
return None

def set_all(self) -> None:
self.setBOS(bos_token=self._special_tokens_settings.bos_token)
# self.setBOS(bos_token=self._special_tokens_settings.bos_token)
self.setEOS(eos_token=self._special_tokens_settings.eos_token)
self.setPAD(pad_token=self._special_tokens_settings.pad_token)
self.setUNK(unk_token=self._special_tokens_settings.unk_token)
Expand All @@ -94,7 +94,7 @@ def set_custom_tokens(self, tokens: list[str]) -> None:
assert added_tokens == len(tokens)

def setup_model_config(self, model: PreTrainedModel) -> None:
model.config.bos_token_id = self._tokenizer.bos_token_id
# model.config.bos_token_id = self._tokenizer.bos_token_id
model.config.eos_token_id = self._tokenizer.eos_token_id

if self._tokenizer.pad_token_id is not None:
Expand Down
4 changes: 2 additions & 2 deletions turbo_alignment/dataset/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def _truncate_and_merge(
input_ids = input_ids[cut_index:]

labels = labels[-len(input_ids) :]
input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids))
labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels))
# input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids))
# labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels))

return input_ids, labels, conversation.get_prompt_repr(left_bound, right_bound)

Expand Down
3 changes: 2 additions & 1 deletion turbo_alignment/generators/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(
# if transformers_settings.num_beams > 1:
# beam_search_params['use_beam_search'] = True
# beam_search_params['best_of'] = transformers_settings.num_beams

print(f'SET MAX_NEW_TOKENS={transformers_settings.max_new_tokens}', flush=True)

self._sampling_params = SamplingParams(
n=transformers_settings.num_return_sequences,
repetition_penalty=transformers_settings.repetition_penalty,
Expand Down
18 changes: 15 additions & 3 deletions turbo_alignment/settings/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from enum import Enum
from pathlib import Path

from pydantic import model_validator

from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel
from turbo_alignment.settings.tf.model import ModelTransformersSettings
from turbo_alignment.settings.tf.peft import PEFT_TYPE
Expand All @@ -15,11 +17,21 @@ class ModelType(str, Enum):

class LigerKernelSettings(ExtraFieldsNotAllowedBaseModel):
use_rope: bool = True
use_cross_entropy: bool = True
use_geglu: bool = True
use_cross_entropy: bool = False
use_fused_linear_cross_entropy: bool = False
use_mlp: bool = True
use_rms_norm: bool = False

@model_validator(mode='after')
def check_cross_entopy_kernels(self) -> 'LigerKernelSettings':
if self.use_fused_linear_cross_entropy and self.use_cross_entropy:
raise ValueError(
'You cannot use both FusedLinearCrossEntropy and CrossEntropy kernels. '
'FusedLinearCrossEntropy is preferred if possible.'
)

return self


class PreTrainedModelSettings(ExtraFieldsNotAllowedBaseModel):
model_path: Path
Expand All @@ -45,4 +57,4 @@ class ModelForPeftSettings(PreTrainedModelSettings):

class PreTrainedMultiModalModel(PreTrainedAdaptersModelSettings):
projections_path: Path
n_modality_embeddings: int
n_modality_embeddings: int
1 change: 1 addition & 0 deletions turbo_alignment/settings/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class vLLMActorSettings(ExtraFieldsNotAllowedBaseModel):

vllm_num_engines: int = 1
vllm_tensor_parallel_size: int = 1
max_model_len: int = 8192

class HFActorSettings(ExtraFieldsNotAllowedBaseModel):
actor_type: Literal[ActorType.LOCAL_TRANSFORMERS] = ActorType.LOCAL_TRANSFORMERS
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/settings/pipelines/train/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Union

class REINFORCETrainerSettings(TrainerSettings):
max_tokens_count: int = 1024
max_new_tokens: int = 1024
stop_token: str = '<eos>'

penalty_reward_value: float = 0.1
Expand Down
38 changes: 11 additions & 27 deletions turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# FIXME
@dataclass
class REINFORCETrainingArguments(TrainingArguments):
max_tokens_count: int = 1024
max_new_tokens: int = 1024
stop_token: str = '<eos>'

penalty_reward_value: float = 0.1
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(

self.generator_transformers_settings = GeneratorTransformersSettings(
temperature=args.temperature,
max_length=args.max_tokens_count,
max_new_tokens=args.max_new_tokens,
num_return_sequences=args.num_generations,
num_beams=args.num_generations,
do_sample=True,
Expand All @@ -193,12 +193,10 @@ def __init__(
num_generations=args.num_generations,
)
start = time.time()
torch.cuda.memory._dump_snapshot("before_reward_stats.pickle")
self.print_readable_stats()
self.norm_reward_mean, self.norm_reward_std = self.reward_stats(
model=self.model, dataloader=self.get_train_dataloader()
)
torch.cuda.memory._dump_snapshot("after_stats.pickle")
logging.info(f'statictis in __init__ elapsed time:{time.time() - start}')
self.print_readable_stats()

Expand Down Expand Up @@ -275,21 +273,19 @@ def get_answers_and_rewards(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
import time
import logging
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
print(f'get_answers_and_rewards:inputs: {k}:{v.shape}', flush=True)

if torch.distributed.get_rank() == 0:
print(f'Input shape: {inputs["input_ids"].shape}', flush=True)
print(f'Input Example at index [0]: {self.tokenizer.batch_decode(inputs["input_ids"][0, :].unsqueeze(0))}')

if do_broadcast:
# TODO: move to generator
# print('broadcast', type(model), flush=True)
# assert isinstance(model, DeepSpeedEngine)
torch.distributed.barrier()
start = time.time()
self._broadcast_to_vllm(model)
torch.distributed.barrier()
logging.info(f'broadcast elapsed time:{time.time() - start}')

# model = self.accelerator.unwrap_model(model)
start = time.time()
generator = self._get_chat_generator()

Expand All @@ -301,13 +297,13 @@ def get_answers_and_rewards(
response_ids = [ans.answer_token_ids for g in generations for ans in g.answers]
response_attention_mask = [ans.answer_attention_mask for g in generations for ans in g.answers]

import logging
logging.info(f'Generated sample: {response_ids[0].shape}')
logging.info(f'Last token: {response_ids[0].squeeze(0)[-1]}')
logging.info(f'Decoded: {self.tokenizer.decode([response_ids[0].squeeze(0)[-1]])}')
if torch.distributed.get_rank() == 0:
print(f'Generated completion at index [0] shape: {response_ids[0].shape}', flush=True)
print(f'Completion decoded: {self.tokenizer.batch_decode(response_ids[0, :].unsqueeze(0))}', flush=True)

# Padding
max_length = 8192 # max(response_id.size(1) for response_id in response_ids) #45.89 GiB

max_length = max([response_id.size(1) for response_id in response_ids])
logging.info(f'{max_length=}')

def pad_sequences(sequences, max_length, pad_value=0):
Expand All @@ -332,10 +328,6 @@ def pad_sequences(sequences, max_length, pad_value=0):
'attention_mask': response_attention_mask,
'position_ids': position_ids,
}

# for k, v in rm_inputs.items():
# if isinstance(v, torch.Tensor):
# print(f'get_answers_and_rewards:rm_inputs: {k}, {v.device=}', flush=True)

start = time.time()
critic_generator = self._get_rm_generator(self.reward_model)
Expand Down Expand Up @@ -380,7 +372,6 @@ def get_batch_loss_metrics(
del inputs
gc.collect()
torch.cuda.empty_cache()
torch.cuda.memory._dump_snapshot("before_inference.pickle")

logprobs = self.get_logprobs(
model=model,
Expand Down Expand Up @@ -445,14 +436,10 @@ def print_readable_stats(self):
print(f"Reserved: {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB")

def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch=None):
print(f'Before batch_loss metrics\n\n')
self.print_readable_stats()

import logging
import time
import gc

logging.info(f'{isinstance(model, DeepSpeedEngine)=}')
start = time.time()

loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train')
Expand All @@ -461,9 +448,6 @@ def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in
logging.info(f'Compute Loss elapsed time:{time.time() - start}')
gc.collect()
torch.cuda.empty_cache()
print(f'After batch_loss metrics\n\n')
self.print_readable_stats()
torch.cuda.memory._dump_snapshot("after_compute_loss.pickle")

return (loss, metrics) if return_outputs else loss

Expand Down

0 comments on commit 7ce5374

Please sign in to comment.