From 545e64afe6733309ea328f5f39f5c0939e6ad216 Mon Sep 17 00:00:00 2001 From: 01WarpDrive Date: Thu, 18 Jul 2024 16:10:15 +0800 Subject: [PATCH] Update metric.py The input parameters of bleu are optimized. Added the ability to evaluate English data. --- src/llamafactory/train/sft/metric.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index 6932737904..160d900b07 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -22,6 +22,7 @@ import numpy as np import torch from transformers.utils import is_jieba_available, is_nltk_available +import re from ...extras.constants import IGNORE_INDEX from ...extras.misc import numpify @@ -101,6 +102,9 @@ def __post_init__(self): self._dump() def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: + # Check whether Chinese characters exist + is_chinese = lambda x='ddd':sum([1 if u'\u4e00' <= i <= u'\u9fff' else 0 for i in x])>0 + preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids) preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) @@ -110,9 +114,17 @@ def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) for pred, label in zip(decoded_preds, decoded_labels): - hypothesis = list(jieba.cut(pred)) - reference = list(jieba.cut(label)) - + if is_chinese(label):# Chinese, Remove special characters and space + pred = re.sub('[^\w]','',pred) + label = re.sub('[^\w]','',label) + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + else: # English, Remove special characters + pred = re.sub('[^\w ]','',pred) + label = re.sub('[^\w ]','',label) + hypothesis = pred.split() + reference = label.split() + if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} else: @@ -123,7 +135,7 @@ def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> for k, v in result.items(): self.score_dict[k].append(round(v["f"] * 100, 4)) - bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + bleu_score = sentence_bleu([reference], hypothesis, smoothing_function=SmoothingFunction().method3) self.score_dict["bleu-4"].append(round(bleu_score * 100, 4)) if compute_result: