diff --git a/deeppavlov/configs/generative_qa/sbersquad_fid.json b/deeppavlov/configs/generative_qa/sbersquad_fid.json new file mode 100755 index 0000000000..55308d61b3 --- /dev/null +++ b/deeppavlov/configs/generative_qa/sbersquad_fid.json @@ -0,0 +1,114 @@ +{ + "dataset_reader": { + "class_name": "json_reader", + "data_path": "{DATASET_PATH}/qa-ru-long-big-answers.json" + }, + "dataset_iterator": { + "class_name": "data_learning_iterator", + "seed": 228, + "shuffle": true + }, + "chainer": { + "in": ["question", "contexts"], + "in_y": ["gold_answers"], + "pipe": [ + { + "class_name": "fid_input_preprocessor", + "vocab_file": "{TRANSFORMER}", + "max_seq_length": 200, + "in": ["question", "contexts"], + "out": ["input_ids", "attention_mask"] + }, + { + "class_name": "fid_target_preprocessor", + "vocab_file": "{TRANSFORMER}", + "answer_maxlength" : 50, + "in": ["gold_answers"], + "out": ["target_ids"] + }, + { + "class_name": "torch_generative_qa_fid", + "pretrained_transformer": "{TRANSFORMER}", + "save_path": "{MODEL_PATH}/save", + "load_path": "{MODEL_PATH}/load", + "optimizer": "AdamW", + "optimizer_parameters": { + "lr": 3e-04, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-08 + }, + "learning_rate_drop_patience": 24, + "learning_rate_drop_div": 2, + "min_learning_rate": 1e-5, + "generate_max_length" : 50, + "in": ["input_ids", "attention_mask"], + "in_y": ["target_ids"], + "out": ["model_answer"] + } + ], + "out": ["model_answer"] + }, + "train": { + "show_examples": false, + "evaluation_targets": [ + "valid" + ], + "log_every_n_batches": 50, + "val_every_n_batches": 4000, + "batch_size": 15, + "validation_patience": 100, + "metrics": [ + { + "name": "squad_v1_f1", + "inputs": ["gold_answers", "model_answer"] + }, + { + "name": "sacrebleu", + "inputs": ["gold_answers", "model_answer"] + } + + ], + "class_name": "torch_trainer" + }, + "metadata": { + "variables": { + "ROOT_PATH": "/home/admin/.deeppavlov", + "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", + "MODELS_PATH": "{ROOT_PATH}/models", + "MODEL_PATH": "{MODELS_PATH}/generative_qa/fusion_in_decoder/sber_squad", + "TRANSFORMER": "{MODEL_PATH}/ruT5-base", + "DATASET_PATH": "{DOWNLOADS_PATH}/sber_squad" + }, + "download": [ + { + "url": "https://files.deeppavlov.ai/deeppavlov_data/generative_question_answering_new/qa-ru-long-big-answers.json", + "subdir": "{DATASET_PATH}" + }, + { + "url": "https://files.deeppavlov.ai/deeppavlov_data/generative_question_answering_new/ruT5-base/ruT5-base/config.json", + "subdir": "{MODEL_PATH}/ruT5-base" + }, + { + "url": "https://files.deeppavlov.ai/deeppavlov_data/generative_question_answering_new/ruT5-base/ruT5-base/pytorch_model.bin", + "subdir": "{MODEL_PATH}/ruT5-base" + }, + { + "url": "https://files.deeppavlov.ai/deeppavlov_data/generative_question_answering_new/ruT5-base/ruT5-base/spiece.model", + "subdir": "{MODEL_PATH}/ruT5-base" + }, + { + "url": "https://files.deeppavlov.ai/deeppavlov_data/generative_question_answering_new/ruT5-trained/pytorch_model.bin", + "subdir": "{MODEL_PATH}/load" + }, + { + "url": "https://files.deeppavlov.ai/deeppavlov_data/generative_question_answering_new/ruT5-trained/config.json", + "subdir": "{MODEL_PATH}/load" + }, + { + "url": "https://files.deeppavlov.ai/deeppavlov_data/generative_question_answering_new/ruT5-trained/optimizer.pth.tar", + "subdir": "{MODEL_PATH}/load" + } + ] + } +} diff --git a/deeppavlov/core/common/metrics_registry.json b/deeppavlov/core/common/metrics_registry.json index c1f1a6c7a0..a89924716e 100644 --- a/deeppavlov/core/common/metrics_registry.json +++ b/deeppavlov/core/common/metrics_registry.json @@ -31,6 +31,7 @@ "r@5": "deeppavlov.metrics.recall_at_k:r_at_5", "rank_response": "deeppavlov.models.ranking.metrics:rank_response", "roc_auc": "deeppavlov.metrics.roc_auc_score:roc_auc_score", + "sacrebleu": "deeppavlov.metrics.bleu:sacrebleu", "sets_accuracy": "deeppavlov.metrics.accuracy:sets_accuracy", "slots_accuracy": "deeppavlov.metrics.accuracy:slots_accuracy", "spearman_correlation": "deeppavlov.metrics.correlation:spearman_correlation", @@ -40,4 +41,4 @@ "squad_v2_f1": "deeppavlov.metrics.squad_metrics:squad_v2_f1", "record_f1_score": "deeppavlov.metrics.record_metrics:record_f1_score", "record_em_score": "deeppavlov.metrics.record_metrics:record_em_score" -} \ No newline at end of file +} diff --git a/deeppavlov/metrics/bleu.py b/deeppavlov/metrics/bleu.py index 4fcc8fec67..f22363a0e6 100644 --- a/deeppavlov/metrics/bleu.py +++ b/deeppavlov/metrics/bleu.py @@ -20,6 +20,8 @@ from deeppavlov.core.common.metrics_registry import register_metric from deeppavlov.metrics.google_bleu import compute_bleu +from sacrebleu.metrics import BLEU + import numpy as np SMOOTH = SmoothingFunction() @@ -81,3 +83,15 @@ def per_item_dialog_bleu(y_true, y_predicted): y_true = (y['text'] for dialog in y_true for y in dialog) return corpus_bleu([[y_t.lower().split()] for y_t in y_true], [y.lower().split() for y_p in y_predicted for y in y_p]) + + +@register_metric('sacrebleu') +def sacrebleu(y_true: List[str], y_predicted: List[str]) -> float: + y_true_padded = [] + for answer in y_true: + y_true_padded.append([answer]) + y_true = np.transpose(y_true_padded).tolist() + + bleu = BLEU() + return bleu.corpus_score(y_predicted, y_true).score + diff --git a/deeppavlov/requirements/sacreblue.txt b/deeppavlov/requirements/sacreblue.txt new file mode 100644 index 0000000000..b71be058a6 --- /dev/null +++ b/deeppavlov/requirements/sacreblue.txt @@ -0,0 +1 @@ +sacrebleu==2.1.0 \ No newline at end of file diff --git a/tests/test_quick_start.py b/tests/test_quick_start.py index 5fa52d657e..6bad7307c4 100644 --- a/tests/test_quick_start.py +++ b/tests/test_quick_start.py @@ -273,6 +273,9 @@ ("squad/squad_ru_bert.json", "squad_ru_bert", ('IP', 'TI')): [TWO_ARGUMENTS_INFER_CHECK], ("squad/squad_bert.json", "squad_bert", ('IP', 'TI')): [TWO_ARGUMENTS_INFER_CHECK] }, + "generative_qa": { + ("generative_qa/sbersquad_fid.json", "sbersquad_fid", ('IP', 'TI')): [TWO_ARGUMENTS_INFER_CHECK], + }, "odqa": { ("odqa/en_odqa_infer_wiki.json", "odqa", ('IP',)): [ONE_ARGUMENT_INFER_CHECK], ("odqa/ru_odqa_infer_wiki.json", "odqa", ('IP',)): [ONE_ARGUMENT_INFER_CHECK],