Skip to content

Commit

Permalink
[SNOW-990542] When converting snowpark dataframe to pandas, cast deci…
Browse files Browse the repository at this point in the history
…mal columns to float type (#1201)

* When converting snowpark dataframe to pandas, cast decimal columns to float64

* addressing comments

* changelog

* use to_numeric

* Revert "use to_numeric"

This reverts commit 21441fd.

* comment

* try downcast and force float64

---------

Co-authored-by: sfc-gh-sfan <[email protected]>
  • Loading branch information
sfc-gh-xhe and sfc-gh-sfan authored Jan 12, 2024
1 parent ed1d3eb commit 27ae233
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- Earlier timestamp columns without a timezone would be converted to nanosecond epochs and inferred as `LongType()`, but will now be correctly be maintained as timestamp values and be inferred as `TimestampType(TimestampTimeZone.NTZ)`.
- Earlier timestamp columns with a timezone would be inferred as `TimestampType(TimestampTimeZone.NTZ)` and loose timezone information but will now be correctly inferred as `TimestampType(TimestampTimeZone.LTZ)` and timezone information is retained correctly.
- Set session parameter `PYTHON_SNOWPARK_USE_LOGICAL_TYPE_FOR_CREATE_DATAFRAME` to revert back to old behavior. It is recommended that you update your code soon to align with correct behavior as the parameter will be removed in the future.
- Fixed a bug that `DataFrame.to_pandas` gets decimal type when scale is not 0, and creates an object dtype in `pandas`. Instead, we cast the value to a float64 type.

### Behavior Changes (API Compatible)

Expand Down
43 changes: 25 additions & 18 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,12 @@ def _to_data_or_iter(
data_or_iter = (
map(
functools.partial(
_fix_pandas_df_integer, results_cursor=results_cursor
_fix_pandas_df_fixed_type, results_cursor=results_cursor
),
results_cursor.fetch_pandas_batches(),
)
if to_iter
else _fix_pandas_df_integer(
else _fix_pandas_df_fixed_type(
results_cursor.fetch_pandas_all(), results_cursor
)
)
Expand Down Expand Up @@ -677,7 +677,7 @@ def _get_client_side_session_parameter(self, name: str, default_value: Any) -> A
)


def _fix_pandas_df_integer(
def _fix_pandas_df_fixed_type(
pd_df: "pandas.DataFrame", results_cursor: SnowflakeCursor
) -> "pandas.DataFrame":
"""The compiler does not make any guarantees about the return types - only that they will be large enough for the result.
Expand All @@ -695,24 +695,31 @@ def _fix_pandas_df_integer(
if (
FIELD_ID_TO_NAME.get(column_metadata.type_code) == "FIXED"
and column_metadata.precision is not None
and column_metadata.scale == 0
and not str(pandas_dtype).startswith("int")
):
# When scale = 0 and precision values are between 10-20, the integers fit into int64.
# If we rely only on pandas.to_numeric, it loses precision value on large integers, therefore
# we try to strictly use astype("int64") in this scenario. If the values are too large to
# fit in int64, an OverflowError is thrown and we rely on to_numeric to choose and appropriate
# floating datatype to represent the number.
if column_metadata.precision > 10:
try:
pd_df[pandas_col_name] = pd_df[pandas_col_name].astype("int64")
except OverflowError:
if column_metadata.scale == 0 and not str(pandas_dtype).startswith("int"):
# When scale = 0 and precision values are between 10-20, the integers fit into int64.
# If we rely only on pandas.to_numeric, it loses precision value on large integers, therefore
# we try to strictly use astype("int64") in this scenario. If the values are too large to
# fit in int64, an OverflowError is thrown and we rely on to_numeric to choose and appropriate
# floating datatype to represent the number.
if column_metadata.precision > 10:
try:
pd_df[pandas_col_name] = pd_df[pandas_col_name].astype("int64")
except OverflowError:
pd_df[pandas_col_name] = pandas.to_numeric(
pd_df[pandas_col_name], downcast="integer"
)
else:
pd_df[pandas_col_name] = pandas.to_numeric(
pd_df[pandas_col_name], downcast="integer"
)
else:
pd_df[pandas_col_name] = pandas.to_numeric(
pd_df[pandas_col_name], downcast="integer"
)
elif column_metadata.scale > 0 and not str(pandas_dtype).startswith(
"float"
):
# For decimal columns, we want to cast it into float64 because pandas doesn't
# recognize decimal type.
pandas.to_numeric(pd_df[pandas_col_name], downcast="float")
if pd_df[pandas_col_name].dtype == "O":
pd_df[pandas_col_name] = pd_df[pandas_col_name].astype("float64")

return pd_df
8 changes: 4 additions & 4 deletions src/snowflake/snowpark/mock/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ def _to_data_or_iter(
data_or_iter = (
map(
functools.partial(
_fix_pandas_df_integer, results_cursor=results_cursor
_fix_pandas_df_fixed_type, results_cursor=results_cursor
),
results_cursor.fetch_pandas_batches(),
)
if to_iter
else _fix_pandas_df_integer(
else _fix_pandas_df_fixed_type(
results_cursor.fetch_pandas_all(), results_cursor
)
)
Expand Down Expand Up @@ -436,7 +436,7 @@ def execute(
pandas_df = pandas.DataFrame()
for col_name in res.columns:
pandas_df[unquote_if_quoted(col_name)] = res[col_name].tolist()
rows = _fix_pandas_df_integer(res)
rows = _fix_pandas_df_fixed_type(res)

# the following implementation is just to make DataFrame.to_pandas_batches API workable
# in snowflake, large data result are split into multiple data chunks
Expand Down Expand Up @@ -580,7 +580,7 @@ def get_result_query_id(self, plan: SnowflakePlan, **kwargs) -> str:
return result_set["sfqid"]


def _fix_pandas_df_integer(table_res: TableEmulator) -> "pandas.DataFrame":
def _fix_pandas_df_fixed_type(table_res: TableEmulator) -> "pandas.DataFrame":
pd_df = pandas.DataFrame()
for col_name in table_res.columns:
col_sf_type = table_res.sf_types[col_name]
Expand Down
22 changes: 22 additions & 0 deletions tests/integ/test_df_to_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,28 @@ def test_to_pandas_precision_for_number_38_0(session):
assert pdf["A"].min() == -9223372036854775808


def test_to_pandas_precision_for_non_zero_scale(session):
df = session.sql(
"""
SELECT
num1,
num2,
DIV0(num1, num2) AS A,
DIV0(CAST(num1 AS INTEGER), CAST(num2 AS INTEGER)) AS B,
ROUND(B, 2) as C
FROM (VALUES
(1, 11)
) X(num1, num2);
"""
)

pdf = df.to_pandas()

assert pdf["A"].dtype == "float64"
assert pdf["B"].dtype == "float64"
assert pdf["C"].dtype == "float64"


def test_to_pandas_non_select(session):
# `with ... select ...` is also a SELECT statement
isinstance(session.sql("select 1").to_pandas(), PandasDF)
Expand Down

0 comments on commit 27ae233

Please sign in to comment.