From 7e6b6eda307a221c0c5bae7e46c6e7b84df18b27 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 24 Oct 2024 16:22:51 -0400 Subject: [PATCH] fix dtype benchmarking --- pyproject.toml | 5 ++++- sdv/data_processing/data_processor.py | 11 +++++++++-- tests/benchmark/numpy_dtypes.py | 8 ++++---- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d5a39b7d..a48c936fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -207,7 +207,10 @@ select = [ # print statements "T201", # pandas-vet - "PD" + "PD", + # numpy 2.0 + "NPY201" + ] ignore = [ # pydocstyle diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 10d1fb045..4143fa522 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -561,7 +561,8 @@ def _create_config(self, data, columns_created_by_constraints): ) if sdtype == 'id': - is_numeric = pd.api.types.is_numeric_dtype(data[column].dtype) + column_dtype = data[column].dtype + is_numeric = pd.api.types.is_numeric_dtype(column_dtype) if column_metadata.get('regex_format', False): transformers[column] = self.create_regex_generator( column, sdtype, column_metadata, is_numeric @@ -571,7 +572,13 @@ def _create_config(self, data, columns_created_by_constraints): else: bothify_format = 'sdv-id-??????' if is_numeric: - bothify_format = '#########' + column_dtype = str(column_dtype).lower() + if 'int8' in column_dtype: + bothify_format = '##' + elif 'int16' in column_dtype: + bothify_format = '####' + else: + bothify_format = '#########' cardinality_rule = None if column in self._keys: diff --git a/tests/benchmark/numpy_dtypes.py b/tests/benchmark/numpy_dtypes.py index fb46008d8..83290cbe2 100644 --- a/tests/benchmark/numpy_dtypes.py +++ b/tests/benchmark/numpy_dtypes.py @@ -61,14 +61,14 @@ }), 'np.string': pd.DataFrame({ 'np.string': pd.Series([ - np.string_('string1'), - np.string_('string2'), - np.string_('string3'), + np.bytes_('string1'), + np.bytes_('string2'), + np.bytes_('string3'), ]) }), 'np.unicode': pd.DataFrame({ 'np.unicode': pd.Series( - [np.unicode_('unicode1'), np.unicode_('unicode2'), np.unicode_('unicode3')], + [np.str_('unicode1'), np.str_('unicode2'), np.str_('unicode3')], dtype='string', ) }),