From 6222186de1fe7200ce5c955a127195026d42d561 Mon Sep 17 00:00:00 2001 From: Alexand Zamiralov Date: Mon, 18 Nov 2024 15:50:09 +0300 Subject: [PATCH] fix-lints --- turbo_alignment/cherry_picks/multimodal.py | 2 +- .../common/logging/weights_and_biases.py | 2 +- .../common/tf/callbacks/logging.py | 2 +- turbo_alignment/metrics/__init__.py | 2 +- turbo_alignment/metrics/distinctness.py | 1 + turbo_alignment/metrics/exact_match.py | 52 ++++++++----------- turbo_alignment/metrics/registry.py | 3 +- turbo_alignment/trainers/dpo.py | 2 +- 8 files changed, 31 insertions(+), 35 deletions(-) diff --git a/turbo_alignment/cherry_picks/multimodal.py b/turbo_alignment/cherry_picks/multimodal.py index 80c6022..ba1c536 100755 --- a/turbo_alignment/cherry_picks/multimodal.py +++ b/turbo_alignment/cherry_picks/multimodal.py @@ -1,10 +1,10 @@ from typing import Iterable import torch +import wandb from PIL import Image from transformers import PreTrainedTokenizerBase -import wandb from turbo_alignment.cherry_picks.base import CherryPickCallbackBase from turbo_alignment.dataset.multimodal import InferenceMultimodalDataset from turbo_alignment.generators.multimodal import MultimodalGenerator diff --git a/turbo_alignment/common/logging/weights_and_biases.py b/turbo_alignment/common/logging/weights_and_biases.py index 112e585..3b5bdfd 100755 --- a/turbo_alignment/common/logging/weights_and_biases.py +++ b/turbo_alignment/common/logging/weights_and_biases.py @@ -1,9 +1,9 @@ from typing import Any +import wandb from wandb.sdk.lib.disabled import RunDisabled from wandb.sdk.wandb_run import Run -import wandb from turbo_alignment.settings.logging.weights_and_biases import WandbSettings diff --git a/turbo_alignment/common/tf/callbacks/logging.py b/turbo_alignment/common/tf/callbacks/logging.py index 8f90288..5623c14 100755 --- a/turbo_alignment/common/tf/callbacks/logging.py +++ b/turbo_alignment/common/tf/callbacks/logging.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +import wandb from clearml import Task from transformers import ( TrainerCallback, @@ -13,7 +14,6 @@ from wandb.sdk.lib.disabled import RunDisabled from wandb.sdk.wandb_run import Run -import wandb from turbo_alignment.common.logging import get_project_logger logger = get_project_logger() diff --git a/turbo_alignment/metrics/__init__.py b/turbo_alignment/metrics/__init__.py index b82054c..6e0ea97 100755 --- a/turbo_alignment/metrics/__init__.py +++ b/turbo_alignment/metrics/__init__.py @@ -1,5 +1,6 @@ from turbo_alignment.metrics.distinctness import DistinctnessMetric from turbo_alignment.metrics.diversity import DiversityMetric +from turbo_alignment.metrics.exact_match import ExactMatchMetric from turbo_alignment.metrics.kl import KLMetric from turbo_alignment.metrics.length import LengthMetric from turbo_alignment.metrics.meteor import MeteorMetric @@ -10,4 +11,3 @@ from turbo_alignment.metrics.reward import RewardMetric from turbo_alignment.metrics.rouge import RougeMetric from turbo_alignment.metrics.self_bleu import SelfBleuMetric -from turbo_alignment.metrics.exact_match import ExactMatchMetric diff --git a/turbo_alignment/metrics/distinctness.py b/turbo_alignment/metrics/distinctness.py index a1ff2bd..4f10de2 100755 --- a/turbo_alignment/metrics/distinctness.py +++ b/turbo_alignment/metrics/distinctness.py @@ -1,4 +1,5 @@ from collections import defaultdict + from transformers.tokenization_utils_base import PreTrainedTokenizerBase from turbo_alignment.metrics.metric import Metric diff --git a/turbo_alignment/metrics/exact_match.py b/turbo_alignment/metrics/exact_match.py index 97cc0d7..eb9a94f 100644 --- a/turbo_alignment/metrics/exact_match.py +++ b/turbo_alignment/metrics/exact_match.py @@ -1,5 +1,6 @@ import json -from datetime import datetime, timedelta +from datetime import datetime + from turbo_alignment.dataset.chat import InferenceChatDataset from turbo_alignment.metrics.metric import Metric from turbo_alignment.metrics.registry import ExactMatchMetricSettings @@ -23,11 +24,11 @@ def compute(self, **kwargs) -> list[MetricResults]: raise ValueError('predictions should not be None') ( - ref_functions, + _, ref_arguments, ) = self._parse_function_with_arguments(references) ( - pred_functions, + _, pred_arguments, ) = self._parse_function_with_arguments(predictions) @@ -36,80 +37,73 @@ def compute(self, **kwargs) -> list[MetricResults]: if 'date' in selected_key: for day_lift in self._settings.possible_day_lifts: temp_hits = [] - for pred_argument, ref_argument in zip( - pred_arguments, ref_arguments): + for pred_argument, ref_argument in zip(pred_arguments, ref_arguments): temp_hits.append( self._full_match_argument( parsed_predicted_arguments=pred_argument, parsed_reference_arguments=ref_argument, selected_key=selected_key, - day_lift=day_lift + day_lift=day_lift, ) ) - label = selected_key +'_'+str(day_lift) + label = selected_key + '_' + str(day_lift) element_wise_scores.append(ElementWiseScores(label=label, values=temp_hits)) else: temp_hits = [] - for pred_argument, ref_argument in zip( - pred_arguments, ref_arguments): + for pred_argument, ref_argument in zip(pred_arguments, ref_arguments): temp_hits.append( self._full_match_argument( parsed_predicted_arguments=pred_argument, parsed_reference_arguments=ref_argument, - selected_key=selected_key + selected_key=selected_key, ) ) label = selected_key element_wise_scores.append(ElementWiseScores(label=label, values=temp_hits)) return [MetricResults(element_wise_scores=element_wise_scores, need_average=True)] - - def _full_match_argument(self, parsed_predicted_arguments, parsed_reference_arguments, - selected_key,day_lift=None) -> bool: + def _full_match_argument( + self, parsed_predicted_arguments, parsed_reference_arguments, selected_key, day_lift=None + ) -> bool: key = selected_key possible_day_lift = day_lift if 'date' in key: try: predicted_date_obj = datetime.strptime(parsed_reference_arguments[key], '%Y-%m-%d') reference_date_obj = datetime.strptime(parsed_predicted_arguments[key], '%Y-%m-%d') - if abs((predicted_date_obj - reference_date_obj).days) <= possible_day_lift: - return True - else: - return False + return abs((predicted_date_obj - reference_date_obj).days) <= possible_day_lift except (ValueError, TypeError): return False else: try: - return parsed_reference_arguments[self._settings.selected_argument_key] \ - == parsed_predicted_arguments[self._settings.selected_argument_key] - except: + return parsed_reference_arguments[key] == parsed_predicted_arguments[key] + except Exception as e: + print(f'error in match-argument{e}') return False - - def _parse_function_with_arguments(self, inputs: InferenceChatDataset | list[list[str]])\ - -> tuple[list, list]: + def _parse_function_with_arguments(self, inputs: InferenceChatDataset | list[list[str]]) -> tuple[list, list]: functions = [] arguments = [] for input_batch in inputs: for input in input_batch: input_splitted = self._load_json(input) if len(input_splitted) == 2: - functions.append(input_splitted.get('name','')) - arguments.append(input_splitted.get('parameters','')) + functions.append(input_splitted.get('name', '')) + arguments.append(input_splitted.get('parameters', '')) else: functions.append('') arguments.append('') return functions, arguments - @staticmethod def _load_json(data): try: - return json.loads(data) + return json.loads(data) except json.JSONDecodeError as e: error_position = e.pos - fixed_json = data[:error_position] + "null," + data[error_position + 1:] + fixed_json = data[:error_position] + 'null,' + data[error_position + 1 :] try: - return json.loads(fixed_json) + return json.loads(fixed_json) except json.JSONDecodeError as e2: + print(f'error in load-json in match-argument{e2}') return {} diff --git a/turbo_alignment/metrics/registry.py b/turbo_alignment/metrics/registry.py index 45cbe78..6f7c1dc 100755 --- a/turbo_alignment/metrics/registry.py +++ b/turbo_alignment/metrics/registry.py @@ -1,4 +1,5 @@ from enum import Enum + from pydantic import field_validator from turbo_alignment.common.registry import Registrable @@ -103,7 +104,7 @@ class RetrievalUtilitySettings(MetricSettings): @MetricSettingsRegistry.register(MetricType.EXACT_MATCH) class ExactMatchMetricSettings(MetricSettings): selected_argument_keys: list[str] - possible_day_lifts: list[int] = [0,1,3] + possible_day_lifts: list[int] = [0, 1, 3] class Config: env_prefix: str = 'EXACT_MATCH_METRIC_' diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 490e6d1..e49231b 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -26,9 +26,9 @@ from turbo_alignment.settings.pipelines.train.dpo import ( APODownLossSettings, APOZeroLossSettings, + ASFTLossSettings, CPOLossSettings, DPOLossesType, - ASFTLossSettings, HingeLossSettings, IPOLossSettings, KTOLossSettings,