diff --git a/turbo_alignment/metrics/distinctness.py b/turbo_alignment/metrics/distinctness.py index b17dbcb..a1ff2bd 100755 --- a/turbo_alignment/metrics/distinctness.py +++ b/turbo_alignment/metrics/distinctness.py @@ -1,21 +1,31 @@ from collections import defaultdict +from transformers.tokenization_utils_base import PreTrainedTokenizerBase from turbo_alignment.metrics.metric import Metric +from turbo_alignment.metrics.registry import DistinctnessSettings from turbo_alignment.settings.metric import ElementWiseScores, MetricResults, MetricType @Metric.register(MetricType.DIST_N) class DistinctnessMetric(Metric): + def __init__(self, settings: DistinctnessSettings) -> None: + super().__init__(settings=settings) + self._settings: DistinctnessSettings = settings + def compute(self, **kwargs) -> list[MetricResults]: predictions: list[list[str]] = kwargs.get('predictions', None) dataset_name: str = kwargs.get('dataset_name', '') + tokenizer: PreTrainedTokenizerBase = kwargs.get('tokenizer', None) + if tokenizer is None: + raise ValueError('tokenizer should not be None') + vocab_size: int = tokenizer.vocab_size if predictions is None: raise ValueError('predictions should not be None') dist_n = defaultdict(list) for prompt_answers in predictions: - ans_dist_n = self.distinctness(prompt_answers) + ans_dist_n = self.distinctness(prompt_answers, vocab_size, self._settings.ngram) for label, value in ans_dist_n.items(): dist_n[label].append(value) @@ -31,21 +41,32 @@ def compute(self, **kwargs) -> list[MetricResults]: ] @staticmethod - def distinctness(answers: list[str]) -> dict[str, float]: - unigrams, bigrams, trigrams = set(), set(), set() - total_words = 0 + def distinctness(answers: list[str], vocab_size: int, ngram: int) -> dict[str, float]: + ngram_sets: list[set] = [set() for _ in range(ngram)] + total_ngrams: list[int] = [0] * ngram for answer in answers: words = answer.split(' ') - total_words += len(words) - unigrams.update(words) - for i in range(len(words) - 1): - bigrams.add(words[i] + '_' + words[i + 1]) - for i in range(len(words) - 2): - trigrams.add(words[i] + '_' + words[i + 1] + '_' + words[i + 2]) - - return { - 'dist_1': len(unigrams) / total_words, - 'dist_2': len(bigrams) / total_words, - 'dist_3': len(trigrams) / total_words, - } + ngram_sets[0].update(words) + total_ngrams[0] += len(words) + + for n in range(1, ngram): + ngrams = ['_'.join(words[i : i + n + 1]) for i in range(len(words) - n)] + ngram_sets[n].update(ngrams) + total_ngrams[n] += len(ngrams) + + result = {} + for n in range(ngram): + result[f'dist_{n+1}'] = len(ngram_sets[n]) / total_ngrams[n] if total_ngrams[n] > 0 else 0 + try: + result[f'ead_dist_{n+1}'] = ( + len(ngram_sets[n]) / (vocab_size * (1 - ((vocab_size - 1) / vocab_size) ** total_ngrams[n])) + if total_ngrams[n] > 0 + else 0 + ) + except ZeroDivisionError: + result[f'ead_dist_{n+1}'] = 0 + + result['dist_mean'] = sum(result[f'dist_{n+1}'] for n in range(ngram)) / ngram + result['ead_dist_mean'] = sum(result[f'ead_dist_{n+1}'] for n in range(ngram)) / ngram + return result diff --git a/turbo_alignment/metrics/diversity.py b/turbo_alignment/metrics/diversity.py index 1d38b91..0af665c 100755 --- a/turbo_alignment/metrics/diversity.py +++ b/turbo_alignment/metrics/diversity.py @@ -6,11 +6,16 @@ from transformers import PreTrainedTokenizerBase from turbo_alignment.metrics.metric import Metric +from turbo_alignment.metrics.registry import DiversitySettings from turbo_alignment.settings.metric import ElementWiseScores, MetricResults, MetricType @Metric.register(MetricType.DIVERSITY) class DiversityMetric(Metric): + def __init__(self, settings: DiversitySettings) -> None: + super().__init__(settings=settings) + self._settings: DiversitySettings = settings + def compute(self, **kwargs) -> list[MetricResults]: tokenizer: PreTrainedTokenizerBase = kwargs.get('tokenizer', None) predictions: list[list[str]] = kwargs.get('predictions', None) @@ -25,7 +30,10 @@ def compute(self, **kwargs) -> list[MetricResults]: element_wise_diversity_scores = [ ElementWiseScores( label=dataset_name + '@@' + 'diversity', - values=[self.average_token_entropy(answer_group, tokenizer) for answer_group in predictions], + values=[ + self.average_token_entropy(answer_group, tokenizer, self._settings.top_k) + for answer_group in predictions + ], ) ] @@ -34,15 +42,17 @@ def compute(self, **kwargs) -> list[MetricResults]: for need_average in self._settings.need_average ] - def average_token_entropy(self, answer_group: list[str], tokenizer: PreTrainedTokenizerBase) -> float: - entropies = [self.token_entropy(answer, tokenizer) for answer in answer_group] + def average_token_entropy( + self, answer_group: list[str], tokenizer: PreTrainedTokenizerBase, top_k: int | None + ) -> float: + entropies = [self.token_entropy(answer, tokenizer, top_k) for answer in answer_group] if entropies: return sum(entropies) / len(entropies) return np.nan @staticmethod - def token_entropy(sample: str, tokenizer: PreTrainedTokenizerBase) -> float: + def token_entropy(sample: str, tokenizer: PreTrainedTokenizerBase, top_k: int | None) -> float: stats: dict[int, Any] = defaultdict(int) num_tokens = 0 tokens = tokenizer.encode(sample) @@ -54,4 +64,8 @@ def token_entropy(sample: str, tokenizer: PreTrainedTokenizerBase) -> float: for k in stats.keys(): stats[k] /= num_tokens - return entropy(list(stats.values())) + top_k_stats = list(stats.values()) + if top_k is not None: + top_k_stats = sorted(top_k_stats, reverse=True)[:top_k] + + return entropy(top_k_stats) diff --git a/turbo_alignment/metrics/registry.py b/turbo_alignment/metrics/registry.py index 0e269d6..d9f5abd 100755 --- a/turbo_alignment/metrics/registry.py +++ b/turbo_alignment/metrics/registry.py @@ -1,4 +1,5 @@ from enum import Enum +from pydantic import field_validator from turbo_alignment.common.registry import Registrable from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel @@ -26,12 +27,18 @@ class MetricSettings(ExtraFieldsNotAllowedBaseModel): @MetricSettingsRegistry.register(MetricType.DIST_N) class DistinctnessSettings(MetricSettings): - ... + ngram: int = 5 + + @field_validator('ngram', mode='before') + def check_ngram(cls, value: int) -> int: + if value <= 0: + raise ValueError('ngram should be greater that 0') + return value @MetricSettingsRegistry.register(MetricType.DIVERSITY) class DiversitySettings(MetricSettings): - ... + top_k: int | None = None @MetricSettingsRegistry.register(MetricType.LENGTH) @@ -76,6 +83,12 @@ class RougeSettings(MetricSettings): class SelfBleuSettings(MetricSettings): ngram: int = 3 + @field_validator('ngram', mode='before') + def check_ngram(cls, value: int) -> int: + if value <= 0: + raise ValueError('ngram should be greater that 0') + return value + @MetricSettingsRegistry.register(MetricType.TOOL_CALL_METRICS) class ToolMetricsSettings(MetricSettings):