Skip to content

Commit

Permalink
Add test cases for aggregate arithmetic function "sum/min/max/avg" ha…
Browse files Browse the repository at this point in the history
…ving decimal argumets
  • Loading branch information
anshuldata committed Aug 2, 2024
1 parent 0c946db commit 96dbadb
Show file tree
Hide file tree
Showing 11 changed files with 396 additions and 2 deletions.
17 changes: 17 additions & 0 deletions bft/cases/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from decimal import Decimal
from typing import BinaryIO, Iterable, List

from bft.core.yaml_parser import BaseYamlParser, BaseYamlVisitor
Expand Down Expand Up @@ -38,8 +39,24 @@ def __normalize_yaml_literal(self, value, data_type):
return math.nan
else:
raise ValueError(f"Unrecognized float string literal {value}")
if data_type.startswith("dec"):
ret_val = self._normalize_decimal_type(value)
return ret_val
return value

def _normalize_decimal_type(self, val):
if val is None:
return val
if not isinstance(val, list):
return Decimal(str(val))
converted_list = []
for v in val:
if v is not None:
converted_list.append(Decimal(v))
else:
converted_list.append(v)
return converted_list

def visit_literal(self, lit):
value = self._get_or_die(lit, "value")
data_type = self._get_or_die(lit, "type")
Expand Down
8 changes: 7 additions & 1 deletion bft/testers/duckdb/runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import datetime
import math
from decimal import Decimal, ROUND_DOWN
from typing import Dict, NamedTuple

import duckdb

from bft.cases.runner import SqlCaseResult, SqlCaseRunner
from bft.cases.types import Case
from bft.dialects.types import SqlMapping
from bft.utils.utils import type_to_dialect_type
from bft.utils.utils import type_to_dialect_type, compareDecimalResult

