Skip to content

Commit

Permalink
Remove context manager (#228)
Browse files Browse the repository at this point in the history
* fix imports

* add interface for data cleaning pre processing during loading

* add interface for data cleaning pre processing during loading
  • Loading branch information
davebulaval authored Jun 7, 2024
1 parent edd8dfc commit c28f01c
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 38 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,5 @@
- Remove fixed dependencies version.
- Fix app errors.
- Add data validation for 1) multiple consecutive whitespace and 2) newline.
- Fixes some errors in tests.
- Fixes some errors in tests.
- Add an argument to the `DatasetContainer` interface to use a pre-processing data cleaning function before validation.
70 changes: 57 additions & 13 deletions deepparse/dataset_container/dataset_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,35 @@ class DatasetContainer(Dataset, ABC):
- no address is a ``None`` value,
- no address is empty,
- no address is composed of only whitespace,
- no address includes consecutive whitespace (e.g. "An address"),
- no tags list is empty, if data is a list of tuple (``[('an address', ['a_tag', 'another_tag']), ...]``), and
- if the addresses (whitespace-split) are the same length as their respective tags list.
While for a predict container (unknown prediction tag), it validates the following:
- no address is a ``None`` value,
- no address is empty, and
- no address is composed of only whitespace.
- no address is empty,
- no address is composed of only whitespace, and
- no address includes consecutive whitespace (e.g. "An address").
Args:
is_training_container (bool): Either or not, the dataset container is a training container. This will determine
the dataset validation test we apply to the dataset. That is, a predict dataset doesn't include tags.
The default value is ``True``.
data_cleaning_pre_processing_fn (Callable): Function to apply as data clea ning pre-processing step after
loading the data, but before applying the validation steps. The default value is ``None``.
"""

@abstractmethod
def __init__(self, is_training_container: bool = True) -> None:
def __init__(
self, is_training_container: bool = True, data_cleaning_pre_processing_fn: Union[None, Callable] = None
) -> None:
"""
The method to init the class. It needs to be defined by the child's class.
"""
self.data = None
self.is_training_container = is_training_container
self.data_cleaning_pre_processing_fn = data_cleaning_pre_processing_fn

def __len__(self) -> int:
return len(self.data)
Expand Down Expand Up @@ -180,26 +187,37 @@ class PickleDatasetContainer(DatasetContainer):
- no address is a ``None`` value,
- no address is empty,
- no address is composed of only whitespace,
- no address includes consecutive whitespace (e.g. "An address"),
- no tags list is empty, if data is a list of tuple (``[('an address', ['a_tag', 'another_tag']), ...]``), and
- if the addresses (whitespace-split) are the same length as their respective tags list.
While for a predict container (unknown prediction tag), the validation tests applied on the dataset are the
following:
- no address is a ``None`` value,
- no address is empty, and
- no address is composed of only whitespace.
- no address is empty,
- no address is composed of only whitespace, and
- no address includes consecutive whitespace (e.g. "An address").
Args:
data_path (str): The path to the pickle dataset file.
is_training_container (bool): Either or not, the dataset container is a training container. This will determine
the dataset validation test we apply to the dataset. That is, a predict dataset doesn't include tags.
The default value is ``True``.
data_cleaning_pre_processing_fn (Callable): Function to apply as data clea ning pre-processing step after
loading the data, but before applying the validation steps. The default value is ``None``.
"""

def __init__(self, data_path: str, is_training_container: bool = True) -> None:
super().__init__(is_training_container=is_training_container)
def __init__(
self,
data_path: str,
is_training_container: bool = True,
data_cleaning_pre_processing_fn: Union[None, Callable] = None,
) -> None:
super().__init__(
is_training_container=is_training_container, data_cleaning_pre_processing_fn=data_cleaning_pre_processing_fn
)
with open(data_path, "rb") as f:
self.data = load(f)

Expand All @@ -209,6 +227,8 @@ def __init__(self, data_path: str, is_training_container: bool = True) -> None:
"The data is a list of tuples, but the dataset container is a predict container. "
"Predict container should contain only a list of addresses."
)
if self.data_cleaning_pre_processing_fn is not None:
self.data = self.data_cleaning_pre_processing_fn(self.data)

self.validate_dataset()

