Skip to content

Commit

Permalink
fix dtype benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 24, 2024
1 parent be23f29 commit 7e6b6ed
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ select = [
# print statements
"T201",
# pandas-vet
"PD"
"PD",
# numpy 2.0
"NPY201"

]
ignore = [
# pydocstyle
Expand Down
11 changes: 9 additions & 2 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,8 @@ def _create_config(self, data, columns_created_by_constraints):
)

if sdtype == 'id':
is_numeric = pd.api.types.is_numeric_dtype(data[column].dtype)
column_dtype = data[column].dtype
is_numeric = pd.api.types.is_numeric_dtype(column_dtype)
if column_metadata.get('regex_format', False):
transformers[column] = self.create_regex_generator(
column, sdtype, column_metadata, is_numeric
Expand All @@ -571,7 +572,13 @@ def _create_config(self, data, columns_created_by_constraints):
else:
bothify_format = 'sdv-id-??????'
if is_numeric:
bothify_format = '#########'
column_dtype = str(column_dtype).lower()
if 'int8' in column_dtype:
bothify_format = '##'
elif 'int16' in column_dtype:
bothify_format = '####'
else:
bothify_format = '#########'

cardinality_rule = None
if column in self._keys:
Expand Down
8 changes: 4 additions & 4 deletions tests/benchmark/numpy_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@
}),
'np.string': pd.DataFrame({
'np.string': pd.Series([
np.string_('string1'),
np.string_('string2'),
np.string_('string3'),
np.bytes_('string1'),
np.bytes_('string2'),
np.bytes_('string3'),
])
}),
'np.unicode': pd.DataFrame({
'np.unicode': pd.Series(
[np.unicode_('unicode1'), np.unicode_('unicode2'), np.unicode_('unicode3')],
[np.str_('unicode1'), np.str_('unicode2'), np.str_('unicode3')],
dtype='string',
)
}),
Expand Down

0 comments on commit 7e6b6ed

Please sign in to comment.