From c28f01c1108efec3c442a6122269ac47395db3c4 Mon Sep 17 00:00:00 2001 From: David Beauchemin Date: Fri, 7 Jun 2024 07:56:09 -0400 Subject: [PATCH] Remove context manager (#228) * fix imports * add interface for data cleaning pre processing during loading * add interface for data cleaning pre processing during loading --- CHANGELOG.md | 3 +- .../dataset_container/dataset_container.py | 70 +++++++++++++++---- deepparse/parser/address_parser.py | 38 ++++------ 3 files changed, 73 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13912f20..d0392211 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -344,4 +344,5 @@ - Remove fixed dependencies version. - Fix app errors. - Add data validation for 1) multiple consecutive whitespace and 2) newline. -- Fixes some errors in tests. \ No newline at end of file +- Fixes some errors in tests. +- Add an argument to the `DatasetContainer` interface to use a pre-processing data cleaning function before validation. \ No newline at end of file diff --git a/deepparse/dataset_container/dataset_container.py b/deepparse/dataset_container/dataset_container.py index e8ccb6e1..cd4db5ea 100644 --- a/deepparse/dataset_container/dataset_container.py +++ b/deepparse/dataset_container/dataset_container.py @@ -30,28 +30,35 @@ class DatasetContainer(Dataset, ABC): - no address is a ``None`` value, - no address is empty, - no address is composed of only whitespace, + - no address includes consecutive whitespace (e.g. "An address"), - no tags list is empty, if data is a list of tuple (``[('an address', ['a_tag', 'another_tag']), ...]``), and - if the addresses (whitespace-split) are the same length as their respective tags list. While for a predict container (unknown prediction tag), it validates the following: - no address is a ``None`` value, - - no address is empty, and - - no address is composed of only whitespace. + - no address is empty, + - no address is composed of only whitespace, and + - no address includes consecutive whitespace (e.g. "An address"). Args: is_training_container (bool): Either or not, the dataset container is a training container. This will determine the dataset validation test we apply to the dataset. That is, a predict dataset doesn't include tags. The default value is ``True``. + data_cleaning_pre_processing_fn (Callable): Function to apply as data clea ning pre-processing step after + loading the data, but before applying the validation steps. The default value is ``None``. """ @abstractmethod - def __init__(self, is_training_container: bool = True) -> None: + def __init__( + self, is_training_container: bool = True, data_cleaning_pre_processing_fn: Union[None, Callable] = None + ) -> None: """ The method to init the class. It needs to be defined by the child's class. """ self.data = None self.is_training_container = is_training_container + self.data_cleaning_pre_processing_fn = data_cleaning_pre_processing_fn def __len__(self) -> int: return len(self.data) @@ -180,6 +187,7 @@ class PickleDatasetContainer(DatasetContainer): - no address is a ``None`` value, - no address is empty, - no address is composed of only whitespace, + - no address includes consecutive whitespace (e.g. "An address"), - no tags list is empty, if data is a list of tuple (``[('an address', ['a_tag', 'another_tag']), ...]``), and - if the addresses (whitespace-split) are the same length as their respective tags list. @@ -187,19 +195,29 @@ class PickleDatasetContainer(DatasetContainer): following: - no address is a ``None`` value, - - no address is empty, and - - no address is composed of only whitespace. + - no address is empty, + - no address is composed of only whitespace, and + - no address includes consecutive whitespace (e.g. "An address"). Args: data_path (str): The path to the pickle dataset file. is_training_container (bool): Either or not, the dataset container is a training container. This will determine the dataset validation test we apply to the dataset. That is, a predict dataset doesn't include tags. The default value is ``True``. + data_cleaning_pre_processing_fn (Callable): Function to apply as data clea ning pre-processing step after + loading the data, but before applying the validation steps. The default value is ``None``. """ - def __init__(self, data_path: str, is_training_container: bool = True) -> None: - super().__init__(is_training_container=is_training_container) + def __init__( + self, + data_path: str, + is_training_container: bool = True, + data_cleaning_pre_processing_fn: Union[None, Callable] = None, + ) -> None: + super().__init__( + is_training_container=is_training_container, data_cleaning_pre_processing_fn=data_cleaning_pre_processing_fn + ) with open(data_path, "rb") as f: self.data = load(f) @@ -209,6 +227,8 @@ def __init__(self, data_path: str, is_training_container: bool = True) -> None: "The data is a list of tuples, but the dataset container is a predict container. " "Predict container should contain only a list of addresses." ) + if self.data_cleaning_pre_processing_fn is not None: + self.data = self.data_cleaning_pre_processing_fn(self.data) self.validate_dataset() @@ -229,6 +249,7 @@ class CSVDatasetContainer(DatasetContainer): - no address is a ``None`` value, - no address is empty, - no address is composed of only whitespace, + - no address includes consecutive whitespace (e.g. "An address"), - no tags list is empty, if data is a list of tuple (``[('an address', ['a_tag', 'another_tag']), ...]``), and - if the addresses (whitespace-split) are the same length as their respective tags list. @@ -236,8 +257,9 @@ class CSVDatasetContainer(DatasetContainer): following: - no address is a ``None`` value, - - no address is empty, and - - no address is composed of only whitespace. + - no address is empty, + - no address is composed of only whitespace, and + - no address includes consecutive whitespace (e.g. "An address"). Args: @@ -257,7 +279,9 @@ class CSVDatasetContainer(DatasetContainer): That is, it removes the ``[],`` characters and splits the sequence at each comma (``","``). csv_reader_kwargs (dict, optional): Keyword arguments to pass to pandas ``read_csv`` use internally. By default, the ``data_path`` is passed along with our default ``sep`` value ( ``"\\t"``) and the ``"utf-8"`` encoding - format. However, this can be overridded by using this argument again. + format. However, this can be overridden by using this argument again. + data_cleaning_pre_processing_fn (Callable): Function to apply as data clea ning pre-processing step after + loading the data, but before applying the validation steps. The default value is ``None``. """ def __init__( @@ -268,8 +292,11 @@ def __init__( separator: str = "\t", tag_seperator_reformat_fn: Union[None, Callable] = None, csv_reader_kwargs: Union[None, Dict] = None, + data_cleaning_pre_processing_fn: Union[None, Callable] = None, ) -> None: - super().__init__(is_training_container=is_training_container) + super().__init__( + is_training_container=is_training_container, data_cleaning_pre_processing_fn=data_cleaning_pre_processing_fn + ) if is_training_container: if isinstance(column_names, str): raise ValueError( @@ -306,6 +333,10 @@ def __init__( else: data = [data_point[0] for data_point in pd.read_csv(**csv_reader_kwargs)[column_names].to_numpy()] self.data = data + + if self.data_cleaning_pre_processing_fn is not None: + self.data = self.data_cleaning_pre_processing_fn(self.data) + self.validate_dataset() @@ -320,9 +351,22 @@ class ListDatasetContainer(DatasetContainer): is_training_container (bool): Either or not, the dataset container is a training container. This will determine the dataset validation test we apply to the dataset. That is, a predict dataset doesn't include tags. The default value is ``True``. + data_cleaning_pre_processing_fn (Callable): Function to apply as data clea ning pre-processing step after + loading the data, but before applying the validation steps. The default value is ``None``. """ - def __init__(self, data: List, is_training_container: bool = True) -> None: - super().__init__(is_training_container=is_training_container) + def __init__( + self, + data: List, + is_training_container: bool = True, + data_cleaning_pre_processing_fn: Union[None, Callable] = None, + ) -> None: + super().__init__( + is_training_container=is_training_container, data_cleaning_pre_processing_fn=data_cleaning_pre_processing_fn + ) self.data = data + + if self.data_cleaning_pre_processing_fn is not None: + self.data = self.data_cleaning_pre_processing_fn(self.data) + self.validate_dataset() diff --git a/deepparse/parser/address_parser.py b/deepparse/parser/address_parser.py index da30f2b2..2b9bf789 100644 --- a/deepparse/parser/address_parser.py +++ b/deepparse/parser/address_parser.py @@ -4,7 +4,6 @@ # It must be due to the complex try, except else case. # pylint: disable=inconsistent-return-statements -import contextlib import os import re import warnings @@ -19,11 +18,7 @@ from torch.optim import SGD from torch.utils.data import DataLoader, Subset -from ..download_tools import CACHE_PATH -from ..pre_processing.pre_processor_list import PreProcessorList -from ..validations import valid_poutyne_version from . import formatted_parsed_address -from .capturing import Capturing from .formatted_parsed_address import FormattedParsedAddress from .tools import ( get_address_parser_in_directory, @@ -39,13 +34,15 @@ from .. import validate_data_to_parse from ..converter import TagsConverter, DataProcessorFactory, DataPadder from ..dataset_container import DatasetContainer +from ..download_tools import CACHE_PATH from ..embeddings_models import EmbeddingsModelFactory from ..errors import FastTextModelError from ..metrics import nll_loss, accuracy from ..network import ModelFactory from ..pre_processing import coma_cleaning, lower_cleaning, hyphen_cleaning from ..pre_processing import trailing_whitespace_cleaning, double_whitespaces_cleaning - +from ..pre_processing.pre_processor_list import PreProcessorList +from ..validations import valid_poutyne_version from ..vectorizer import VectorizerFactory from ..weights_tools import handle_weights_upload @@ -791,14 +788,12 @@ def retrain( verbose = self.verbose try: - with_capturing_context = False if not valid_poutyne_version(min_major=1, min_minor=8): - print( + raise ImportError( "You are using an older version of Poutyne that does not support proper error management." - " Due to that, we cannot show retrain progress. To fix that, update Poutyne to " + " Due to that, we cannot show retrain progress. To fix that, please update Poutyne to " "the newest version." ) - with_capturing_context = True train_res = self._retrain( experiment=exp, train_generator=train_generator, @@ -807,7 +802,6 @@ def retrain( seed=seed, callbacks=callbacks, disable_tensorboard=disable_tensorboard, - capturing_context=with_capturing_context, verbose=verbose, ) except RuntimeError as error: @@ -1197,22 +1191,18 @@ def _retrain( seed: int, callbacks: List, disable_tensorboard: bool, - capturing_context: bool, verbose: Union[None, bool], ) -> List[Dict]: # pylint: disable=too-many-arguments - # If Poutyne 1.7 and before, we capture poutyne print since it prints some exception. - # Otherwise, we use a null context manager. - with Capturing() if capturing_context else contextlib.nullcontext(): - train_res = experiment.train( - train_generator, - valid_generator=valid_generator, - epochs=epochs, - seed=seed, - callbacks=callbacks, - disable_tensorboard=disable_tensorboard, - verbose=verbose, - ) + train_res = experiment.train( + train_generator, + valid_generator=valid_generator, + epochs=epochs, + seed=seed, + callbacks=callbacks, + disable_tensorboard=disable_tensorboard, + verbose=verbose, + ) return train_res def _freeze_model_params(self, layers_to_freeze: Union[str]) -> None: