diff --git a/flair/class_utils.py b/flair/class_utils.py index 9aa95cd1ee..ab0b11bb69 100644 --- a/flair/class_utils.py +++ b/flair/class_utils.py @@ -1,11 +1,15 @@ import importlib import inspect from types import ModuleType -from typing import Any, Iterable, List, Optional, Type, TypeVar, Union, overload +from typing import Any, Iterable, List, Optional, Protocol, Type, TypeVar, Union, overload T = TypeVar("T") +class StringLike(Protocol): + def __str__(self) -> str: ... + + def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]: for subclass in cls.__subclasses__(): yield from get_non_abstract_subclasses(subclass) diff --git a/flair/training_utils.py b/flair/training_utils.py index 2c3ce9d5f4..aacac495c7 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -1,21 +1,26 @@ import logging +import pathlib import random from collections import defaultdict from enum import Enum from functools import reduce from math import inf from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Literal, NamedTuple, Optional, Union +from numpy import ndarray from scipy.stats import pearsonr, spearmanr +from scipy.stats._stats_py import PearsonRResult, SignificanceResult from sklearn.metrics import mean_absolute_error, mean_squared_error from torch.optim import Optimizer from torch.utils.data import Dataset import flair -from flair.data import DT, Dictionary, Sentence, _iter_dataset +from flair.class_utils import StringLike +from flair.data import DT, Dictionary, Sentence, Token, _iter_dataset -log = logging.getLogger("flair") +MinMax = Literal["min", "max"] +logger = logging.getLogger("flair") class Result: @@ -23,15 +28,15 @@ def __init__( self, main_score: float, detailed_results: str, - classification_report: dict = {}, - scores: dict = {}, + classification_report: Optional[Dict] = None, + scores: Optional[Dict] = None, ) -> None: - assert "loss" in scores, "No loss provided." + assert scores is not None and "loss" in scores, "No loss provided." self.main_score: float = main_score self.scores = scores self.detailed_results: str = detailed_results - self.classification_report = classification_report + self.classification_report = classification_report if classification_report is not None else {} @property def loss(self): @@ -42,40 +47,36 @@ def __str__(self) -> str: class MetricRegression: - def __init__(self, name) -> None: + def __init__(self, name: str) -> None: self.name = name self.true: List[float] = [] self.pred: List[float] = [] - def mean_squared_error(self): + def mean_squared_error(self) -> Union[float, ndarray]: return mean_squared_error(self.true, self.pred) def mean_absolute_error(self): return mean_absolute_error(self.true, self.pred) - def pearsonr(self): + def pearsonr(self) -> PearsonRResult: return pearsonr(self.true, self.pred)[0] - def spearmanr(self): + def spearmanr(self) -> SignificanceResult: return spearmanr(self.true, self.pred)[0] - # dummy return to fulfill trainer.train() needs - def micro_avg_f_score(self): - return self.mean_squared_error() - - def to_tsv(self): + def to_tsv(self) -> str: return f"{self.mean_squared_error()}\t{self.mean_absolute_error()}\t{self.pearsonr()}\t{self.spearmanr()}" @staticmethod - def tsv_header(prefix=None): + def tsv_header(prefix: StringLike = None) -> str: if prefix: return f"{prefix}_MEAN_SQUARED_ERROR\t{prefix}_MEAN_ABSOLUTE_ERROR\t{prefix}_PEARSON\t{prefix}_SPEARMAN" return "MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN" @staticmethod - def to_empty_tsv(): + def to_empty_tsv() -> str: return "\t_\t_\t_\t_" def __str__(self) -> str: @@ -99,13 +100,13 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) -> self.weights_dict: Dict[str, Dict[int, List[float]]] = defaultdict(lambda: defaultdict(list)) self.number_of_weights = number_of_weights - def extract_weights(self, state_dict, iteration): + def extract_weights(self, state_dict: Dict, iteration: int) -> None: for key in state_dict: vec = state_dict[key] - # print(vec) try: weights_to_watch = min(self.number_of_weights, reduce(lambda x, y: x * y, list(vec.size()))) - except Exception: + except Exception as e: + logger.debug(e) continue if key not in self.weights_dict: @@ -193,15 +194,15 @@ class AnnealOnPlateau: def __init__( self, optimizer, - mode="min", - aux_mode="min", - factor=0.1, - patience=10, - initial_extra_patience=0, - verbose=False, - cooldown=0, - min_lr=0, - eps=1e-8, + mode: MinMax = "min", + aux_mode: MinMax = "min", + factor: float = 0.1, + patience: int = 10, + initial_extra_patience: int = 0, + verbose: bool = False, + cooldown: int = 0, + min_lr: float = 0.0, + eps: float = 1e-8, ) -> None: if factor >= 1.0: raise ValueError("Factor should be < 1.0.") @@ -212,6 +213,7 @@ def __init__( raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") self.optimizer = optimizer + self.min_lrs: List[float] if isinstance(min_lr, (list, tuple)): if len(min_lr) != len(optimizer.param_groups): raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}") @@ -229,7 +231,7 @@ def __init__( self.best = None self.best_aux = None self.num_bad_epochs = None - self.mode_worse = None # the worse value for the chosen mode + self.mode_worse: Optional[float] = None # the worse value for the chosen mode self.eps = eps self.last_epoch = 0 self._init_is_better(mode=mode) @@ -256,7 +258,7 @@ def step(self, metric, auxiliary_metric=None) -> bool: if self.mode == "max" and current > self.best: is_better = True - if current == self.best and auxiliary_metric: + if current == self.best and auxiliary_metric is not None: current_aux = float(auxiliary_metric) if self.aux_mode == "min" and current_aux < self.best_aux: is_better = True @@ -287,20 +289,20 @@ def step(self, metric, auxiliary_metric=None) -> bool: return reduce_learning_rate - def _reduce_lr(self, epoch): + def _reduce_lr(self, epoch: int) -> None: for i, param_group in enumerate(self.optimizer.param_groups): old_lr = float(param_group["lr"]) new_lr = max(old_lr * self.factor, self.min_lrs[i]) if old_lr - new_lr > self.eps: param_group["lr"] = new_lr if self.verbose: - log.info(f" - reducing learning rate of group {epoch} to {new_lr}") + logger.info(f" - reducing learning rate of group {epoch} to {new_lr}") @property def in_cooldown(self): return self.cooldown_counter > 0 - def _init_is_better(self, mode): + def _init_is_better(self, mode: MinMax) -> None: if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") @@ -311,10 +313,10 @@ def _init_is_better(self, mode): self.mode = mode - def state_dict(self): + def state_dict(self) -> Dict: return {key: value for key, value in self.__dict__.items() if key != "optimizer"} - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Dict) -> None: self.__dict__.update(state_dict) self._init_is_better(mode=self.mode) @@ -348,11 +350,11 @@ def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionar return [[1 if label in labels else 0 for label in label_dict.get_items()] for labels in label_list] -def log_line(log): +def log_line(log: logging.Logger) -> None: log.info("-" * 100, stacklevel=3) -def add_file_handler(log, output_file): +def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging.FileHandler: init_output_file(output_file.parents[0], output_file.name) fh = logging.FileHandler(output_file, mode="w", encoding="utf-8") fh.setLevel(logging.INFO) @@ -363,12 +365,21 @@ def add_file_handler(log, output_file): def store_embeddings( - data_points: Union[List[DT], Dataset], storage_mode: str, dynamic_embeddings: Optional[List[str]] = None -): + data_points: Union[List[DT], Dataset], + storage_mode: str, + dynamic_embeddings: Optional[List[str]] = None, +) -> None: + """Stores embeddings of data points in memory or on disk. + + Args: + data_points: a DataSet or list of DataPoints for which embeddings should be stored + storage_mode: store in either CPU or GPU memory, or delete them if set to 'none' + dynamic_embeddings: these are always deleted. If not passed, they are identified automatically. + """ if isinstance(data_points, Dataset): data_points = list(_iter_dataset(data_points)) - # if memory mode option 'none' delete everything + # if storage mode option 'none' delete everything if storage_mode == "none": dynamic_embeddings = None @@ -387,7 +398,7 @@ def store_embeddings( data_point.to("cpu", pin_memory=pin_memory) -def identify_dynamic_embeddings(data_points: List[DT]): +def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List[str]]: dynamic_embeddings = [] all_embeddings = [] for data_point in data_points: @@ -407,3 +418,130 @@ def identify_dynamic_embeddings(data_points: List[DT]): if not all_embeddings: return None return list(set(dynamic_embeddings)) + + +class TokenEntity(NamedTuple): + """Entity represented by token indices.""" + + start_token_idx: int + end_token_idx: int + label: str + value: str = "" # text value of the entity + score: float = 1.0 + + +class CharEntity(NamedTuple): + """Entity represented by character indices.""" + + start_char_idx: int + end_char_idx: int + label: str + value: str + score: float = 1.0 + + +def create_labeled_sentence_from_tokens( + tokens: Union[List[Token]], token_entities: List[TokenEntity], type_name: str = "ner" +) -> Sentence: + """Creates a new Sentence object from a list of tokens or strings and applies entity labels. + + Tokens are recreated with the same text, but not attached to the previous sentence. + + Args: + tokens: a list of Token objects or strings - only the text is used, not any labels + token_entities: a list of TokenEntity objects representing entity annotations + type_name: the type of entity label to apply + Returns: + A labeled Sentence object + """ + tokens = [Token(token.text) for token in tokens] # create new tokens that do not already belong to a sentence + sentence = Sentence(tokens, use_tokenizer=True) + for entity in token_entities: + sentence[entity.start_token_idx : entity.end_token_idx].add_label(type_name, entity.label, score=entity.score) + return sentence + + +def create_flair_sentence( + text: str, + entities: List[CharEntity], + token_limit: int = 512, + use_context: bool = True, + overlap: int = 0, # TODO: implement overlap +) -> List[Sentence]: + """Constructs a Flair Sentence from text and a list of entity annotations. + + The function explicitly tokenizes the text and labels separately, ensuring entity labels are + not partially split across tokens. + + Args: + text (str): The full text to be tokenized and labeled. + entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the + format (start_char_index, end_char_index, entity_class, entity_text). + token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking + use_context: whether to add context to the sentence + overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context + + Returns: + A list of labeled Sentence objects representing the chunks of the original text + """ + chunks = [] + + tokens: List[Token] = [] + current_index = 0 + token_entities: List[TokenEntity] = [] + end_token_idx = 0 + + for entity in entities: + + if entity.start_char_idx > current_index: # add non-entity text + non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens + while end_token_idx + len(non_entity_tokens) > token_limit: + num_tokens = token_limit - len(tokens) + tokens.extend(non_entity_tokens[:num_tokens]) + non_entity_tokens = non_entity_tokens[num_tokens:] + # skip any fully negative samples, they cause fine_tune to fail with + # `torch.cat(): expected a non-empty list of Tensors` + if len(token_entities) > 0: + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + tokens, token_entities = [], [] + end_token_idx = 0 + tokens.extend(non_entity_tokens) + + # add new entity tokens + start_token_idx = len(tokens) + entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx]) + if len(entity_sentence) > token_limit: + logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}") + end_token_idx = start_token_idx + len(entity_sentence) + + if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + + tokens, token_entities = [], [] + start_token_idx, end_token_idx = 0, len(entity_sentence) + + token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value) + token_entities.append(token_entity) + tokens.extend(entity_sentence) + + current_index = entity.end_char_idx + + # add any remaining tokens to a new chunk + if current_index < len(text): + remaining_sentence = Sentence(text[current_index:]) + if end_token_idx + len(remaining_sentence) > token_limit: + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + tokens, token_entities = [], [] + tokens.extend(remaining_sentence) + + if tokens: + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + + for chunk in chunks: + if len(chunk) > token_limit: + logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}") + + if use_context: + Sentence.set_context_for_sentences(chunks) + + return chunks diff --git a/requirements-dev.txt b/requirements-dev.txt index 61d45acf8c..7f8798c35c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,4 +12,4 @@ types-Deprecated>=1.2.9.2 types-requests>=2.28.11.17 types-tabulate>=0.9.0.2 pyab3p -transformers!=4.40.1,!=4.40.0 \ No newline at end of file +transformers!=4.40.1,!=4.40.0 diff --git a/tests/resources/text_sequences/resume1.txt b/tests/resources/text_sequences/resume1.txt new file mode 100644 index 0000000000..6be7107559 --- /dev/null +++ b/tests/resources/text_sequences/resume1.txt @@ -0,0 +1,85 @@ +Nate Diaz +_____________________________________________________________________ +         5151 Morada Lane Stockton, CA | (555) 445-6487| lkji8765@ziprecruiter.com + +Summary +____________________________________________________________________________ + +Reliable, service-focused nursing professional with excellent patient-care and charting skills gained through five years of experience as a CNA. Compassionate and technically skilled in attending to patients in diverse healthcare settings. BLS and CPR certified (current). + +Highlights +____________________________________________________________________________ + +· Patient Care & Safety +· Diagnostic Testing +· Customer loyalty +· Medical Terminology +· Privacy / HIPAA Regulations +Experience +____________________________________________________________________________ +ABC STAFFING COMPANY +Certified Nursing Assistant 2012- Present +Stockton, CA + +· Provide high-quality patient care as an in-demand per-diem CNA within surgical, acute-care, rehabilitation, home-healthcare and nursing-home settings. +· Preserve patient dignity and minimize discomfort while carrying out duties such as bedpan changes, diapering, emptying drainage bags and bathing. +· Commended for chart accuracy, effective team collaboration, patient relations and consistent delivery of empathetic care. + +DEF REHABILITATION COMPANY 2010-2012 +Certified Nursing Assistant Stockton, CA + +· Displayed strong clinical skills in assessing vital signs, performing lab draws and glucose checks, and providing pre- and post-operative care. +· Adhered to safety guidelines; completed hospital’s three-hour Patient Safety Training Program. +· Ensured the accurate, timely flow of information by maintaining thorough patient records and updating healthcare team on patients’ status. +· Complied with HIPAA standards in all patient documentation and interactions. + +Education + +____________________________________________________________________________ + +Nurse’s Aid Program- Completed 60 hours of clinical training +Arizona State University +Tempe, AZ + +Nate Diaz +_____________________________________________________________________ +         5151 Morada Lane Stockton, CA | (555) 445-6487| lkji8765@ziprecruiter.com + +Summary +____________________________________________________________________________ + +Reliable, service-focused nursing professional with excellent patient-care and charting skills gained through five years of experience as a CNA. Compassionate and technically skilled in attending to patients in diverse healthcare settings. BLS and CPR certified (current). + +Highlights +____________________________________________________________________________ + +· Patient Care & Safety +· Diagnostic Testing +· Customer loyalty +· Medical Terminology +· Privacy / HIPAA Regulations +Experience +____________________________________________________________________________ +ABC STAFFING COMPANY +Certified Nursing Assistant 2012- Present +Stockton, CA + +· Provide high-quality patient care as an in-demand per-diem CNA within surgical, acute-care, rehabilitation, home-healthcare and nursing-home settings. +· Preserve patient dignity and minimize discomfort while carrying out duties such as bedpan changes, diapering, emptying drainage bags and bathing. +· Commended for chart accuracy, effective team collaboration, patient relations and consistent delivery of empathetic care. + +DEF REHABILITATION COMPANY 2010-2012 +Certified Nursing Assistant Stockton, CA + +· Displayed strong clinical skills in assessing vital signs, performing lab draws and glucose checks, and providing pre- and post-operative care. +· Adhered to safety guidelines; completed hospital’s three-hour Patient Safety Training Program. +· Ensured the accurate, timely flow of information by maintaining thorough patient records and updating healthcare team on patients’ status. +· Complied with HIPAA standards in all patient documentation and interactions. + +Education + +____________________________________________________________________________ + +Nurse’s Aid Program- Completed 60 hours of clinical training +Arizona State University +Tempe, AZ \ No newline at end of file diff --git a/tests/test_sentence_chunking.py b/tests/test_sentence_chunking.py new file mode 100644 index 0000000000..dbdfdb6976 --- /dev/null +++ b/tests/test_sentence_chunking.py @@ -0,0 +1,191 @@ +from typing import Dict, List + +import pytest + +from flair.data import Sentence +from flair.training_utils import CharEntity, create_flair_sentence + + +@pytest.fixture(params=["resume1.txt"]) +def resume(request, resources_path) -> str: + filepath = resources_path / "text_sequences" / request.param + with open(filepath, encoding="utf8") as file: + text_content = file.read() + return text_content + + +@pytest.fixture +def parsed_resume_dict(resume) -> dict: + return { + "raw_text": resume, + "entities": [ + CharEntity(20, 40, "dummy_label1", "Dummy Text 1"), + CharEntity(250, 300, "dummy_label2", "Dummy Text 2"), + CharEntity(700, 810, "dummy_label3", "Dummy Text 3"), + CharEntity(3900, 4000, "dummy_label4", "Dummy Text 4"), + ], + } + + +@pytest.fixture +def small_token_limit_resume() -> Dict: + return { + "raw_text": "Professional Clown June 2020 - August 2021 Entertaining students of all ages. Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Gained " + "proficiency in juggling and scaring children.", + "entities": [ + CharEntity(0, 18, "EXPERIENCE.TITLE", ""), + CharEntity(19, 29, "DATE.START_DATE", ""), + CharEntity(31, 42, "DATE.END_DATE", ""), + CharEntity(450, 510, "EXPERIENCE.DESCRIPTION", ""), + ], + } + + +@pytest.fixture +def small_token_limit_response() -> List[Sentence]: + """Recreates expected response Sentences.""" + chunk0 = Sentence("Professional Clown June 2020 - August 2021 Entertaining students of") + chunk0[0:2].add_label("Professional Clown", "EXPERIENCE.TITLE") + chunk0[2:4].add_label("June 2020", "DATE.START_DATE") + chunk0[5:7].add_label("August 2021", "DATE.END_DATE") + + chunk1 = Sentence("Blah Blah Blah Blah Blah Blah Blah Bl") + + chunk2 = Sentence("ah Blah Gained proficiency in juggling and scaring children .") + chunk2[0:10].add_label("ah Blah Gained proficiency in juggling and scaring children .", "EXPERIENCE.DESCRIPTION") + + return [chunk0, chunk1, chunk2] + + +class TestChunking: + def test_empty_string(self): + sentences = create_flair_sentence("", []) + assert len(sentences) == 0 + + def check_split_entities(self, entity_labels, chunks, max_token_limit): + """Ensure that no entities are split over chunks (except entities longer than the token limit).""" + chunk_intervals = [] + start_index = 0 + for chunk in chunks: + end_index = start_index + len(chunk.text) + chunk_intervals.append((start_index, end_index)) + start_index = end_index + + for entity in entity_labels: + entity_start, entity_end = entity.start_char_idx, entity.end_char_idx + entity_length = entity_end - entity_start + + # Skip the check if the entity itself is longer than the maximum token limit + if entity_length > max_token_limit: + continue + + assert any( + start <= entity_start and entity_end <= end for start, end in chunk_intervals + ), f"Entity {entity} is not within a single chunk interval" + + @pytest.mark.parametrize( + "test_text, expected_text", + [ + ("test text", "test text"), + ("a", "a"), + ("this ", "this"), + ], + ) + def test_short_text(self, test_text, expected_text): + """Short texts that should fit nicely into a single chunk.""" + chunks = create_flair_sentence(test_text, []) + assert chunks[0].text == expected_text + + def test_create_flair_sentence(self, parsed_resume_dict): + chunks = create_flair_sentence(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) + assert len(chunks) == 2 + + max_token_limit = 512 # default + assert all(len(c) <= max_token_limit for c in chunks) + + self.check_split_entities(parsed_resume_dict["entities"], chunks, max_token_limit) + + def test_small_token_limit(self, small_token_limit_resume, small_token_limit_response): + max_token_limit = 10 # test a small max token limit + chunks = create_flair_sentence( + small_token_limit_resume["raw_text"], small_token_limit_resume["entities"], token_limit=max_token_limit + ) + + for response, expected in zip(chunks, small_token_limit_response): + assert response.to_tagged_string() == expected.to_tagged_string() + + assert all(len(c) <= max_token_limit for c in chunks) + + self.check_split_entities(small_token_limit_resume["entities"], chunks, max_token_limit) + + @pytest.mark.parametrize( + "test_text, entities, expected_chunks", + [ + ( + "Led a team of five engineers. It's important to note the project's success. We've implemented state-of-the-art technologies. Co-ordinated efforts with cross-functional teams.", + [ + CharEntity(0, 25, "RESPONSIBILITY", "Led a team of five engineers"), + CharEntity(27, 72, "ACHIEVEMENT", "It's important to note the project's success"), + CharEntity(74, 117, "ACHIEVEMENT", "We've implemented state-of-the-art technologies"), + CharEntity(119, 168, "RESPONSIBILITY", "Co-ordinated efforts with cross-functional teams"), + ], + [ + "Led a team of five engine er s. It 's important to note the project 's succe ss", + ". We 've implemented state-of-the-art techno lo gies . Co-ordinated efforts with cross-functional teams .", + ], + ), + ], + ) + def test_contractions_and_hyphens(self, test_text, entities, expected_chunks): + max_token_limit = 20 + chunks = create_flair_sentence(test_text, entities, max_token_limit) + for i, chunk in enumerate(expected_chunks): + assert chunks[i].text == chunk + self.check_split_entities(entities, chunks, max_token_limit) + + @pytest.mark.parametrize( + "test_text, entities", + [ + ( + "This is a long text. " * 100, + [CharEntity(0, 1000, "dummy_label1", "Dummy Text 1")], + ) + ], + ) + def test_long_text(self, test_text, entities): + """Test for handling long texts that should be split into multiple chunks.""" + max_token_limit = 512 + chunks = create_flair_sentence(test_text, entities, max_token_limit) + assert len(chunks) > 1 + assert all(len(c) <= max_token_limit for c in chunks) + self.check_split_entities(entities, chunks, max_token_limit) + + @pytest.mark.parametrize( + "test_text, entities, expected_chunks", + [ + ( + "Hello! Is your company hiring? I am available for employment. Contact me at 5:00 p.m.", + [ + CharEntity(0, 6, "LABEL", "Hello!"), + CharEntity(7, 31, "LABEL", "Is your company hiring?"), + CharEntity(32, 65, "LABEL", "I am available for employment."), + CharEntity(66, 86, "LABEL", "Contact me at 5:00 p.m."), + ], + [ + "Hello ! Is your company hiring ? I", + "am available for employment . Con t", + "act me at 5:00 p.m .", + ], + ) + ], + ) + def test_text_with_punctuation(self, test_text, entities, expected_chunks): + max_token_limit = 10 + chunks = create_flair_sentence(test_text, entities, max_token_limit) + for i, chunk in enumerate(expected_chunks): + assert chunks[i].text == chunk + self.check_split_entities(entities, chunks, max_token_limit)