diff --git a/docs/tutorial/tutorial-training/how-to-train-span-classifier.md b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md index e1d916ff7d..e5d32cb426 100644 --- a/docs/tutorial/tutorial-training/how-to-train-span-classifier.md +++ b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md @@ -154,7 +154,7 @@ from flair.nn.multitask import make_multitask_model_and_corpus # 1. get the corpus ner_corpus = NER_MULTI_WIKINER() -nel_corpus = ZELDA(column_format={0: "text", 2: "ner"}) # need to set the label type to be the same as the ner one +nel_corpus = ZELDA(column_format={0: "text", 2: "nel"}) # need to set the label type to be the same as the ner one # --- Embeddings that are shared by both models --- # shared_embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True) @@ -171,12 +171,13 @@ ner_model = SequenceTagger( ) -nel_label_dict = nel_corpus.make_label_dictionary("ner", add_unk=True) +nel_label_dict = nel_corpus.make_label_dictionary("nel", add_unk=True) nel_model = SpanClassifier( embeddings=shared_embeddings, label_dictionary=nel_label_dict, - label_type="ner", + label_type="nel", + span_label_type="ner", decoder=PrototypicalDecoder( num_prototypes=len(nel_label_dict), embeddings_size=shared_embeddings.embedding_length * 2, # we use "first_last" encoding for spans diff --git a/flair/data.py b/flair/data.py index 77fff1f200..bc35c83c5c 100644 --- a/flair/data.py +++ b/flair/data.py @@ -1363,7 +1363,7 @@ def __init__( test, train = randomly_split_into_two_datasets(train, test_size, random_seed) log.warning( "No test split found. Using %.0f%% (i.e. %d samples) of the train split as test data", - test_portion, + test_portion * 100, test_size, ) @@ -1375,7 +1375,7 @@ def __init__( dev, train = randomly_split_into_two_datasets(train, dev_size, random_seed) log.warning( "No dev split found. Using %.0f%% (i.e. %d samples) of the train split as dev data", - dev_portion, + dev_portion * 100, dev_size, ) diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index 1d716e7904..9f516a703c 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -94,6 +94,7 @@ def __init__( label_dictionary: Dictionary, pooling_operation: str = "first_last", label_type: str = "nel", + span_label_type: Optional[str] = None, candidates: Optional[CandidateGenerator] = None, **classifierargs, ) -> None: @@ -107,6 +108,7 @@ def __init__( text representation we take the average of the embeddings of the token in the mention. `first_last` concatenates the embedding of the first and the embedding of the last token. label_type: name of the label you use. + span_label_type: name of the label you use for inputs of predictions. candidates: If provided, use a :class:`CandidateGenerator` for prediction candidates. **classifierargs: The arguments propagated to :meth:`flair.nn.DefaultClassifier.__init__` """ @@ -121,6 +123,7 @@ def __init__( self.pooling_operation = pooling_operation self._label_type = label_type + self._span_label_type = span_label_type cases: Dict[str, Callable[[Span, List[str]], torch.Tensor]] = { "average": self.emb_mean, @@ -153,9 +156,16 @@ def emb_mean(self, span, embedding_names): return torch.mean(torch.stack([token.get_embedding(embedding_names) for token in span], 0), 0) def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]: + if self._span_label_type is not None: + spans = sentence.get_spans(self._span_label_type) + # only use span label type if there are predictions, otherwise search for output label type (training labels) + if spans: + return spans return sentence.get_spans(self.label_type) def _filter_data_point(self, data_point: Sentence) -> bool: + if self._span_label_type is not None and bool(data_point.get_labels(self._span_label_type)): + return True return bool(data_point.get_labels(self.label_type)) def _get_embedding_for_data_point(self, prediction_data_point: Span) -> torch.Tensor: @@ -170,6 +180,7 @@ def _get_state_dict(self): "pooling_operation": self.pooling_operation, "loss_weights": self.weight_dict, "candidates": self.candidates, + "span_label_type": self._span_label_type, } return model_state