Expand All @@ -229,15 +249,17 @@ class CSVDatasetContainer(DatasetContainer):
- no address is a ``None`` value,
- no address is empty,
- no address is composed of only whitespace,
- no address includes consecutive whitespace (e.g. "An address"),
- no tags list is empty, if data is a list of tuple (``[('an address', ['a_tag', 'another_tag']), ...]``), and
- if the addresses (whitespace-split) are the same length as their respective tags list.
While for a predict container (unknown prediction tag), the validation tests applied on the dataset are the
following:
- no address is a ``None`` value,
- no address is empty, and
- no address is composed of only whitespace.
- no address is empty,
- no address is composed of only whitespace, and
- no address includes consecutive whitespace (e.g. "An address").
Args:
Expand All @@ -257,7 +279,9 @@ class CSVDatasetContainer(DatasetContainer):
That is, it removes the ``[],`` characters and splits the sequence at each comma (``","``).
csv_reader_kwargs (dict, optional): Keyword arguments to pass to pandas ``read_csv`` use internally. By default,
the ``data_path`` is passed along with our default ``sep`` value ( ``"\\t"``) and the ``"utf-8"`` encoding
format. However, this can be overridded by using this argument again.
format. However, this can be overridden by using this argument again.
data_cleaning_pre_processing_fn (Callable): Function to apply as data clea ning pre-processing step after
loading the data, but before applying the validation steps. The default value is ``None``.
"""

def __init__(
Expand All @@ -268,8 +292,11 @@ def __init__(
separator: str = "\t",
tag_seperator_reformat_fn: Union[None, Callable] = None,
csv_reader_kwargs: Union[None, Dict] = None,
data_cleaning_pre_processing_fn: Union[None, Callable] = None,
) -> None:
super().__init__(is_training_container=is_training_container)
super().__init__(
is_training_container=is_training_container, data_cleaning_pre_processing_fn=data_cleaning_pre_processing_fn
)
if is_training_container:
if isinstance(column_names, str):
raise ValueError(
Expand Down Expand Up @@ -306,6 +333,10 @@ def __init__(
else:
data = [data_point[0] for data_point in pd.read_csv(**csv_reader_kwargs)[column_names].to_numpy()]
self.data = data

if self.data_cleaning_pre_processing_fn is not None:
self.data = self.data_cleaning_pre_processing_fn(self.data)

self.validate_dataset()


Expand All @@ -320,9 +351,22 @@ class ListDatasetContainer(DatasetContainer):
is_training_container (bool): Either or not, the dataset container is a training container. This will determine
the dataset validation test we apply to the dataset. That is, a predict dataset doesn't include tags.
The default value is ``True``.
data_cleaning_pre_processing_fn (Callable): Function to apply as data clea ning pre-processing step after
loading the data, but before applying the validation steps. The default value is ``None``.
"""

def __init__(self, data: List, is_training_container: bool = True) -> None:
super().__init__(is_training_container=is_training_container)
def __init__(
self,
data: List,
is_training_container: bool = True,
data_cleaning_pre_processing_fn: Union[None, Callable] = None,
) -> None:
super().__init__(
is_training_container=is_training_container, data_cleaning_pre_processing_fn=data_cleaning_pre_processing_fn
)
self.data = data

if self.data_cleaning_pre_processing_fn is not None:
self.data = self.data_cleaning_pre_processing_fn(self.data)

self.validate_dataset()
38 changes: 14 additions & 24 deletions deepparse/parser/address_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# It must be due to the complex try, except else case.
# pylint: disable=inconsistent-return-statements

import contextlib
import os
import re
import warnings
Expand All @@ -19,11 +18,7 @@
from torch.optim import SGD
from torch.utils.data import DataLoader, Subset

from ..download_tools import CACHE_PATH
from ..pre_processing.pre_processor_list import PreProcessorList
from ..validations import valid_poutyne_version
from . import formatted_parsed_address
from .capturing import Capturing
from .formatted_parsed_address import FormattedParsedAddress
from .tools import (
get_address_parser_in_directory,
Expand All @@ -39,13 +34,15 @@
from .. import validate_data_to_parse
from ..converter import TagsConverter, DataProcessorFactory, DataPadder
from ..dataset_container import DatasetContainer
from ..download_tools import CACHE_PATH
from ..embeddings_models import EmbeddingsModelFactory
from ..errors import FastTextModelError
from ..metrics import nll_loss, accuracy
from ..network import ModelFactory
from ..pre_processing import coma_cleaning, lower_cleaning, hyphen_cleaning
from ..pre_processing import trailing_whitespace_cleaning, double_whitespaces_cleaning

from ..pre_processing.pre_processor_list import PreProcessorList
from ..validations import valid_poutyne_version
from ..vectorizer import VectorizerFactory
from ..weights_tools import handle_weights_upload

Expand Down Expand Up @@ -791,14 +788,12 @@ def retrain(
verbose = self.verbose

try:
with_capturing_context = False
if not valid_poutyne_version(min_major=1, min_minor=8):
print(
raise ImportError(
"You are using an older version of Poutyne that does not support proper error management."
" Due to that, we cannot show retrain progress. To fix that, update Poutyne to "
" Due to that, we cannot show retrain progress. To fix that, please update Poutyne to "
"the newest version."
)
with_capturing_context = True
train_res = self._retrain(
experiment=exp,
train_generator=train_generator,
Expand All @@ -807,7 +802,6 @@ def retrain(
seed=seed,
callbacks=callbacks,
disable_tensorboard=disable_tensorboard,
capturing_context=with_capturing_context,
verbose=verbose,
)
except RuntimeError as error:
Expand Down Expand Up @@ -1197,22 +1191,18 @@ def _retrain(
seed: int,
callbacks: List,
disable_tensorboard: bool,
capturing_context: bool,
verbose: Union[None, bool],
) -> List[Dict]:
# pylint: disable=too-many-arguments
# If Poutyne 1.7 and before, we capture poutyne print since it prints some exception.
# Otherwise, we use a null context manager.
with Capturing() if capturing_context else contextlib.nullcontext():
train_res = experiment.train(
train_generator,
valid_generator=valid_generator,
epochs=epochs,
seed=seed,
callbacks=callbacks,
disable_tensorboard=disable_tensorboard,
verbose=verbose,
)
train_res = experiment.train(
train_generator,
valid_generator=valid_generator,
epochs=epochs,
seed=seed,
callbacks=callbacks,
disable_tensorboard=disable_tensorboard,
verbose=verbose,
)
return train_res

def _freeze_model_params(self, layers_to_freeze: Union[str]) -> None:
Expand Down

0 comments on commit c28f01c

Please sign in to comment.