Skip to content

Commit

Permalink
fix-lints
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexand Zamiralov committed Nov 18, 2024
1 parent 5720536 commit 6222186
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 35 deletions.
2 changes: 1 addition & 1 deletion turbo_alignment/cherry_picks/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Iterable

import torch
import wandb
from PIL import Image
from transformers import PreTrainedTokenizerBase

import wandb
from turbo_alignment.cherry_picks.base import CherryPickCallbackBase
from turbo_alignment.dataset.multimodal import InferenceMultimodalDataset
from turbo_alignment.generators.multimodal import MultimodalGenerator
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/common/logging/weights_and_biases.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any

import wandb
from wandb.sdk.lib.disabled import RunDisabled
from wandb.sdk.wandb_run import Run

import wandb
from turbo_alignment.settings.logging.weights_and_biases import WandbSettings


Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/common/tf/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pandas as pd
import wandb
from clearml import Task
from transformers import (
TrainerCallback,
Expand All @@ -13,7 +14,6 @@
from wandb.sdk.lib.disabled import RunDisabled
from wandb.sdk.wandb_run import Run

import wandb
from turbo_alignment.common.logging import get_project_logger

logger = get_project_logger()
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from turbo_alignment.metrics.distinctness import DistinctnessMetric
from turbo_alignment.metrics.diversity import DiversityMetric
from turbo_alignment.metrics.exact_match import ExactMatchMetric
from turbo_alignment.metrics.kl import KLMetric
from turbo_alignment.metrics.length import LengthMetric
from turbo_alignment.metrics.meteor import MeteorMetric
Expand All @@ -10,4 +11,3 @@
from turbo_alignment.metrics.reward import RewardMetric
from turbo_alignment.metrics.rouge import RougeMetric
from turbo_alignment.metrics.self_bleu import SelfBleuMetric
from turbo_alignment.metrics.exact_match import ExactMatchMetric
1 change: 1 addition & 0 deletions turbo_alignment/metrics/distinctness.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict

from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from turbo_alignment.metrics.metric import Metric
Expand Down
52 changes: 23 additions & 29 deletions turbo_alignment/metrics/exact_match.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from datetime import datetime, timedelta
from datetime import datetime

from turbo_alignment.dataset.chat import InferenceChatDataset
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.metrics.registry import ExactMatchMetricSettings
Expand All @@ -23,11 +24,11 @@ def compute(self, **kwargs) -> list[MetricResults]:
raise ValueError('predictions should not be None')

(
ref_functions,
_,
ref_arguments,
) = self._parse_function_with_arguments(references)
(
pred_functions,
_,
pred_arguments,
) = self._parse_function_with_arguments(predictions)

Expand All @@ -36,80 +37,73 @@ def compute(self, **kwargs) -> list[MetricResults]:
if 'date' in selected_key:
for day_lift in self._settings.possible_day_lifts:
temp_hits = []
for pred_argument, ref_argument in zip(
pred_arguments, ref_arguments):
for pred_argument, ref_argument in zip(pred_arguments, ref_arguments):
temp_hits.append(
self._full_match_argument(
parsed_predicted_arguments=pred_argument,
parsed_reference_arguments=ref_argument,
selected_key=selected_key,
day_lift=day_lift
day_lift=day_lift,
)
)
label = selected_key +'_'+str(day_lift)
label = selected_key + '_' + str(day_lift)
element_wise_scores.append(ElementWiseScores(label=label, values=temp_hits))
else:
temp_hits = []
for pred_argument, ref_argument in zip(
pred_arguments, ref_arguments):
for pred_argument, ref_argument in zip(pred_arguments, ref_arguments):
temp_hits.append(
self._full_match_argument(
parsed_predicted_arguments=pred_argument,
parsed_reference_arguments=ref_argument,
selected_key=selected_key
selected_key=selected_key,
)
)
label = selected_key
element_wise_scores.append(ElementWiseScores(label=label, values=temp_hits))
return [MetricResults(element_wise_scores=element_wise_scores, need_average=True)]


def _full_match_argument(self, parsed_predicted_arguments, parsed_reference_arguments,
selected_key,day_lift=None) -> bool:
def _full_match_argument(
self, parsed_predicted_arguments, parsed_reference_arguments, selected_key, day_lift=None
) -> bool:
key = selected_key
possible_day_lift = day_lift
if 'date' in key:
try:
predicted_date_obj = datetime.strptime(parsed_reference_arguments[key], '%Y-%m-%d')
reference_date_obj = datetime.strptime(parsed_predicted_arguments[key], '%Y-%m-%d')
if abs((predicted_date_obj - reference_date_obj).days) <= possible_day_lift:
return True
else:
return False
return abs((predicted_date_obj - reference_date_obj).days) <= possible_day_lift
except (ValueError, TypeError):
return False
else:
try:
return parsed_reference_arguments[self._settings.selected_argument_key] \
== parsed_predicted_arguments[self._settings.selected_argument_key]
except:
return parsed_reference_arguments[key] == parsed_predicted_arguments[key]
except Exception as e:
print(f'error in match-argument{e}')
return False


def _parse_function_with_arguments(self, inputs: InferenceChatDataset | list[list[str]])\
-> tuple[list, list]:
def _parse_function_with_arguments(self, inputs: InferenceChatDataset | list[list[str]]) -> tuple[list, list]:
functions = []
arguments = []
for input_batch in inputs:
for input in input_batch:
input_splitted = self._load_json(input)
if len(input_splitted) == 2:
functions.append(input_splitted.get('name',''))
arguments.append(input_splitted.get('parameters',''))
functions.append(input_splitted.get('name', ''))
arguments.append(input_splitted.get('parameters', ''))
else:
functions.append('')
arguments.append('')
return functions, arguments


@staticmethod
def _load_json(data):
try:
return json.loads(data)
return json.loads(data)
except json.JSONDecodeError as e:
error_position = e.pos
fixed_json = data[:error_position] + "null," + data[error_position + 1:]
fixed_json = data[:error_position] + 'null,' + data[error_position + 1 :]
try:
return json.loads(fixed_json)
return json.loads(fixed_json)
except json.JSONDecodeError as e2:
print(f'error in load-json in match-argument{e2}')
return {}
3 changes: 2 additions & 1 deletion turbo_alignment/metrics/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum

from pydantic import field_validator

from turbo_alignment.common.registry import Registrable
Expand Down Expand Up @@ -103,7 +104,7 @@ class RetrievalUtilitySettings(MetricSettings):
@MetricSettingsRegistry.register(MetricType.EXACT_MATCH)
class ExactMatchMetricSettings(MetricSettings):
selected_argument_keys: list[str]
possible_day_lifts: list[int] = [0,1,3]
possible_day_lifts: list[int] = [0, 1, 3]

class Config:
env_prefix: str = 'EXACT_MATCH_METRIC_'
2 changes: 1 addition & 1 deletion turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from turbo_alignment.settings.pipelines.train.dpo import (
APODownLossSettings,
APOZeroLossSettings,
ASFTLossSettings,
CPOLossSettings,
DPOLossesType,
ASFTLossSettings,
HingeLossSettings,
IPOLossSettings,
KTOLossSettings,
Expand Down

0 comments on commit 6222186

Please sign in to comment.