Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deeppavlov fixes #1481

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
2 changes: 1 addition & 1 deletion deeppavlov/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def flatten_str_batch(batch: Union[str, Iterable]) -> Union[list, chain]:
['a', 'b', 'c', 'd']

"""
if isinstance(batch, str):
if isinstance(batch, str) or isinstance(batch, str) or isinstance(batch, int) or isinstance(batch, float):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Twice str is mentioned.
Consider replacing isinstance with type.

return [batch]
else:
return chain(*[flatten_str_batch(sample) for sample in batch])
Expand Down
5 changes: 4 additions & 1 deletion deeppavlov/core/trainers/nn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from pathlib import Path
from typing import List, Tuple, Union, Optional, Iterable

from tqdm import tqdm

from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.common.registry import register
from deeppavlov.core.data.data_learning_iterator import DataLearningIterator
Expand Down Expand Up @@ -279,7 +281,8 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None:
while True:
impatient = False
self._send_event(event_name='before_train')
for x, y_true in iterator.gen_batches(self.batch_size, data_type='train'):
log.info('The model training started')
for x, y_true in tqdm(iterator.gen_batches(self.batch_size, data_type='train')):
self.last_result = self._chainer.train_on_batch(x, y_true)
if self.last_result is None:
self.last_result = {}
Expand Down
42 changes: 25 additions & 17 deletions deeppavlov/dataset_readers/basic_classification_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class BasicClassificationDatasetReader(DatasetReader):
@overrides
def read(self, data_path: str, url: str = None,
format: str = "csv", class_sep: str = None,
float_labels: bool = False,
*args, **kwargs) -> dict:
"""
Read dataset from data_path directory.
Expand All @@ -48,6 +49,8 @@ def read(self, data_path: str, url: str = None,
format: extension of files. Set of Values: ``"csv", "json"``
class_sep: string separator of labels in column with labels
sep (str): delimeter for ``"csv"`` files. Default: None -> only one class per sample
float_labels (boolean): if True and class_sep is not None, we treat all classes as float
quotechar (str): what char we consider as quote in the dataset
header (int): row number to use as the column names
names (array): list of column names to use
orient (str): indication of expected JSON string format
Expand Down Expand Up @@ -80,7 +83,7 @@ def read(self, data_path: str, url: str = None,
file = Path(data_path).joinpath(file_name)
if file.exists():
if format == 'csv':
keys = ('sep', 'header', 'names')
keys = ('sep', 'header', 'names', 'quotechar')
options = {k: kwargs[k] for k in keys if k in kwargs}
df = pd.read_csv(file, **options)
elif format == 'json':
Expand All @@ -92,22 +95,27 @@ def read(self, data_path: str, url: str = None,

x = kwargs.get("x", "text")
y = kwargs.get('y', 'labels')
if isinstance(x, list):
if class_sep is None:
# each sample is a tuple ("text", "label")
data[data_type] = [([row[x_] for x_ in x], str(row[y]))
for _, row in df.iterrows()]
else:
# each sample is a tuple ("text", ["label", "label", ...])
data[data_type] = [([row[x_] for x_ in x], str(row[y]).split(class_sep))
for _, row in df.iterrows()]
else:
if class_sep is None:
# each sample is a tuple ("text", "label")
data[data_type] = [(row[x], str(row[y])) for _, row in df.iterrows()]
else:
# each sample is a tuple ("text", ["label", "label", ...])
data[data_type] = [(row[x], str(row[y]).split(class_sep)) for _, row in df.iterrows()]
data[data_type] = []
i = 0
prev_n_classes = 0 # to capture samples with different n_classes
for _, row in df.iterrows():
if isinstance(x, list):
sample = [row[x_] for x_ in x]
else:
sample = row[x]
label = str(row[y])
if class_sep:
label = str(row[y]).split(class_sep)
if prev_n_classes == 0:
prev_n_classes = len(label)
assert len(label) == prev_n_classes, f"Wrong class number at {i} row"
if float_labels:
label = [float(k) for k in label]
if sample == sample and label == label: # not NAN
data[data_type].append((sample, label))
else:
log.warning(f'Skipping NAN received in file {file} at {i} row')
i += 1
else:
log.warning("Cannot find {} file".format(file))

Expand Down