Skip to content

Commit

Permalink
fix linters and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Sep 3, 2024
1 parent 4f20d96 commit 323e99f
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 23 deletions.
14 changes: 2 additions & 12 deletions turbo_alignment/dataset/pair_preferences/pair_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(
def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str, Any] | None]:
chosen_chat_records: list[ChatDatasetRecord] = []
rejected_chat_records: list[ChatDatasetRecord] = []
best_decode_records: list[ChatDatasetRecord] = []

for record in records:
context = [
Expand All @@ -60,24 +59,15 @@ def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str,
chosen = ChatMessage(role=record.answer_w.role, content=record.answer_w.content)
rejected = ChatMessage(role=record.answer_l.role, content=record.answer_l.content)

if record.best_decode is not None:
best_decode = ChatMessage(role=record.best_decode.role, content=record.best_decode.content)
best_decode_records.append(ChatDatasetRecord(id=record.id, messages=context + [best_decode]))

chosen_chat_records.append(ChatDatasetRecord(id=record.id, messages=context + [chosen]))
rejected_chat_records.append(ChatDatasetRecord(id=record.id, messages=context + [rejected]))

tokenized_chosen_records = self._chat_dataset.convert_records(chosen_chat_records)
tokenized_rejected_records = self._chat_dataset.convert_records(rejected_chat_records)

if len(best_decode_records) != 0:
tokenized_best_decode_records = self._chat_dataset.convert_records(best_decode_records)
else:
tokenized_best_decode_records = [None] * len(tokenized_chosen_records)

output: list[dict[str, Any] | None] = []
for record, chosen_record, rejected_record, best_decode_record in zip(
records, tokenized_chosen_records, tokenized_rejected_records, tokenized_best_decode_records
for record, chosen_record, rejected_record in zip(
records, tokenized_chosen_records, tokenized_rejected_records
):
if not (chosen_record and rejected_record):
continue
Expand Down
3 changes: 1 addition & 2 deletions turbo_alignment/pipelines/sampling/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils.operations import gather_object

from transformers import PreTrainedModel, PreTrainedTokenizerBase

from turbo_alignment.common.tf.loaders.model.model import load_model
Expand Down Expand Up @@ -64,7 +63,7 @@ def sample(self, experiment_settings: SamplingSettingsWithRMT) -> list[SamplingD
model = load_model(experiment_settings.rm, tokenizer)
if accelerator is not None:
model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True)

model.eval()

dataset = SamplingRMDataset(
Expand Down
8 changes: 5 additions & 3 deletions turbo_alignment/settings/pipelines/train/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class DPOLossesType(str, Enum):
SIGMOID = 'sigmoid'
SIGMOD_WITH_MARGIN = 'sigmoid_with_margin'
SIGMOID_WITH_MARGIN = 'sigmoid_with_margin'
HINGE = 'hinge'
IPO = 'ipo'
KTO = 'kto'
Expand All @@ -37,8 +37,10 @@ class SigmoidLossSettings(DPOLossSettings):
loss_type: Literal[DPOLossesType.SIGMOID]
label_smoothing: float = 0

class SigmoidLossWithMarginSettings(SigmoidLossSettings):
loss_type: Literal[DPOLossesType.SIGMOD_WITH_MARGIN]

class SigmoidLossWithMarginSettings(DPOLossSettings):
loss_type: Literal[DPOLossesType.SIGMOID_WITH_MARGIN]


class HingeLossSettings(DPOLossSettings):
loss_type: Literal[DPOLossesType.HINGE]
Expand Down
7 changes: 3 additions & 4 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
KTOLossSettings,
ORPOLossSettings,
SigmoidLossSettings,
SigmoidLossWithMarginSettings,
SimPOLossSettings,
SlicHfLossSettings,
SyncRefModelSettings,
Expand Down Expand Up @@ -294,22 +295,20 @@ def compute_loss(
rejected_rewards = self.beta * policy_rejected_logps.detach()

return losses, chosen_rewards, rejected_rewards


@DPOLossRegistry.register(DPOLossesType.SIGMOID_WITH_MARGIN)
class SigmoidLossWithMargin(DPOLossRegistry):
def __init__(self, *args, beta: float = 0.1, **kwargs) -> None:
self.beta = beta
super().__ini__(*args, **kwargs)

super().__init__(*args, **kwargs)

def compute_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
policy_best_decode_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
Expand Down
2 changes: 0 additions & 2 deletions turbo_alignment/trainers/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer
from turbo_alignment.trainers.utils import concatenated_inputs


logger = logging.get_logger(__name__)


class RMTrainer(MultiGPUCherryPicksTrainer):

def concatenated_forward(self, model: nn.Module, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]:
concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device)
all_rewards = model(
Expand Down

0 comments on commit 323e99f

Please sign in to comment.