Skip to content

Commit

Permalink
SNOW-1235476 Improve Local Testing's implementation of to_decimal (#1423
Browse files Browse the repository at this point in the history
)

* Add tests

* Address comments
  • Loading branch information
sfc-gh-stan authored Apr 26, 2024
1 parent 7ae8db5 commit d758658
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 37 deletions.
68 changes: 31 additions & 37 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,13 +432,13 @@ def mock_to_decimal(
"""
[x] For NULL input, the result is NULL.
[ ] For fixed-point numbers:
[x] For fixed-point numbers:
Numbers with different scales are converted by either adding zeros to the right (if the scale needs to be increased) or by reducing the number of fractional digits by rounding (if the scale needs to be decreased).
Note that casts of fixed-point numbers to fixed-point numbers that increase scale might fail.
[ ] For floating-point numbers:
[x] For floating-point numbers:
Numbers are converted if they are within the representable range, given the scale.
Expand All @@ -448,53 +448,47 @@ def mock_to_decimal(
For floating-point input, omitting the mantissa or exponent is allowed and is interpreted as 0. Thus, E is parsed as 0.
[ ] Strings are converted as decimal, integer, fractional, or floating-point numbers.
[x] Strings are converted as decimal, integer, fractional, or floating-point numbers.
[x] For fractional input, the precision is deduced as the number of digits after the point.
For VARIANT input:
[ ] If the variant contains a fixed-point or a floating-point numeric value, an appropriate numeric conversion is performed.
[x] If the variant contains a fixed-point or a floating-point numeric value, an appropriate numeric conversion is performed.
[ ] If the variant contains a string, a string conversion is performed.
[x] If the variant contains a string, a string conversion is performed.
[ ] If the variant contains a Boolean value, the result is 0 or 1 (for false and true, correspondingly).
[x] If the variant contains a Boolean value, the result is 0 or 1 (for false and true, correspondingly).
[ ] If the variant contains JSON null value, the output is NULL.
[x] If the variant contains JSON null value, the output is NULL.
"""
res = []

for data in e:
if data is None:
res.append(data)
continue
try:
try:
float(data)
except ValueError:
raise SnowparkSQLException(f"Numeric value '{data}' is not recognized.")

integer_part = round(float(data))
integer_part_str = str(integer_part)
len_integer_part = (
len(integer_part_str) - 1
if integer_part_str[0] == "-"
else len(integer_part_str)
def cast_as_float_convert_to_decimal(x: Union[Decimal, float, str, bool]):
x = float(x)
if x in (math.inf, -math.inf, math.nan):
raise ValueError(
"Values of infinity and NaN cannot be converted to decimal"
)
if len_integer_part > precision:
raise SnowparkSQLException(f"Numeric value '{data}' is out of range")
remaining_decimal_len = min(precision - len(str(integer_part)), scale)
res.append(Decimal(str(round(float(data), remaining_decimal_len))))
except BaseException:
if try_cast:
res.append(None)
else:
raise

return ColumnEmulator(
data=res,
sf_type=ColumnType(DecimalType(precision, scale), nullable=e.sf_type.nullable),
integer_part_len = 1 if abs(x) < 1 else math.ceil(math.log10(abs(x)))
if integer_part_len > precision:
raise SnowparkSQLException(f"Numeric value '{x}' is out of range")
remaining_decimal_len = min(precision - integer_part_len, scale)
return Decimal(str(round(x, remaining_decimal_len)))

if isinstance(e.sf_type.datatype, (_NumericType, BooleanType, NullType)):
res = e.apply(
lambda x: try_convert(cast_as_float_convert_to_decimal, try_cast, x)
)
elif isinstance(e.sf_type.datatype, (StringType, VariantType)):
res = e.replace({"E": 0}).apply(
lambda x: try_convert(cast_as_float_convert_to_decimal, try_cast, x)
)
else:
raise TypeError(f"Invalid input type to TO_DECIMAL {e.sf_type.datatype}")
res.sf_type = ColumnType(
DecimalType(precision, scale), nullable=e.sf_type.nullable or res.hasnans
)
return res


@patch("to_time")
Expand Down
79 changes: 79 additions & 0 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
import decimal
import json
import math
import re
from itertools import chain

Expand Down Expand Up @@ -135,6 +136,7 @@
to_binary,
to_char,
to_date,
to_decimal,
to_double,
to_json,
to_object,
Expand All @@ -154,6 +156,7 @@
BooleanType,
DateType,
DecimalType,
DoubleType,
FloatType,
MapType,
StringType,
Expand Down Expand Up @@ -1890,3 +1893,79 @@ def test_to_double(session, local_testing_mode):
[Row(1.2, -2.34, -2.34)],
sort=False,
)


@pytest.mark.localtest
def test_to_decimal(session, local_testing_mode):
# Supported input type
df = session.create_dataframe(
[[decimal.Decimal("12.34"), 12.345678, "3.14E-6", True, None]],
schema=StructType(
[
StructField("decimal_col", DecimalType(26, 12)),
StructField("float_col", DoubleType()),
StructField("str_col", StringType()),
StructField("bool_col1", BooleanType()),
StructField("bool_col2", BooleanType()),
]
),
)
# Test when scale is 0
Utils.check_answer(
df.select([to_decimal(c, 38, 0) for c in df.columns]),
[
Row(
decimal.Decimal("12"),
decimal.Decimal("12"),
decimal.Decimal("0"),
decimal.Decimal("1"),
None,
)
],
)

# Test when scale is 2
Utils.check_answer(
df.select([to_decimal(c, 38, 2) for c in df.columns]),
[
Row(
decimal.Decimal("12.34"),
decimal.Decimal("12.35"),
decimal.Decimal("0"),
decimal.Decimal("1"),
None,
)
],
)

# Test when scale is 6
Utils.check_answer(
df.select([to_decimal(c, 38, 6) for c in df.columns]),
[
Row(
decimal.Decimal("12.34"),
decimal.Decimal("12.345678"),
decimal.Decimal("0.000003"),
decimal.Decimal("1"),
None,
)
],
)

# Unsupported input
df = session.create_dataframe(
[[-math.inf, datetime.date.today()]],
schema=StructType(
[StructField("float_col", FloatType()), StructField("date_col", DateType())]
),
)

# Test when input type is not supported
expected_error = TypeError if local_testing_mode else SnowparkSQLException
with pytest.raises(expected_error):
df.select([to_decimal(df.date_col, 38, 0)]).collect()

# Test when input value is not supported
expected_error = ValueError if local_testing_mode else SnowparkSQLException
with pytest.raises(expected_error):
df.select([to_decimal(df.float_col, 38, 0)]).collect()

0 comments on commit d758658

Please sign in to comment.