From 3e7d488c35df4df873a19e057233cbbf67f6608b Mon Sep 17 00:00:00 2001 From: Maksim Savkin Date: Wed, 17 Aug 2022 17:38:55 +0300 Subject: [PATCH 1/9] add fusion in decoder --- deeppavlov/configs/kbqa/nq_fid.json | 98 +++++ deeppavlov/core/common/metrics_registry.json | 1 + deeppavlov/core/common/registry.json | 3 + .../core/common/requirements_registry.json | 9 + deeppavlov/dataset_readers/json_reader.py | 30 ++ deeppavlov/metrics/bleu.py | 13 + deeppavlov/models/kbqa/fusion_in_decoder.py | 353 ++++++++++++++++++ .../torch_transformers_preprocessor.py | 62 +++ .../models/torch_bert/torch_generative_qa.py | 200 ++++++++++ deeppavlov/requirements/sacrebleu.txt | 1 + 10 files changed, 770 insertions(+) create mode 100644 deeppavlov/configs/kbqa/nq_fid.json create mode 100644 deeppavlov/dataset_readers/json_reader.py create mode 100644 deeppavlov/models/kbqa/fusion_in_decoder.py create mode 100644 deeppavlov/models/torch_bert/torch_generative_qa.py create mode 100644 deeppavlov/requirements/sacrebleu.txt diff --git a/deeppavlov/configs/kbqa/nq_fid.json b/deeppavlov/configs/kbqa/nq_fid.json new file mode 100644 index 0000000000..7e0c6726cf --- /dev/null +++ b/deeppavlov/configs/kbqa/nq_fid.json @@ -0,0 +1,98 @@ +{ + "dataset_reader": { + "class_name": "json_reader", + "valid_size": 1000, + "data_path": "{DATASET_PATH}/natural_questions_dataset.json" + }, + "dataset_iterator": { + "class_name": "data_learning_iterator" + }, + "chainer": { + "in": ["question", "contexts"], + "in_y": ["target", "gold_answers"], + "pipe": [ + { + "class_name": "torch_transformers_fid_preprocessor", + "vocab_file": "{TRANSFORMER}", + "max_seq_length": 512, + "answer_maxlength" : 20, + "in": ["question", "contexts", "target"], + "out": ["input_ids", "attention_mask", "target_ids"] + }, + { + "class_name": "torch_generative_qa_fid", + "pretrained_transformer": "{TRANSFORMER}", + "save_path": "{MODEL_PATH}", + "load_path": "{MODEL_PATH}", + "optimizer": "AdamW", + "optimizer_parameters": { + "lr": 3e-04, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-08 + }, + "learning_rate_drop_patience": 8, + "learning_rate_drop_div": 2, + "min_learning_rate": 1e-5, + "generate_max_length" : 20, + "in": ["input_ids", "attention_mask"], + "in_y": ["target_ids"], + "out": ["model_answer"] + } + ], + "out": ["model_answer"] + }, + "train": { + "show_examples": false, + "evaluation_targets": [ + "valid" + ], + "log_every_n_batches": 100, + "val_every_n_batches": 600, + "batch_size": 8, + "validation_patience": 60, + "metrics": [ + { + "name": "squad_v2_em", + "inputs": ["gold_answers", "model_answer"] + }, + { + "name": "squad_v2_f1", + "inputs": ["gold_answers", "model_answer"] + }, + { + "name": "sacrebleu", + "inputs": ["gold_answers", "model_answer"] + } + ], + "class_name": "torch_trainer" + }, + "metadata": { + "variables": { + "TRANSFORMER": "t5-base", + "ROOT_PATH": "~/.deeppavlov", + "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", + "MODELS_PATH": "{ROOT_PATH}/models", + "MODEL_PATH": "{MODELS_PATH}/generative_qa/fusion_in_decoder/natural_questions", + "DATASET_PATH": "{DOWNLOADS_PATH}/natural_questions" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/kbqa/datasets/natural_questions/natural_questions_dataset.json", + "subdir": "{DATASET_PATH}" + }, + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/kbqa/models/generative_qa/fusion_in_decoder/natural_questions/config.json", + "subdir": "{MODEL_PATH}" + }, + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/kbqa/models/generative_qa/fusion_in_decoder/natural_questions/pytorch_model.bin", + "subdir": "{MODEL_PATH}" + }, + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/kbqa/models/generative_qa/fusion_in_decoder/natural_questions/optimizer.pth.tar", + "subdir": "{MODEL_PATH}" + } + ] + } +} diff --git a/deeppavlov/core/common/metrics_registry.json b/deeppavlov/core/common/metrics_registry.json index c1f1a6c7a0..79d0364cc8 100644 --- a/deeppavlov/core/common/metrics_registry.json +++ b/deeppavlov/core/common/metrics_registry.json @@ -31,6 +31,7 @@ "r@5": "deeppavlov.metrics.recall_at_k:r_at_5", "rank_response": "deeppavlov.models.ranking.metrics:rank_response", "roc_auc": "deeppavlov.metrics.roc_auc_score:roc_auc_score", + "sacrebleu": "deeppavlov.metrics.bleu:sacrebleu", "sets_accuracy": "deeppavlov.metrics.accuracy:sets_accuracy", "slots_accuracy": "deeppavlov.metrics.accuracy:slots_accuracy", "spearman_correlation": "deeppavlov.metrics.correlation:spearman_correlation", diff --git a/deeppavlov/core/common/registry.json b/deeppavlov/core/common/registry.json index 42f0df484e..6edd96ad2b 100644 --- a/deeppavlov/core/common/registry.json +++ b/deeppavlov/core/common/registry.json @@ -21,6 +21,7 @@ "huggingface_dataset_iterator": "deeppavlov.dataset_iterators.huggingface_dataset_iterator:HuggingFaceDatasetIterator", "huggingface_dataset_reader": "deeppavlov.dataset_readers.huggingface_dataset_reader:HuggingFaceDatasetReader", "imdb_reader": "deeppavlov.dataset_readers.imdb_reader:ImdbReader", + "json_reader": "deeppavlov.dataset_readers.json_reader:JsonReader", "kenlm_elector": "deeppavlov.models.spelling_correction.electors.kenlm_elector:KenlmElector", "line_reader": "deeppavlov.dataset_readers.line_reader:LineReader", "logit_ranker": "deeppavlov.models.doc_retrieval.logit_ranker:LogitRanker", @@ -81,6 +82,7 @@ "top1_elector": "deeppavlov.models.spelling_correction.electors.top1_elector:TopOneElector", "torch_bert_ranker": "deeppavlov.models.torch_bert.torch_bert_ranker:TorchBertRankerModel", "torch_bert_ranker_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchBertRankerPreprocessor", + "torch_generative_qa_fid": "deeppavlov.models.torch_bert.torch_generative_qa:TorchFiD", "torch_record_postprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchRecordPostprocessor", "torch_squad_transformers_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchSquadTransformersPreprocessor", "torch_text_classification_model": "deeppavlov.models.classifiers.torch_classification_model:TorchTextClassificationModel", @@ -89,6 +91,7 @@ "torch_transformers_el_ranker": "deeppavlov.models.torch_bert.torch_transformers_el_ranker:TorchTransformersElRanker", "torch_transformers_entity_ranker_infer": "deeppavlov.models.torch_bert.torch_transformers_el_ranker:TorchTransformersEntityRankerInfer", "torch_transformers_entity_ranker_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersEntityRankerPreprocessor", + "torch_transformers_fid_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersFiDPreprocessor", "torch_transformers_multiplechoice": "deeppavlov.models.torch_bert.torch_transformers_multiplechoice:TorchTransformersMultiplechoiceModel", "torch_transformers_multiplechoice_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersMultiplechoicePreprocessor", "torch_transformers_ner_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersNerPreprocessor", diff --git a/deeppavlov/core/common/requirements_registry.json b/deeppavlov/core/common/requirements_registry.json index d65eba771e..f29d8e7319 100644 --- a/deeppavlov/core/common/requirements_registry.json +++ b/deeppavlov/core/common/requirements_registry.json @@ -86,6 +86,11 @@ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" ], + "torch_generative_qa_fid": [ + "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", + "{DEEPPAVLOV_PATH}/requirements/transformers.txt", + "{DEEPPAVLOV_PATH}/requirements/sacrebleu.txt" + ], "torch_record_postprocessor": [ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" @@ -113,6 +118,10 @@ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" ], + "torch_transformers_fid_preprocessor": [ + "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", + "{DEEPPAVLOV_PATH}/requirements/transformers.txt" + ], "torch_transformers_multiplechoice": [ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" diff --git a/deeppavlov/dataset_readers/json_reader.py b/deeppavlov/dataset_readers/json_reader.py new file mode 100644 index 0000000000..c38418e3f6 --- /dev/null +++ b/deeppavlov/dataset_readers/json_reader.py @@ -0,0 +1,30 @@ +# Copyright 2017 Neural Networks and Deep Learning lab, MIPT +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Dict, Optional + +from deeppavlov.core.common.registry import register +from deeppavlov.core.data.dataset_reader import DatasetReader + +@register('json_reader') +class JsonReader(DatasetReader): + + def read(self, data_path: str, valid_size: Optional[int] = None) -> Dict: + with open(data_path, 'r') as f: + dataset = json.load(f) + if valid_size is not None: + dataset["valid"] = dataset["valid"][:valid_size] + + return dataset diff --git a/deeppavlov/metrics/bleu.py b/deeppavlov/metrics/bleu.py index 75bfec2b79..ddbfbf6846 100644 --- a/deeppavlov/metrics/bleu.py +++ b/deeppavlov/metrics/bleu.py @@ -19,6 +19,8 @@ from deeppavlov.core.common.metrics_registry import register_metric from deeppavlov.metrics.google_bleu import compute_bleu +from sacrebleu.metrics import BLEU +import numpy as np SMOOTH = SmoothingFunction() @@ -79,3 +81,14 @@ def per_item_dialog_bleu(y_true, y_predicted): y_true = (y['text'] for dialog in y_true for y in dialog) return corpus_bleu([[y_t.lower().split()] for y_t in y_true], [y.lower().split() for y_p in y_predicted for y in y_p]) + +@register_metric('sacrebleu') +def sacrebleu(y_true: List[str], y_predicted: List[str]) -> float: + y_true_padded = [] + max_answers_cnt = max(len(answers) for answers in y_true) + for answers in y_true: + y_true_padded.append(answers + [''] * (max_answers_cnt - len(answers))) + y_true = np.transpose(y_true_padded).tolist() + + bleu = BLEU() + return bleu.corpus_score(y_predicted, y_true).score \ No newline at end of file diff --git a/deeppavlov/models/kbqa/fusion_in_decoder.py b/deeppavlov/models/kbqa/fusion_in_decoder.py new file mode 100644 index 0000000000..7f17dcd6d3 --- /dev/null +++ b/deeppavlov/models/kbqa/fusion_in_decoder.py @@ -0,0 +1,353 @@ +import types +import torch +import transformers +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss +import numpy as np + + +class FiDT5(transformers.T5ForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + self.wrap_encoder() + + def forward_(self, **kwargs): + if 'input_ids' in kwargs: + kwargs['input_ids'] = kwargs['input_ids'].view(kwargs['input_ids'].size(0), -1) + if 'attention_mask' in kwargs: + kwargs['attention_mask'] = kwargs['attention_mask'].view(kwargs['attention_mask'].size(0), -1) + + return super(FiDT5, self).forward( + **kwargs + ) + + # We need to resize as B x (N * L) instead of (B * N) x L here + # because the T5 forward method uses the input tensors to infer + # dimensions used in the decoder. + # EncoderWrapper resizes the inputs as (B * N) x L. + def forward(self, input_ids=None, attention_mask=None, **kwargs): + if input_ids != None: + # inputs might have already be resized in the generate method + if input_ids.dim() == 3: + self.encoder.n_passages = input_ids.size(1) + input_ids = input_ids.view(input_ids.size(0), -1) + if attention_mask != None: + attention_mask = attention_mask.view(attention_mask.size(0), -1) + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs + ) + + # We need to resize the inputs here, as the generate method expect 2D tensors + def generate(self, input_ids, attention_mask, max_length): + self.encoder.n_passages = input_ids.size(1) + return super().generate( + input_ids=input_ids.view(input_ids.size(0), -1), + attention_mask=attention_mask.view(attention_mask.size(0), -1), + max_length=max_length + ) + + def wrap_encoder(self, use_checkpoint=False): + """ + Wrap T5 encoder to obtain a Fusion-in-Decoder model. + """ + self.encoder = EncoderWrapper(self.encoder, use_checkpoint=use_checkpoint) + + def unwrap_encoder(self): + """ + Unwrap Fusion-in-Decoder encoder, useful to load T5 weights. + """ + self.encoder = self.encoder.encoder + block = [] + for mod in self.encoder.block: + block.append(mod.module) + block = nn.ModuleList(block) + self.encoder.block = block + + def load_t5(self, state_dict): + self.unwrap_encoder() + self.load_state_dict(state_dict) + self.wrap_encoder() + + def set_checkpoint(self, use_checkpoint): + """ + Enable or disable checkpointing in the encoder. + See https://pytorch.org/docs/stable/checkpoint.html + """ + for mod in self.encoder.encoder.block: + mod.use_checkpoint = use_checkpoint + + def reset_score_storage(self): + """ + Reset score storage, only used when cross-attention scores are saved + to train a retriever. + """ + for mod in self.decoder.block: + mod.layer[1].EncDecAttention.score_storage = None + + def get_crossattention_scores(self, context_mask): + """ + Cross-attention scores are aggregated to obtain a single scalar per + passage. This scalar can be seen as a similarity score between the + question and the input passage. It is obtained by averaging the + cross-attention scores obtained on the first decoded token over heads, + layers, and tokens of the input passage. + + More details in Distilling Knowledge from Reader to Retriever: + https://arxiv.org/abs/2012.04584. + """ + scores = [] + n_passages = context_mask.size(1) + for mod in self.decoder.block: + scores.append(mod.layer[1].EncDecAttention.score_storage) + scores = torch.cat(scores, dim=2) + bsz, n_heads, n_layers, _ = scores.size() + # batch_size, n_head, n_layers, n_passages, text_maxlength + scores = scores.view(bsz, n_heads, n_layers, n_passages, -1) + scores = scores.masked_fill(~context_mask[:, None, None], 0.) + scores = scores.sum(dim=[1, 2, 4]) + ntokens = context_mask.sum(dim=[2]) * n_layers * n_heads + scores = scores/ntokens + return scores + + def overwrite_forward_crossattention(self): + """ + Replace cross-attention forward function, only used to save + cross-attention scores. + """ + for mod in self.decoder.block: + attn = mod.layer[1].EncDecAttention + attn.forward = types.MethodType(cross_attention_forward, attn) + +class EncoderWrapper(torch.nn.Module): + """ + Encoder Wrapper for T5 Wrapper to obtain a Fusion-in-Decoder model. + """ + def __init__(self, encoder, use_checkpoint=False): + super().__init__() + + self.encoder = encoder + apply_checkpoint_wrapper(self.encoder, use_checkpoint) + + def forward(self, input_ids=None, attention_mask=None, **kwargs,): + # total_length = n_passages * passage_length + bsz, total_length = input_ids.shape + passage_length = total_length // self.n_passages + input_ids = input_ids.view(bsz*self.n_passages, passage_length) + attention_mask = attention_mask.view(bsz*self.n_passages, passage_length) + outputs = self.encoder(input_ids, attention_mask, **kwargs) + outputs = (outputs[0].view(bsz, self.n_passages*passage_length, -1), ) + outputs[1:] + return outputs + +class CheckpointWrapper(torch.nn.Module): + """ + Wrapper replacing None outputs by empty tensors, which allows the use of + checkpointing. + """ + def __init__(self, module, use_checkpoint=False): + super().__init__() + self.module = module + self.use_checkpoint = use_checkpoint + + def forward(self, hidden_states, attention_mask, position_bias, **kwargs): + if self.use_checkpoint and self.training: + kwargs = {k: v for k, v in kwargs.items() if v is not None} + def custom_forward(*inputs): + output = self.module(*inputs, **kwargs) + empty = torch.tensor( + [], + dtype=torch.float, + device=output[0].device, + requires_grad=True) + output = tuple(x if x is not None else empty for x in output) + return output + + output = torch.utils.checkpoint.checkpoint( + custom_forward, + hidden_states, + attention_mask, + position_bias + ) + output = tuple(x if x.size() != 0 else None for x in output) + else: + output = self.module(hidden_states, attention_mask, position_bias, **kwargs) + return output + +def apply_checkpoint_wrapper(t5stack, use_checkpoint): + """ + Wrap each block of the encoder to enable checkpointing. + """ + block = [] + for mod in t5stack.block: + wrapped_mod = CheckpointWrapper(mod, use_checkpoint) + block.append(wrapped_mod) + block = nn.ModuleList(block) + t5stack.block = block + +def cross_attention_forward( + self, + input, + mask=None, + kv=None, + position_bias=None, + past_key_value_state=None, + head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + This only works for computing cross attention over the input + """ + assert(kv != None) + assert(head_mask == None) + assert(position_bias != None or self.has_relative_attention_bias) + + bsz, qlen, dim = input.size() + n_heads, d_heads = self.n_heads, self.d_kv + klen = kv.size(1) + + q = self.q(input).view(bsz, -1, n_heads, d_heads).transpose(1, 2) + if past_key_value_state == None: + k = self.k(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2) + v = self.v(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2) + else: + k, v = past_key_value_state + + scores = torch.einsum("bnqd,bnkd->bnqk", q, k) + + if mask is not None: + scores += mask + + if position_bias is None: + position_bias = self.compute_bias(qlen, klen) + scores += position_bias + + if self.score_storage is None: + self.score_storage = scores + + attn = F.softmax(scores.float(), dim=-1).type_as(scores) + attn = F.dropout(attn, p=self.dropout, training=self.training) + + output = torch.matmul(attn, v) + output = output.transpose(1, 2).contiguous().view(bsz, -1, self.inner_dim) + output = self.o(output) + + if use_cache: + output = (output,) + ((k, v),) + else: + output = (output,) + (None,) + + if output_attentions: + output = output + (attn,) + + if self.has_relative_attention_bias: + output = output + (position_bias,) + + return output + +class RetrieverConfig(transformers.BertConfig): + + def __init__(self, + indexing_dimension=768, + apply_question_mask=False, + apply_passage_mask=False, + extract_cls=False, + passage_maxlength=200, + question_maxlength=40, + projection=True, + **kwargs): + super().__init__(**kwargs) + self.indexing_dimension = indexing_dimension + self.apply_question_mask = apply_question_mask + self.apply_passage_mask = apply_passage_mask + self.extract_cls=extract_cls + self.passage_maxlength = passage_maxlength + self.question_maxlength = question_maxlength + self.projection = projection + +class Retriever(transformers.PreTrainedModel): + + config_class = RetrieverConfig + base_model_prefix = "retriever" + + def __init__(self, config, initialize_wBERT=False): + super().__init__(config) + assert config.projection or config.indexing_dimension == 768, \ + 'If no projection then indexing dimension must be equal to 768' + self.config = config + if initialize_wBERT: + self.model = transformers.BertModel.from_pretrained('bert-base-uncased') + else: + self.model = transformers.BertModel(config) + if self.config.projection: + self.proj = nn.Linear( + self.model.config.hidden_size, + self.config.indexing_dimension + ) + self.norm = nn.LayerNorm(self.config.indexing_dimension) + self.loss_fct = torch.nn.KLDivLoss() + + def forward(self, + question_ids, + question_mask, + passage_ids, + passage_mask, + gold_score=None): + question_output = self.embed_text( + text_ids=question_ids, + text_mask=question_mask, + apply_mask=self.config.apply_question_mask, + extract_cls=self.config.extract_cls, + ) + bsz, n_passages, plen = passage_ids.size() + passage_ids = passage_ids.view(bsz * n_passages, plen) + passage_mask = passage_mask.view(bsz * n_passages, plen) + passage_output = self.embed_text( + text_ids=passage_ids, + text_mask=passage_mask, + apply_mask=self.config.apply_passage_mask, + extract_cls=self.config.extract_cls, + ) + + score = torch.einsum( + 'bd,bid->bi', + question_output, + passage_output.view(bsz, n_passages, -1) + ) + score = score / np.sqrt(question_output.size(-1)) + if gold_score is not None: + loss = self.kldivloss(score, gold_score) + else: + loss = None + + return question_output, passage_output, score, loss + + def embed_text(self, text_ids, text_mask, apply_mask=False, extract_cls=False): + text_output = self.model( + input_ids=text_ids, + attention_mask=text_mask if apply_mask else None + ) + if type(text_output) is not tuple: + text_output.to_tuple() + text_output = text_output[0] + if self.config.projection: + text_output = self.proj(text_output) + text_output = self.norm(text_output) + + if extract_cls: + text_output = text_output[:, 0] + else: + if apply_mask: + text_output = text_output.masked_fill(~text_mask[:, :, None], 0.) + text_output = torch.sum(text_output, dim=1) / torch.sum(text_mask, dim=1)[:, None] + else: + text_output = torch.mean(text_output, dim=1) + return text_output + + def kldivloss(self, score, gold_score): + gold_score = torch.softmax(gold_score, dim=-1) + score = torch.nn.functional.log_softmax(score, dim=-1) + return self.loss_fct(score, gold_score) diff --git a/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py b/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py index 8bc2daec34..7e79441730 100644 --- a/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py +++ b/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py @@ -421,6 +421,68 @@ def __call__(self, questions_batch: List[List[str]], rels_batch: List[List[str]] return input_features +@register('torch_transformers_fid_preprocessor') +class TorchTransformersFiDPreprocessor(Component): + def __init__(self, + vocab_file: str, + do_lower_case: bool = True, + max_seq_length: int = 512, + return_tokens: bool = False, + answer_maxlength: int = 20, + **kwargs) -> None: + self.max_seq_length = max_seq_length + self.return_tokens = return_tokens + self.answer_maxlength = answer_maxlength + + if Path(vocab_file).is_file(): + vocab_file = str(expand_path(vocab_file)) + self.tokenizer = AutoTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) + else: + self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=do_lower_case) + + + def encode_question_passages(self, passages_batch: List[List[str]]): + passage_ids, passage_masks = [], [] + for text_passages in passages_batch: + passages_encoding = self.tokenizer( + text_passages, + max_length=self.max_seq_length if self.max_seq_length > 0 else None, + pad_to_max_length=True, + return_tensors='pt', + truncation=True if self.max_seq_length > 0 else False, + ) + passage_ids.append(passages_encoding['input_ids'][None]) + passage_masks.append(passages_encoding['attention_mask'][None]) + + passage_ids = torch.cat(passage_ids, dim=0) + passage_masks = torch.cat(passage_masks, dim=0) + return passage_ids, passage_masks + + def encode_targets(self, targets_batch: List[str] = None): + target_ids, target_masks = None, None + if targets_batch is not None: + target_encoding = self.tokenizer( + targets_batch, + max_length=self.answer_maxlength if self.answer_maxlength > 0 else None, + pad_to_max_length=True, + return_tensors='pt', + truncation=True if self.answer_maxlength > 0 else False, + ) + target_ids = target_encoding["input_ids"] + target_mask = target_encoding["attention_mask"].bool() + return target_ids, target_masks + + def __call__(self, questions_batch: List[str], passages_batch: List[List[str]], targets_batch: List[str] = None): + prepare_data = lambda q, c: f"question: {q}, context: {c}" + question_passages_batch = [[prepare_data(question, passage) for passage in text_passages] + for (question, text_passages) in zip(questions_batch, passages_batch)] + + question_passages_ids, question_passage_masks = self.encode_question_passages(question_passages_batch) + + target_ids, target_masks = self.encode_targets(targets_batch) + + return question_passages_ids, question_passage_masks, target_ids + @register('torch_transformers_ner_preprocessor') class TorchTransformersNerPreprocessor(Component): """ diff --git a/deeppavlov/models/torch_bert/torch_generative_qa.py b/deeppavlov/models/torch_bert/torch_generative_qa.py new file mode 100644 index 0000000000..24cce40548 --- /dev/null +++ b/deeppavlov/models/torch_bert/torch_generative_qa.py @@ -0,0 +1,200 @@ +# Copyright 2017 Neural Networks and Deep Learning lab, MIPT +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from logging import getLogger +from pathlib import Path +from typing import List, Optional, Dict + +import torch +from overrides import overrides +from transformers import T5ForConditionalGeneration, T5Tokenizer + +from deeppavlov.core.common.registry import register +from deeppavlov.core.models.torch_model import TorchModel + + +from deeppavlov.models.kbqa.fusion_in_decoder import FiDT5 + + +logger = getLogger(__name__) + + +@register('torch_generative_qa_fid') +class TorchFiD(TorchModel): + def __init__(self, + pretrained_transformer: str = "t5-base", + attention_probs_keep_prob: Optional[float] = None, + hidden_keep_prob: Optional[float] = None, + optimizer: str = "AdamW", + optimizer_parameters: Optional[dict] = None, + bert_config_file: Optional[str] = None, + learning_rate_drop_patience: int = 20, + learning_rate_drop_div: float = 2.0, + load_before_drop: bool = True, + clip_norm: Optional[float] = None, + min_learning_rate: float = 1e-06, + generate_max_length: int = 20, + **kwargs): + + if not optimizer_parameters: + optimizer_parameters = {"lr": 0.01, + "weight_decay": 0.01, + "betas": (0.9, 0.999), + "eps": 1e-6} + self.generate_max_length = generate_max_length + + self.attention_probs_keep_prob = attention_probs_keep_prob + self.hidden_keep_prob = hidden_keep_prob + self.clip_norm = clip_norm + + self.pretrained_transformer = pretrained_transformer + self.bert_config_file = bert_config_file + self.tokenizer = T5Tokenizer.from_pretrained(self.pretrained_transformer, return_dict=False) + + super().__init__(optimizer=optimizer, + optimizer_parameters=optimizer_parameters, + learning_rate_drop_patience=learning_rate_drop_patience, + learning_rate_drop_div=learning_rate_drop_div, + load_before_drop=load_before_drop, + min_learning_rate=min_learning_rate, + **kwargs) + + def train_on_batch(self, + input_ids_batch: List[List[float]], + attention_mask_batch: List[List[float]], + target_ids_batch: List[List[float]]) -> Dict: + input_ids_batch = torch.LongTensor(input_ids_batch).to(self.device) + attention_mask_batch = torch.LongTensor(attention_mask_batch).to(self.device) + target_ids_batch = torch.LongTensor(target_ids_batch).to(self.device) + + input_ = { + 'input_ids': input_ids_batch, + 'attention_mask': attention_mask_batch, + 'labels': target_ids_batch + } + + self.optimizer.zero_grad() + + loss = self.model(**input_)[0] + if self.is_data_parallel: + loss = loss.mean() + loss.backward() + + # Clip the norm of the gradients to 1.0. + # This is to help prevent the "exploding gradients" problem. + if self.clip_norm: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.clip_norm) + + self.optimizer.step() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + return {'loss': loss.item()} + + @property + def is_data_parallel(self) -> bool: + return isinstance(self.model, torch.nn.DataParallel) + + def __call__(self, input_ids_batch: List[List[float]], attention_mask_batch: List[List[float]]) -> List[str]: + input_ids_batch = torch.LongTensor(input_ids_batch).to(self.device) + attention_mask_batch = torch.LongTensor(attention_mask_batch).to(self.device) + input_ = { + 'input_ids': input_ids_batch, + 'attention_mask': attention_mask_batch, + } + + model = self.model.module if hasattr(self.model, "module") else self.model + with torch.no_grad(): + answer_ids_batch = model.generate(max_length=self.generate_max_length, **input_) + + answers_batch = self.tokenizer.batch_decode(answer_ids_batch, skip_special_tokens=True) + return answers_batch + + @overrides + def save(self, fname: Optional[str] = None): + if fname is None: + fname = self.save_path + os.makedirs(fname, exist_ok=True) + logger.info(f"Saving checkpoint to {fname}.") + + # Save model + model_dir_path = fname + model_to_save = self.model.module if hasattr(self.model, "module") else self.model + model_to_save.save_pretrained(model_dir_path) + + # Save optimizer and scheduler + optimizer_path = os.path.join(fname, "optimizer.pth.tar") + optimizer_state = { + "optimizer": self.optimizer.state_dict() + } + torch.save(optimizer_state, optimizer_path) + + + def init_optimizer_from_scratch(self) -> None: + self.optimizer = getattr(torch.optim, self.optimizer_name)( + self.model.parameters(), **self.optimizer_parameters) + + if self.lr_scheduler_name is not None: + self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)( + self.optimizer, **self.lr_scheduler_parameters) + + if self.opt.get("criterion", None): + self.criterion = getattr(torch.nn, self.opt.get("criterion", None))() + + def init_from_scratch(self) -> None: + logger.info(f"From pretrained {self.pretrained_transformer}.") + self.tokenizer = T5Tokenizer.from_pretrained(self.pretrained_transformer, return_dict=False) + t5 = T5ForConditionalGeneration.from_pretrained(self.pretrained_transformer) + + self.model = FiDT5(t5.config) + self.model.load_t5(t5.state_dict()) + self.model.to(self.device) + + self.init_optimizer_from_scratch() + + + def load_from_checkpoint(self, model_dir_path: str, optimizer_path: str) -> None: + logger.info(f"Loading model from {model_dir_path}.") + self.model = FiDT5.from_pretrained(model_dir_path) + self.model = self.model.to(self.device) + + logger.info(f"Loading optimizer from {optimizer_path}.") + self.init_optimizer_from_scratch() + optimizer_state = torch.load(optimizer_path, map_location=self.device) + self.optimizer.load_state_dict(optimizer_state["optimizer"]) + + @overrides + def load(self, fname: Optional[str] = None) -> None: + if fname is not None: + self.load_path = fname + + # Loading weights from checkpoint + if self.load_path is not None: + logger.info(f"Load path {self.load_path} is given.") + model_dir_path = self.load_path + optimizer_path = os.path.join(self.load_path, "optimizer.pth.tar") + + if Path(model_dir_path).exists() and Path(optimizer_path).exists(): + self.load_from_checkpoint(model_dir_path, optimizer_path) + else: + self.init_from_scratch() + logger.info(f"Init from scratch. Model_path: {model_dir_path} or optimizer_path: {optimizer_path} does not exist.") + else: + self.init_from_scratch() + logger.info(f"Init from scratch. Load path {self.load_path} does not exist.") + + if self.device.type == "cuda" and torch.cuda.device_count() > 1: + self.model = torch.nn.DataParallel(self.model) \ No newline at end of file diff --git a/deeppavlov/requirements/sacrebleu.txt b/deeppavlov/requirements/sacrebleu.txt new file mode 100644 index 0000000000..b71be058a6 --- /dev/null +++ b/deeppavlov/requirements/sacrebleu.txt @@ -0,0 +1 @@ +sacrebleu==2.1.0 \ No newline at end of file From 699eadfa5b1d9da2223fe262c79ed2037ed9dfba Mon Sep 17 00:00:00 2001 From: Maksim Savkin Date: Wed, 24 Aug 2022 10:43:11 +0300 Subject: [PATCH 2/9] fix model loading --- .../torch_transformers_preprocessor.py | 17 +++--- .../models/torch_bert/torch_generative_qa.py | 55 ++++++++++--------- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py b/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py index 7e79441730..05c5f911b7 100644 --- a/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py +++ b/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py @@ -441,7 +441,7 @@ def __init__(self, self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=do_lower_case) - def encode_question_passages(self, passages_batch: List[List[str]]): + def encode_passages(self, passages_batch: List[List[str]]): passage_ids, passage_masks = [], [] for text_passages in passages_batch: passages_encoding = self.tokenizer( @@ -472,16 +472,19 @@ def encode_targets(self, targets_batch: List[str] = None): target_mask = target_encoding["attention_mask"].bool() return target_ids, target_masks - def __call__(self, questions_batch: List[str], passages_batch: List[List[str]], targets_batch: List[str] = None): - prepare_data = lambda q, c: f"question: {q}, context: {c}" - question_passages_batch = [[prepare_data(question, passage) for passage in text_passages] - for (question, text_passages) in zip(questions_batch, passages_batch)] + def __call__(self, + questions_batch: List[str], + contexts_batch: List[List[str]], + targets_batch: List[str] = None): + prepare_data = lambda q, c,: f"question: {q} context: {c}" + passages_batch = [[prepare_data(question, context) for context in contexts] + for (question, contexts) in zip(questions_batch, contexts_batch)] - question_passages_ids, question_passage_masks = self.encode_question_passages(question_passages_batch) + passage_ids, passage_masks = self.encode_passages(passages_batch) target_ids, target_masks = self.encode_targets(targets_batch) - return question_passages_ids, question_passage_masks, target_ids + return passage_ids, passage_masks, target_ids @register('torch_transformers_ner_preprocessor') class TorchTransformersNerPreprocessor(Component): diff --git a/deeppavlov/models/torch_bert/torch_generative_qa.py b/deeppavlov/models/torch_bert/torch_generative_qa.py index 24cce40548..fc0bd71f82 100644 --- a/deeppavlov/models/torch_bert/torch_generative_qa.py +++ b/deeppavlov/models/torch_bert/torch_generative_qa.py @@ -45,8 +45,8 @@ def __init__(self, load_before_drop: bool = True, clip_norm: Optional[float] = None, min_learning_rate: float = 1e-06, - generate_max_length: int = 20, - **kwargs): + generate_max_length: int = 50, + **kwargs) -> None: if not optimizer_parameters: optimizer_parameters = {"lr": 0.01, @@ -71,10 +71,7 @@ def __init__(self, min_learning_rate=min_learning_rate, **kwargs) - def train_on_batch(self, - input_ids_batch: List[List[float]], - attention_mask_batch: List[List[float]], - target_ids_batch: List[List[float]]) -> Dict: + def train_on_batch(self, input_ids_batch, attention_mask_batch, target_ids_batch) -> Dict: input_ids_batch = torch.LongTensor(input_ids_batch).to(self.device) attention_mask_batch = torch.LongTensor(attention_mask_batch).to(self.device) target_ids_batch = torch.LongTensor(target_ids_batch).to(self.device) @@ -101,14 +98,14 @@ def train_on_batch(self, self.optimizer.step() if self.lr_scheduler is not None: self.lr_scheduler.step() - + return {'loss': loss.item()} @property def is_data_parallel(self) -> bool: return isinstance(self.model, torch.nn.DataParallel) - def __call__(self, input_ids_batch: List[List[float]], attention_mask_batch: List[List[float]]) -> List[str]: + def __call__(self, input_ids_batch, attention_mask_batch): input_ids_batch = torch.LongTensor(input_ids_batch).to(self.device) attention_mask_batch = torch.LongTensor(attention_mask_batch).to(self.device) input_ = { @@ -120,11 +117,11 @@ def __call__(self, input_ids_batch: List[List[float]], attention_mask_batch: Lis with torch.no_grad(): answer_ids_batch = model.generate(max_length=self.generate_max_length, **input_) - answers_batch = self.tokenizer.batch_decode(answer_ids_batch, skip_special_tokens=True) + answers_batch = self.tokenizer.batch_decode(answer_ids_batch, skip_special_tokens=True) return answers_batch @overrides - def save(self, fname: Optional[str] = None): + def save(self, fname: Optional[str] = None, *args, **kwargs): if fname is None: fname = self.save_path os.makedirs(fname, exist_ok=True) @@ -136,7 +133,7 @@ def save(self, fname: Optional[str] = None): model_to_save.save_pretrained(model_dir_path) # Save optimizer and scheduler - optimizer_path = os.path.join(fname, "optimizer.pth.tar") + optimizer_path = str(Path(fname, "optimizer.pth.tar").resolve()) optimizer_state = { "optimizer": self.optimizer.state_dict() } @@ -154,47 +151,53 @@ def init_optimizer_from_scratch(self) -> None: if self.opt.get("criterion", None): self.criterion = getattr(torch.nn, self.opt.get("criterion", None))() - def init_from_scratch(self) -> None: + def init_model_from_scratch(self) -> None: logger.info(f"From pretrained {self.pretrained_transformer}.") self.tokenizer = T5Tokenizer.from_pretrained(self.pretrained_transformer, return_dict=False) t5 = T5ForConditionalGeneration.from_pretrained(self.pretrained_transformer) self.model = FiDT5(t5.config) self.model.load_t5(t5.state_dict()) - self.model.to(self.device) - - self.init_optimizer_from_scratch() + self.model.to(self.device) # TODO: model = self.model.to(self.device) ? - def load_from_checkpoint(self, model_dir_path: str, optimizer_path: str) -> None: + def load_model_from_checkpoint(self, model_dir_path: str): logger.info(f"Loading model from {model_dir_path}.") self.model = FiDT5.from_pretrained(model_dir_path) self.model = self.model.to(self.device) + def load_optimizer_from_checkpoint(self, optimizer_path: str): logger.info(f"Loading optimizer from {optimizer_path}.") self.init_optimizer_from_scratch() optimizer_state = torch.load(optimizer_path, map_location=self.device) self.optimizer.load_state_dict(optimizer_state["optimizer"]) @overrides - def load(self, fname: Optional[str] = None) -> None: + def load(self, fname: Optional[str] = None, *args, **kwargs) -> None: if fname is not None: self.load_path = fname - # Loading weights from checkpoint if self.load_path is not None: logger.info(f"Load path {self.load_path} is given.") - model_dir_path = self.load_path - optimizer_path = os.path.join(self.load_path, "optimizer.pth.tar") + model_config_path = Path(self.load_path) / "config.json" + model_weights_path = Path(self.load_path) / "pytorch_model.bin" + optimizer_path = Path(self.load_path) / "optimizer.pth.tar" - if Path(model_dir_path).exists() and Path(optimizer_path).exists(): - self.load_from_checkpoint(model_dir_path, optimizer_path) + if model_config_path.exists() and model_weights_path.exists(): + self.load_model_from_checkpoint(self.load_path) + else: + self.init_model_from_scratch() + logger.info(f"Init model from scratch: {model_config_path} or {model_weights_path} does not exist.") + + if optimizer_path.exists(): + self.load_optimizer_from_checkpoint(str(optimizer_path.resolve())) else: - self.init_from_scratch() - logger.info(f"Init from scratch. Model_path: {model_dir_path} or optimizer_path: {optimizer_path} does not exist.") + self.init_optimizer_from_scratch() + logger.info(f"Init optimizer from scratch: {optimizer_path} does not exist.") else: - self.init_from_scratch() - logger.info(f"Init from scratch. Load path {self.load_path} does not exist.") + self.init_model_from_scratch() + self.init_optimizer_from_scratch() + logger.info(f"Init model and optimizer from scratch: \"load_path\" and \"fname\" are not given.") if self.device.type == "cuda" and torch.cuda.device_count() > 1: self.model = torch.nn.DataParallel(self.model) \ No newline at end of file From 3e0192d02c3afdd8b138481d4d2a82133d8c67af Mon Sep 17 00:00:00 2001 From: Maksim Savkin Date: Thu, 25 Aug 2022 12:47:54 +0300 Subject: [PATCH 3/9] fix inference preprocessing --- deeppavlov/configs/kbqa/nq_fid.json | 44 ++++++------- deeppavlov/core/common/registry.json | 3 +- .../torch_transformers_preprocessor.py | 66 +++++++++---------- 3 files changed, 56 insertions(+), 57 deletions(-) diff --git a/deeppavlov/configs/kbqa/nq_fid.json b/deeppavlov/configs/kbqa/nq_fid.json index 7e0c6726cf..d6576d732b 100644 --- a/deeppavlov/configs/kbqa/nq_fid.json +++ b/deeppavlov/configs/kbqa/nq_fid.json @@ -1,23 +1,31 @@ { "dataset_reader": { "class_name": "json_reader", - "valid_size": 1000, + "valid_size": 1, "data_path": "{DATASET_PATH}/natural_questions_dataset.json" }, "dataset_iterator": { - "class_name": "data_learning_iterator" + "class_name": "data_learning_iterator", + "seed": 42, + "shuffle": true }, "chainer": { "in": ["question", "contexts"], "in_y": ["target", "gold_answers"], "pipe": [ { - "class_name": "torch_transformers_fid_preprocessor", + "class_name": "fid_input_preprocessor", "vocab_file": "{TRANSFORMER}", - "max_seq_length": 512, - "answer_maxlength" : 20, - "in": ["question", "contexts", "target"], - "out": ["input_ids", "attention_mask", "target_ids"] + "max_seq_length": 200, + "in": ["question", "contexts"], + "out": ["input_ids", "attention_mask"] + }, + { + "class_name": "fid_target_preprocessor", + "vocab_file": "{TRANSFORMER}", + "answer_maxlength" : 50, + "in": ["target"], + "out": ["target_ids"] }, { "class_name": "torch_generative_qa_fid", @@ -34,7 +42,7 @@ "learning_rate_drop_patience": 8, "learning_rate_drop_div": 2, "min_learning_rate": 1e-5, - "generate_max_length" : 20, + "generate_max_length" : 50, "in": ["input_ids", "attention_mask"], "in_y": ["target_ids"], "out": ["model_answer"] @@ -49,8 +57,8 @@ ], "log_every_n_batches": 100, "val_every_n_batches": 600, - "batch_size": 8, - "validation_patience": 60, + "batch_size": 1, + "validation_patience": 50, "metrics": [ { "name": "squad_v2_em", @@ -59,10 +67,6 @@ { "name": "squad_v2_f1", "inputs": ["gold_answers", "model_answer"] - }, - { - "name": "sacrebleu", - "inputs": ["gold_answers", "model_answer"] } ], "class_name": "torch_trainer" @@ -78,21 +82,17 @@ }, "download": [ { - "url": "http://files.deeppavlov.ai/deeppavlov_data/kbqa/datasets/natural_questions/natural_questions_dataset.json", + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/datasets/natural_questions/natural_questions_dataset.json", "subdir": "{DATASET_PATH}" }, { - "url": "http://files.deeppavlov.ai/deeppavlov_data/kbqa/models/generative_qa/fusion_in_decoder/natural_questions/config.json", - "subdir": "{MODEL_PATH}" - }, - { - "url": "http://files.deeppavlov.ai/deeppavlov_data/kbqa/models/generative_qa/fusion_in_decoder/natural_questions/pytorch_model.bin", + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/models/fusion_in_decoder/natural_questions/config.json", "subdir": "{MODEL_PATH}" }, { - "url": "http://files.deeppavlov.ai/deeppavlov_data/kbqa/models/generative_qa/fusion_in_decoder/natural_questions/optimizer.pth.tar", + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/models/fusion_in_decoder/natural_questions/pytorch_model.bin", "subdir": "{MODEL_PATH}" } ] } -} +} \ No newline at end of file diff --git a/deeppavlov/core/common/registry.json b/deeppavlov/core/common/registry.json index 6edd96ad2b..138d056c6f 100644 --- a/deeppavlov/core/common/registry.json +++ b/deeppavlov/core/common/registry.json @@ -16,6 +16,8 @@ "entity_linker": "deeppavlov.models.entity_extraction.entity_linking:EntityLinker", "faq_reader": "deeppavlov.dataset_readers.faq_reader:FaqDatasetReader", "fasttext": "deeppavlov.models.embedders.fasttext_embedder:FasttextEmbedder", + "fid_input_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:FiDInputPreprocessor", + "fid_target_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:FiDTargetPreprocessor", "fit_trainer": "deeppavlov.core.trainers.fit_trainer:FitTrainer", "hashing_tfidf_vectorizer": "deeppavlov.models.vectorizers.hashing_tfidf_vectorizer:HashingTfIdfVectorizer", "huggingface_dataset_iterator": "deeppavlov.dataset_iterators.huggingface_dataset_iterator:HuggingFaceDatasetIterator", @@ -91,7 +93,6 @@ "torch_transformers_el_ranker": "deeppavlov.models.torch_bert.torch_transformers_el_ranker:TorchTransformersElRanker", "torch_transformers_entity_ranker_infer": "deeppavlov.models.torch_bert.torch_transformers_el_ranker:TorchTransformersEntityRankerInfer", "torch_transformers_entity_ranker_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersEntityRankerPreprocessor", - "torch_transformers_fid_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersFiDPreprocessor", "torch_transformers_multiplechoice": "deeppavlov.models.torch_bert.torch_transformers_multiplechoice:TorchTransformersMultiplechoiceModel", "torch_transformers_multiplechoice_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersMultiplechoicePreprocessor", "torch_transformers_ner_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersNerPreprocessor", diff --git a/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py b/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py index 05c5f911b7..51cdd81edf 100644 --- a/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py +++ b/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py @@ -421,27 +421,26 @@ def __call__(self, questions_batch: List[List[str]], rels_batch: List[List[str]] return input_features -@register('torch_transformers_fid_preprocessor') -class TorchTransformersFiDPreprocessor(Component): +@register('fid_input_preprocessor') +class FiDInputPreprocessor(Component): def __init__(self, vocab_file: str, do_lower_case: bool = True, max_seq_length: int = 512, - return_tokens: bool = False, - answer_maxlength: int = 20, **kwargs) -> None: self.max_seq_length = max_seq_length - self.return_tokens = return_tokens - self.answer_maxlength = answer_maxlength if Path(vocab_file).is_file(): vocab_file = str(expand_path(vocab_file)) self.tokenizer = AutoTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) else: self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=do_lower_case) - - def encode_passages(self, passages_batch: List[List[str]]): + def __call__(self, questions_batch: List[str], contexts_batch: List[List[str]]): + prepare_data = lambda q, c,: f"question: {q} context: {c}" + passages_batch = [[prepare_data(question, context) for context in contexts] + for (question, contexts) in zip(questions_batch, contexts_batch)] + passage_ids, passage_masks = [], [] for text_passages in passages_batch: passages_encoding = self.tokenizer( @@ -456,35 +455,34 @@ def encode_passages(self, passages_batch: List[List[str]]): passage_ids = torch.cat(passage_ids, dim=0) passage_masks = torch.cat(passage_masks, dim=0) - return passage_ids, passage_masks - - def encode_targets(self, targets_batch: List[str] = None): - target_ids, target_masks = None, None - if targets_batch is not None: - target_encoding = self.tokenizer( - targets_batch, - max_length=self.answer_maxlength if self.answer_maxlength > 0 else None, - pad_to_max_length=True, - return_tensors='pt', - truncation=True if self.answer_maxlength > 0 else False, - ) - target_ids = target_encoding["input_ids"] - target_mask = target_encoding["attention_mask"].bool() - return target_ids, target_masks - - def __call__(self, - questions_batch: List[str], - contexts_batch: List[List[str]], - targets_batch: List[str] = None): - prepare_data = lambda q, c,: f"question: {q} context: {c}" - passages_batch = [[prepare_data(question, context) for context in contexts] - for (question, contexts) in zip(questions_batch, contexts_batch)] - passage_ids, passage_masks = self.encode_passages(passages_batch) + return passage_ids, passage_masks - target_ids, target_masks = self.encode_targets(targets_batch) +@register('fid_target_preprocessor') +class FiDTargetPreprocessor(Component): + def __init__(self, + vocab_file: str, + do_lower_case: bool = True, + answer_maxlength: int = 50, + **kwargs) -> None: + self.answer_maxlength = answer_maxlength + if Path(vocab_file).is_file(): + vocab_file = str(expand_path(vocab_file)) + self.tokenizer = AutoTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) + else: + self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=do_lower_case) + - return passage_ids, passage_masks, target_ids + def __call__(self, targets_batch: List[str]): + target_encoding = self.tokenizer( + targets_batch, + max_length=self.answer_maxlength if self.answer_maxlength > 0 else None, + pad_to_max_length=True, + return_tensors='pt', + truncation=True if self.answer_maxlength > 0 else False, + ) + target_ids = target_encoding["input_ids"] + return target_ids @register('torch_transformers_ner_preprocessor') class TorchTransformersNerPreprocessor(Component): From d38b1bce3a6174efc988031088c08be0661b2fed Mon Sep 17 00:00:00 2001 From: Maksim Savkin Date: Mon, 29 Aug 2022 17:28:07 +0300 Subject: [PATCH 4/9] add trivia_qa config --- .../{kbqa => generative_qa}/nq_fid.json | 7 +- deeppavlov/configs/generative_qa/tqa_fid.json | 97 +++++++++++++++++++ .../{kbqa => torch_bert}/fusion_in_decoder.py | 0 .../models/torch_bert/torch_generative_qa.py | 4 +- 4 files changed, 102 insertions(+), 6 deletions(-) rename deeppavlov/configs/{kbqa => generative_qa}/nq_fid.json (94%) create mode 100644 deeppavlov/configs/generative_qa/tqa_fid.json rename deeppavlov/models/{kbqa => torch_bert}/fusion_in_decoder.py (100%) diff --git a/deeppavlov/configs/kbqa/nq_fid.json b/deeppavlov/configs/generative_qa/nq_fid.json similarity index 94% rename from deeppavlov/configs/kbqa/nq_fid.json rename to deeppavlov/configs/generative_qa/nq_fid.json index d6576d732b..945ec0de95 100644 --- a/deeppavlov/configs/kbqa/nq_fid.json +++ b/deeppavlov/configs/generative_qa/nq_fid.json @@ -1,8 +1,7 @@ { "dataset_reader": { "class_name": "json_reader", - "valid_size": 1, - "data_path": "{DATASET_PATH}/natural_questions_dataset.json" + "data_path": "/archive/savkin/parsed_datasets/natural_questions/fid/nq_dataset.json" }, "dataset_iterator": { "class_name": "data_learning_iterator", @@ -10,7 +9,7 @@ "shuffle": true }, "chainer": { - "in": ["question", "contexts"], + "in": ["question", "contexts", "titles"], "in_y": ["target", "gold_answers"], "pipe": [ { @@ -57,7 +56,7 @@ ], "log_every_n_batches": 100, "val_every_n_batches": 600, - "batch_size": 1, + "batch_size": 8, "validation_patience": 50, "metrics": [ { diff --git a/deeppavlov/configs/generative_qa/tqa_fid.json b/deeppavlov/configs/generative_qa/tqa_fid.json new file mode 100644 index 0000000000..7ca76141f3 --- /dev/null +++ b/deeppavlov/configs/generative_qa/tqa_fid.json @@ -0,0 +1,97 @@ +{ + "dataset_reader": { + "class_name": "json_reader", + "data_path": "/archive/savkin/parsed_datasets/trivia_qa/fid/trivia_qa_dataset.json" + }, + "dataset_iterator": { + "class_name": "data_learning_iterator", + "seed": 42, + "shuffle": true + }, + "chainer": { + "in": ["question", "contexts", "titles"], + "in_y": ["target", "gold_answers"], + "pipe": [ + { + "class_name": "fid_input_preprocessor", + "vocab_file": "{TRANSFORMER}", + "max_seq_length": 200, + "in": ["question", "contexts"], + "out": ["input_ids", "attention_mask"] + }, + { + "class_name": "fid_target_preprocessor", + "vocab_file": "{TRANSFORMER}", + "answer_maxlength" : 50, + "in": ["target"], + "out": ["target_ids"] + }, + { + "class_name": "torch_generative_qa_fid", + "pretrained_transformer": "{TRANSFORMER}", + "save_path": "{MODEL_PATH}", + "load_path": "{MODEL_PATH}", + "optimizer": "AdamW", + "optimizer_parameters": { + "lr": 3e-04, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-08 + }, + "learning_rate_drop_patience": 8, + "learning_rate_drop_div": 2, + "min_learning_rate": 1e-5, + "generate_max_length" : 50, + "in": ["input_ids", "attention_mask"], + "in_y": ["target_ids"], + "out": ["model_answer"] + } + ], + "out": ["model_answer"] + }, + "train": { + "show_examples": false, + "evaluation_targets": [ + "valid" + ], + "log_every_n_batches": 100, + "val_every_n_batches": 600, + "batch_size": 8, + "validation_patience": 50, + "metrics": [ + { + "name": "squad_v2_em", + "inputs": ["gold_answers", "model_answer"] + }, + { + "name": "squad_v2_f1", + "inputs": ["gold_answers", "model_answer"] + } + ], + "class_name": "torch_trainer" + }, + "metadata": { + "variables": { + "TRANSFORMER": "t5-base", + "ROOT_PATH": "~/.deeppavlov", + "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", + "MODELS_PATH": "{ROOT_PATH}/models", + "MODEL_PATH": "{MODELS_PATH}/generative_qa/fusion_in_decoder/trivia_qa", + "DATASET_PATH": "{DOWNLOADS_PATH}/trivia_qa" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/datasets/trivia_qa/trivia_qa_dataset.json", + "subdir": "{DATASET_PATH}" + }, + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/models/fusion_in_decoder/trivia_qa/config.json", + "subdir": "{MODEL_PATH}" + }, + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/models/fusion_in_decoder/trivia_qa/pytorch_model.bin", + "subdir": "{MODEL_PATH}" + } + ] + } + } \ No newline at end of file diff --git a/deeppavlov/models/kbqa/fusion_in_decoder.py b/deeppavlov/models/torch_bert/fusion_in_decoder.py similarity index 100% rename from deeppavlov/models/kbqa/fusion_in_decoder.py rename to deeppavlov/models/torch_bert/fusion_in_decoder.py diff --git a/deeppavlov/models/torch_bert/torch_generative_qa.py b/deeppavlov/models/torch_bert/torch_generative_qa.py index fc0bd71f82..8fa4d593cf 100644 --- a/deeppavlov/models/torch_bert/torch_generative_qa.py +++ b/deeppavlov/models/torch_bert/torch_generative_qa.py @@ -25,7 +25,7 @@ from deeppavlov.core.models.torch_model import TorchModel -from deeppavlov.models.kbqa.fusion_in_decoder import FiDT5 +from deeppavlov.models.torch_bert.fusion_in_decoder import FiDT5 logger = getLogger(__name__) @@ -158,7 +158,7 @@ def init_model_from_scratch(self) -> None: self.model = FiDT5(t5.config) self.model.load_t5(t5.state_dict()) - self.model.to(self.device) # TODO: model = self.model.to(self.device) ? + self.model.to(self.device) def load_model_from_checkpoint(self, model_dir_path: str): From 70c7c0f1ba05f187b7bd5d70be9130e7f93574f6 Mon Sep 17 00:00:00 2001 From: Maksim Savkin Date: Tue, 30 Aug 2022 11:57:17 +0300 Subject: [PATCH 5/9] add documentation --- docs/features/models/generative_qa.rst | 219 +++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 docs/features/models/generative_qa.rst diff --git a/docs/features/models/generative_qa.rst b/docs/features/models/generative_qa.rst new file mode 100644 index 0000000000..ab8cb0f294 --- /dev/null +++ b/docs/features/models/generative_qa.rst @@ -0,0 +1,219 @@ +Generative Question Answering +============================= + +Task definitfion +---------------- +Generative Question Answering is the task of finding an answer on question in a given contexts (e.g, paragraphs from Wikipedia), +where the answer to each question is **not necessary** a segment of the context. + + +**Question**: + + Is it possible to have a rating above 4000 in chess? + +**Contexts**: + + > Right now that can't really happen. Now, the highest-rated chess player is Stockfish 12, with a rating of 3515. A rating difference of 400 points means you'll beat your opponent over 90% of the time. Here we're looking at an even bigger difference than that: about 500 points. + + > It's nearly impossible to measure the rating difference between two players so far apart in skill. For there to be a player with a rating of 4000, there would first have to be other players with ratings that are at least fairly close, like 3800. + +**Answer**: + + not really possible + +Datasets +-------- +We consider the following datasets: + +- `Natural Questions `__ +- `TriviaQA `__ + +Specifically, we validate our model on *Natural Questions* and *TriviaQA* from: https://github.com/facebookresearch/FiD. + + +Datasets format +~~~~~~~~~~~~~~~ +.. code-block:: json + + { + "train":[ + [ + [ + question, + [contexts], + [titles] + ], + [ + target, + [answers] + ] + ], + [ + ... + ], + ... + ] + "valid": [...] + "test": [...] + } + +Built-In Models +--------------- +DeepPavlov's model for generative question answering is based on Fusion-in-decoder(FiD) base. +The model generates answer based on the question and k-support contexts. + +Currently, we provide two built-in models for generative question answering in DeepPavlov library, finetuned on 2 datasets: + +- Natural Questions :config:`deeppavlov/configs/generative_qa/nq_fid.json` + +- TriviaQA :config:`deeppavlov/configs/generative_qa/tqa_fid.json` + +Architecture +~~~~~~~~~~~~ +FiD model uses several support passages to gather usefull information from multiple knowledge sources. Firstly, every +passage is concatinated with the question like this *"question: What is the capital of UK? passage: London is the capital of UK"* +and processed independently from other passages by the encoder of pretrained sequence-to-sequence network (e.g. T5). +Then the decoder performs attention over the concatenation of the resulting representations of all the retrieved passages + + +Metrics +~~~~~~~ +Natural Questions dataset +^^^^^^^^^^^^^^^^^^^^^^^^^ ++---------------------------------------------------------+---------------------------------+---------------------------------+ +| Dataset | Natural Questions (dev) | Natural Questions (test) | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| Model | EM | F-1 | EM | F-1 | ++=========================================================+================+================+================+================+ +| :config:`DeepPavlov FiD ` | 39.9 | 50.0 | 46.0 | 54.1 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `T5`_ | 42.0 | 50.6 | 42.2 | 49.7 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ + + +TriviaQA dataset +^^^^^^^^^^^^^^^^ ++---------------------------------------------------------+---------------------------------+---------------------------------+ +| Dataset | TriviaQA (dev) | TriviaQA (test) | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| Model | EM | F-1 | EM | F-1 | ++=========================================================+================+================+================+================+ +| :config:`DeepPavlov FiD ` | 61.8 | 69.6 | 63.1 | 70.0 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| :config:`DeepPavlov FiD ` | 51.1 | 61.3 | 52.2 | 61.9 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `T5 (обученная на NQ, её лучше убрать отсюда)`_ | 46.0 | 55.0 | 46.1 | 55.3 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `QANet`_ | 51.1 | 56.6 | -- | -- | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `M-Reader`_ | -- | -- | 46.9 | 52.9 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `MEMEN`_ | 43.2 | 46.9 | -- | -- | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `BiDAF`_ | 40.3 | 45.7 | -- | -- | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ + + +.. _`M-Reader`: https://arxiv.org/abs/1705.02798 +.. _`MEMEN`: https://arxiv.org/abs/1707.09098 +.. _`QANet`: https://arxiv.org/abs/1804.09541 +.. _`BiDAF`: https://arxiv.org/abs/1611.01603 +.. _`T5`: https://arxiv.org/abs/1910.10683 + + + +Prerequisites +------------- + +Before using the models make sure that all required packages are installed running the command: + + .. code:: bash + + python -m deeppavlov install nq_fid + python -m deeppavlov install tqa_fid + + +Pretrained models are available and can be downloaded (~0.9Gb): + + .. code:: bash + + python -m deeppavlov download nq_fid + python -m deeppavlov download tqa_fid + + +Model usage from Python +----------------------- + +Interact +~~~~~~~~ + .. code:: python + + from deeppavlov import build_model + + model = build_model('nq_fid', download=True) + + model([ + "What is the capital of UK?", + "Where did the name Atari itself come from?" + ], + [ + [ + "The name Britain is sometimes used to refer to the United Kingdom as a whole", + "London is the capital of Great Britain" + ], + [ + "Bushnell and Dabney were originally going to name their company Syzygy, a term for planetary alignment, but found that it had been registered already.", + "Instead, they chose a word from the Japanese game Go. The Japanese equivalent of chess, in Go Atari means something similar to \'check\'." + ] + ]) + >>> ['london', 'the japanese game go'] + + model([ + "How many points do you need to win in badminton?" + ], + [ + [ + "A rally is lost if the shuttle is hit into the net, or over the net but outside of the opponent's court.", + "A rally is also lost if the shuttle touches the player's clothing or body, or if it is hit before it crosses over the net", + 'The side winning a rally adds a point to its score', 'A match consists of the best of 3 games of 21 points (games cap at 30 points)', + "A rally is won when a shuttle is hit over the net and onto the floor of the opponent's court.", + 'At 29 all, the side scoring the 30th point, wins that game', + 'The side winning a game serves first in the next game', + 'At 20 all, the side which gains a 2 point lead first, wins that game.', + 'Each gamestarts at 0-0. If the match goes to the third game that third game will be played to 15' + ] + ]) + >>> ['21'] + +Train +~~~~~ + .. code:: python + + from deeppavlov import train_model + + model = train_model('nq_fid', download=True) + + +Model usage from CLI +-------------------- + +Train +~~~~~ + .. code:: bash + + python -m deeppavlov train nq_fid + +Evaluate +~~~~~~~~ + .. code:: bash + + python -m deeppavlov evaluate nq_fid + +Interact +~~~~~~~~ + +Interact mode provides command line interface to already trained model. + + .. code:: bash + + python -m deeppavlov interact nq_fid From 1918c1370f96e22aa184a85fe0d7fef1de909478 Mon Sep 17 00:00:00 2001 From: Maksim Savkin Date: Tue, 30 Aug 2022 12:45:01 +0300 Subject: [PATCH 6/9] fixed requirements --- deeppavlov/core/common/requirements_registry.json | 10 +++++++--- deeppavlov/requirements/transformers_3.0.2.txt | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) create mode 100644 deeppavlov/requirements/transformers_3.0.2.txt diff --git a/deeppavlov/core/common/requirements_registry.json b/deeppavlov/core/common/requirements_registry.json index f29d8e7319..2ea6b7526d 100644 --- a/deeppavlov/core/common/requirements_registry.json +++ b/deeppavlov/core/common/requirements_registry.json @@ -88,7 +88,7 @@ ], "torch_generative_qa_fid": [ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", - "{DEEPPAVLOV_PATH}/requirements/transformers.txt", + "{DEEPPAVLOV_PATH}/requirements/transformers_3.0.2.txt", "{DEEPPAVLOV_PATH}/requirements/sacrebleu.txt" ], "torch_record_postprocessor": [ @@ -118,9 +118,13 @@ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" ], - "torch_transformers_fid_preprocessor": [ + "fid_input_preprocessor": [ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", - "{DEEPPAVLOV_PATH}/requirements/transformers.txt" + "{DEEPPAVLOV_PATH}/requirements/transformers_3.0.2.txt" + ], + "fid_target_preprocessor": [ + "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", + "{DEEPPAVLOV_PATH}/requirements/transformers_3.0.2.txt" ], "torch_transformers_multiplechoice": [ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", diff --git a/deeppavlov/requirements/transformers_3.0.2.txt b/deeppavlov/requirements/transformers_3.0.2.txt new file mode 100644 index 0000000000..70ef216ce9 --- /dev/null +++ b/deeppavlov/requirements/transformers_3.0.2.txt @@ -0,0 +1 @@ +transformers==3.0.2 \ No newline at end of file From 6d41facaa081acf2d111ce8425012fca0ae99924 Mon Sep 17 00:00:00 2001 From: Maksim Savkin Date: Tue, 30 Aug 2022 12:49:42 +0300 Subject: [PATCH 7/9] remove sacrebleu --- deeppavlov/core/common/metrics_registry.json | 1 - deeppavlov/core/common/requirements_registry.json | 3 +-- deeppavlov/metrics/bleu.py | 11 ----------- deeppavlov/requirements/sacrebleu.txt | 1 - 4 files changed, 1 insertion(+), 15 deletions(-) delete mode 100644 deeppavlov/requirements/sacrebleu.txt diff --git a/deeppavlov/core/common/metrics_registry.json b/deeppavlov/core/common/metrics_registry.json index 79d0364cc8..c1f1a6c7a0 100644 --- a/deeppavlov/core/common/metrics_registry.json +++ b/deeppavlov/core/common/metrics_registry.json @@ -31,7 +31,6 @@ "r@5": "deeppavlov.metrics.recall_at_k:r_at_5", "rank_response": "deeppavlov.models.ranking.metrics:rank_response", "roc_auc": "deeppavlov.metrics.roc_auc_score:roc_auc_score", - "sacrebleu": "deeppavlov.metrics.bleu:sacrebleu", "sets_accuracy": "deeppavlov.metrics.accuracy:sets_accuracy", "slots_accuracy": "deeppavlov.metrics.accuracy:slots_accuracy", "spearman_correlation": "deeppavlov.metrics.correlation:spearman_correlation", diff --git a/deeppavlov/core/common/requirements_registry.json b/deeppavlov/core/common/requirements_registry.json index 2ea6b7526d..ac03a5e154 100644 --- a/deeppavlov/core/common/requirements_registry.json +++ b/deeppavlov/core/common/requirements_registry.json @@ -88,8 +88,7 @@ ], "torch_generative_qa_fid": [ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", - "{DEEPPAVLOV_PATH}/requirements/transformers_3.0.2.txt", - "{DEEPPAVLOV_PATH}/requirements/sacrebleu.txt" + "{DEEPPAVLOV_PATH}/requirements/transformers_3.0.2.txt" ], "torch_record_postprocessor": [ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", diff --git a/deeppavlov/metrics/bleu.py b/deeppavlov/metrics/bleu.py index ddbfbf6846..9c2cfd3615 100644 --- a/deeppavlov/metrics/bleu.py +++ b/deeppavlov/metrics/bleu.py @@ -81,14 +81,3 @@ def per_item_dialog_bleu(y_true, y_predicted): y_true = (y['text'] for dialog in y_true for y in dialog) return corpus_bleu([[y_t.lower().split()] for y_t in y_true], [y.lower().split() for y_p in y_predicted for y in y_p]) - -@register_metric('sacrebleu') -def sacrebleu(y_true: List[str], y_predicted: List[str]) -> float: - y_true_padded = [] - max_answers_cnt = max(len(answers) for answers in y_true) - for answers in y_true: - y_true_padded.append(answers + [''] * (max_answers_cnt - len(answers))) - y_true = np.transpose(y_true_padded).tolist() - - bleu = BLEU() - return bleu.corpus_score(y_predicted, y_true).score \ No newline at end of file diff --git a/deeppavlov/requirements/sacrebleu.txt b/deeppavlov/requirements/sacrebleu.txt deleted file mode 100644 index b71be058a6..0000000000 --- a/deeppavlov/requirements/sacrebleu.txt +++ /dev/null @@ -1 +0,0 @@ -sacrebleu==2.1.0 \ No newline at end of file From 2af1aa4c93ce84493957388aae7e97a742ae9c5a Mon Sep 17 00:00:00 2001 From: Maksim Savkin Date: Tue, 30 Aug 2022 13:02:44 +0300 Subject: [PATCH 8/9] fix configs --- deeppavlov/configs/generative_qa/nq_fid.json | 8 ++++---- deeppavlov/configs/generative_qa/tqa_fid.json | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/deeppavlov/configs/generative_qa/nq_fid.json b/deeppavlov/configs/generative_qa/nq_fid.json index 945ec0de95..1cbbaba7f1 100644 --- a/deeppavlov/configs/generative_qa/nq_fid.json +++ b/deeppavlov/configs/generative_qa/nq_fid.json @@ -1,7 +1,7 @@ { "dataset_reader": { "class_name": "json_reader", - "data_path": "/archive/savkin/parsed_datasets/natural_questions/fid/nq_dataset.json" + "data_path": "{DATASET_PATH}/natural_questions_dataset.json" }, "dataset_iterator": { "class_name": "data_learning_iterator", @@ -38,7 +38,7 @@ "betas": [0.9, 0.999], "eps": 1e-08 }, - "learning_rate_drop_patience": 8, + "learning_rate_drop_patience": 24, "learning_rate_drop_div": 2, "min_learning_rate": 1e-5, "generate_max_length" : 50, @@ -56,8 +56,8 @@ ], "log_every_n_batches": 100, "val_every_n_batches": 600, - "batch_size": 8, - "validation_patience": 50, + "batch_size": 1, + "validation_patience": 100, "metrics": [ { "name": "squad_v2_em", diff --git a/deeppavlov/configs/generative_qa/tqa_fid.json b/deeppavlov/configs/generative_qa/tqa_fid.json index 7ca76141f3..c936307cc8 100644 --- a/deeppavlov/configs/generative_qa/tqa_fid.json +++ b/deeppavlov/configs/generative_qa/tqa_fid.json @@ -1,7 +1,7 @@ { "dataset_reader": { "class_name": "json_reader", - "data_path": "/archive/savkin/parsed_datasets/trivia_qa/fid/trivia_qa_dataset.json" + "data_path": "{DATASET_PATH}/trivia_qa_dataset.json" }, "dataset_iterator": { "class_name": "data_learning_iterator", @@ -38,7 +38,7 @@ "betas": [0.9, 0.999], "eps": 1e-08 }, - "learning_rate_drop_patience": 8, + "learning_rate_drop_patience": 24, "learning_rate_drop_div": 2, "min_learning_rate": 1e-5, "generate_max_length" : 50, @@ -56,8 +56,8 @@ ], "log_every_n_batches": 100, "val_every_n_batches": 600, - "batch_size": 8, - "validation_patience": 50, + "batch_size": 1, + "validation_patience": 100, "metrics": [ { "name": "squad_v2_em", From cf5144d55f70fdc4e904d65b19e25db078a90e54 Mon Sep 17 00:00:00 2001 From: Maksim Savkin Date: Tue, 30 Aug 2022 17:16:54 +0300 Subject: [PATCH 9/9] fix docs --- deeppavlov/metrics/bleu.py | 2 +- docs/features/models/generative_qa.rst | 37 +++++++++++--------------- docs/index.rst | 1 + 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/deeppavlov/metrics/bleu.py b/deeppavlov/metrics/bleu.py index 9c2cfd3615..4fcc8fec67 100644 --- a/deeppavlov/metrics/bleu.py +++ b/deeppavlov/metrics/bleu.py @@ -19,7 +19,7 @@ from deeppavlov.core.common.metrics_registry import register_metric from deeppavlov.metrics.google_bleu import compute_bleu -from sacrebleu.metrics import BLEU + import numpy as np SMOOTH = SmoothingFunction() diff --git a/docs/features/models/generative_qa.rst b/docs/features/models/generative_qa.rst index ab8cb0f294..ebf62f6546 100644 --- a/docs/features/models/generative_qa.rst +++ b/docs/features/models/generative_qa.rst @@ -33,29 +33,24 @@ Specifically, we validate our model on *Natural Questions* and *TriviaQA* from: Datasets format ~~~~~~~~~~~~~~~ -.. code-block:: json - - { - "train":[ - [ - [ - question, - [contexts], - [titles] - ], - [ - target, - [answers] - ] - ], - [ - ... +{ + "train": [ + [ + [ "question", [ "contexts" ], [ "titles" ] ], + + [ "target", [ "answers" ] ] + ], + ... + ] - "valid": [...] - "test": [...] - } + + "valid": [ ... ] + + "test": [ ... ] + +} Built-In Models --------------- @@ -102,7 +97,7 @@ TriviaQA dataset +---------------------------------------------------------+----------------+----------------+----------------+----------------+ | :config:`DeepPavlov FiD ` | 51.1 | 61.3 | 52.2 | 61.9 | +---------------------------------------------------------+----------------+----------------+----------------+----------------+ -| `T5 (обученная на NQ, её лучше убрать отсюда)`_ | 46.0 | 55.0 | 46.1 | 55.3 | +| `T5`_ | 46.0 | 55.0 | 46.1 | 55.3 | +---------------------------------------------------------+----------------+----------------+----------------+----------------+ | `QANet`_ | 51.1 | 56.6 | -- | -- | +---------------------------------------------------------+----------------+----------------+----------------+----------------+ diff --git a/docs/index.rst b/docs/index.rst index 0fe7640253..0c556edcc3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,7 @@ Welcome to DeepPavlov's documentation! Knowledge Base Question answering Relation Extraction SuperGLUE Submission + Generative Question Answering .. toctree::