type_map = {
"i8": "TINYINT",
Expand Down Expand Up @@ -139,6 +140,11 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
elif case.result.type.startswith("fp") and case.result.value and result:
if math.isclose(result, case.result.value, rel_tol=1e-7):
return SqlCaseResult.success()
elif case.result.type.startswith("dec") and case.result.value and result:
if compareDecimalResult(case.result.value, Decimal(str(result))):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
else:
if result == case.result.value:
return SqlCaseResult.success()
Expand Down
9 changes: 8 additions & 1 deletion bft/testers/snowflake/runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import datetime
import math
import os
from decimal import Decimal

import yaml
from typing import Dict, NamedTuple

Expand All @@ -10,7 +12,7 @@
from bft.cases.runner import SqlCaseResult, SqlCaseRunner
from bft.cases.types import Case
from bft.dialects.types import SqlMapping
from bft.utils.utils import type_to_dialect_type
from bft.utils.utils import type_to_dialect_type, compareDecimalResult

type_map = {
"fp64": "FLOAT",
Expand Down Expand Up @@ -167,6 +169,11 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
elif case.result.type.startswith("fp") and case.result.value and result:
if math.isclose(result, case.result.value, rel_tol=1e-7):
return SqlCaseResult.success()
elif case.result.type.startswith("dec") and case.result.value and result:
if compareDecimalResult(case.result.value, Decimal(str(result))):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
else:
if result == case.result.value:
return SqlCaseResult.success()
Expand Down
20 changes: 20 additions & 0 deletions bft/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from decimal import Decimal, ROUND_DOWN, getcontext, localcontext
from typing import Dict


Expand All @@ -24,3 +25,22 @@ def type_to_dialect_type(type: str, type_map: Dict[str, str])->str:
return type_val
# transform parameterized type name to have dialect type
return type.replace(type_to_check, type_val).replace("<", "(").replace(">", ")")

def compareDecimalResult(expected_result: Decimal, actual_result: Decimal)->bool:
'''
Compares non-null decimal type based on scale of expected_result
:param expected_result: expected result. Its scale is considered to be the scale to compare
:param actual_result:
:return: bool
'''
# make scale of actual_result to same as expected_result for comparison
scale = abs(expected_result.as_tuple().exponent)
rounding_format = Decimal(f"1.{'0' * scale}")
try:
# set thread precison to 38 since database decimal support max 38
with localcontext(prec=38) as ctx:
rounded_result = actual_result.quantize(rounding_format, rounding=ROUND_DOWN)
except Exception as e:
print(f"Exception while rounding: {e}")
return False
return rounded_result == expected_result
12 changes: 12 additions & 0 deletions bft/utils/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from decimal import Decimal

from bft.utils.utils import compareDecimalResult


def test_compare_decimal_result():
assert compareDecimalResult(Decimal('1'), Decimal('1'))
assert compareDecimalResult(Decimal('99999999999999999999999999999999999999'), Decimal('99999999999999999999999999999999999999'))
assert compareDecimalResult(Decimal('1.75'), Decimal('1.75678'))
assert compareDecimalResult(Decimal('1.757'), Decimal('1.75678')) == False
assert compareDecimalResult(Decimal('2.33'), Decimal('2.330000000078644688'))
assert compareDecimalResult(Decimal('4.12500053644180'), Decimal('4.1250005364418029785156'))
66 changes: 66 additions & 0 deletions cases/arithmetic_decimal/avg_decimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml
function: avg
cases:
- group:
id: basic
description: Basic examples without any special cases
args:
- value: [0, -1, 2, 20]
type: decimal<38, 0>
result:
value: 5.25
type: decimal<38, 2>
- group: basic
args:
- value: [2000000, -3217908, 629000, -100000, 0, 987654]
type: decimal<38, 0>
result:
value: 49791
type: decimal<38, 5>
- group: basic
args:
- value: [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<38, 2>
result:
value: -0.5
type: decimal<38, 2>
- group: basic
args:
- value: [2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875]
type: decimal<38, 22>
result:
value: 4.12500053644180
type: decimal<38, 22>
- group:
id: overflow
description: Examples demonstrating overflow behavior
args:
- value: [99999999999999999999999999999999999999, 1, 1, 1, 1, 99999999999999999999999999999999999999]
type: decimal<38, 0>
options:
overflow: ERROR
result:
special: error
- group:
id: null_handling
description: Examples with null as unput or output
args:
- value: [Null, Null, Null]
type: decimal<38, 0>
result:
value: Null
type: decimal<38, 0>
- group: null_handling
args:
- value: []
type: decimal<38, 0>
result:
value: Null
type: decimal<38, 0>
- group: null_handling
args:
- value: [200000, Null, 629000, -10000, 0, 987621]
type: decimal<38, 0>
result:
value: 361324.2
type: decimal<38, 2>
91 changes: 91 additions & 0 deletions cases/arithmetic_decimal/max_decimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml
function: max
cases:
- group:
id: basic
description: Basic examples without any special cases
args:
- value: [20, -3, 1, -10, 0, 5]
type: decimal<38, 0>
result:
value: 20
type: decimal<38, 0>
- group: basic
args:
- value: [-32768, 32767, 20000, -30000]
type: decimal<38, 0>
result:
value: 32767
type: decimal<38, 0>
- group: basic
args:
- value: [-214748648, 214748647, 21470048, 4000000]
type: decimal<38, 0>
result:
value: 214748647
type: decimal<38, 0>
- group: basic
args:
- value: [2000000000, -3217908979, 629000000, -100000000, 0, 987654321]
type: decimal<38, 0>
result:
value: 2000000000
type: decimal<38, 0>
- group: basic
args:
- value: [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<38, 2>
result:
value: 5.0
type: decimal<38, 2>
- group: basic
args:
- value: [99999999999999999999999999999999999999, 0, -99999999999999999999999999999999999998, 111111111, -76]
type: decimal<38, 0>
result:
value: 99999999999999999999999999999999999999
type: decimal<38, 0>
- group:
id: null_handling
description: Examples with null as unput or output
args:
- value: [Null, Null, Null]
type: decimal<38, 0>
result:
value: Null
type: decimal<38, 0>
- group: null_handling
args:
- value: []
type: decimal<38, 0>
result:
value: Null
type: decimal<38, 0>
- group: null_handling
args:
- value: [2000000000, Null, 629000000, -100000000, Null, 987654321]
type: decimal<38, 0>
result:
value: 2000000000
type: decimal<38, 0>
- group: null_handling
args:
- value: [Null, Null]
type: decimal<38, 0>
result:
value: Null
type: decimal<38, 0>
- group: null_handling
args:
- value: []
type: decimal<38, 0>
result:
value: Null
type: decimal<38, 0>
- group: null_handling
args:
- value: [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, Null, 11111111111111111111111111111111111111, Null]
type: decimal<38, 0>
result:
value: 99999999999999999999999999999999999999
type: decimal<38, 0>
77 changes: 77 additions & 0 deletions cases/arithmetic_decimal/min_decimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml
function: min
cases:
- group:
id: basic
description: Basic examples without any special cases
args:
- value: [20, -3, 1, -10, 0, 5]
type: decimal<38, 0>
result:
value: -10
type: decimal<38, 0>
- group: basic
args:
- value: [-32768, 32767, 20000, -30000]
type: decimal<38, 0>
result:
value: -32768
type: decimal<38, 0>
- group: basic
args:
- value: [-214748648, 214748647, 21470048, 4000000]
type: decimal<38, 0>
result:
value: -214748648
type: decimal<38, 0>
- group: basic
args:
- value: [2000000000, -3217908979, 629000000, -100000000, 0, 987654321]
type: decimal<38, 0>
result:
value: -3217908979
type: decimal<38, 0>
- group: basic
args:
- value: [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<38, 2>
result:
value: -7.5
type: decimal<38, 2>
- group: basic
args:
- value: [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, -99999999999999999999999999999999999997, 0, 1111]
type: decimal<38, 0>
result:
value: -99999999999999999999999999999999999998
type: decimal<38, 0>
- group:
id: null_handling
description: Examples with null as unput or output
args:
- value: [Null, Null, Null]
type: decimal<38, 0>
result:
value: Null
type: decimal<38, 0>
- group: null_handling
args:
- value: []
type: decimal<38, 0>
result:
value: Null
type: decimal<38, 0>
- group: null_handling
args:
- value: [2000000000, Null, 629000000, -100000000, Null, 987654321]
type: decimal<38, 0>
result:
value: -100000000
type: decimal<38, 0>
- group: null_handling
args:
- value: [-99999999999999999999999999999999999998, Null, 99999999999999999999999999999999999999, Null]
type: decimal<38, 0>
result:
value: -99999999999999999999999999999999999998
type: decimal<38, 0>
Loading

0 comments on commit 96dbadb

Please sign in to comment.