Skip to content

Commit

Permalink
Add test cases for aggregate arithmetic function "sum/min/max" having…
Browse files Browse the repository at this point in the history
… decimal argumets (#88)

* Add test cases for aggregate arithmetic function "sum/min/max/avg" having decimal argumets

* Fix Decimal comparision to load decimal value at load time instead of check time
* Added custom tag "!decimal" and "!decimallist" and use them to load decimal value as decimal

* Make input type and result type same as substrait which is as follows
Sum:
  Input DECIMAL<P, S> ==> Result: DECIMAL<38, S>
Min:
  Input DECIMAL<P, S> ==> Result: DECIMAL<P, S>
Max:
  Input DECIMAL<P, S> ==> Result: DECIMAL<P, S>
* Remove average test cases for now. Their return type doesn't seem to be conform to Substrait
* I will raise avg in a separate PR
  • Loading branch information
anshuldata authored Aug 22, 2024
1 parent 91056da commit 90241dc
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 1 deletion.
23 changes: 22 additions & 1 deletion bft/core/yaml_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from abc import ABC, abstractmethod
from decimal import Decimal
from typing import BinaryIO, Generic, Iterable, List, TypeVar

import yaml
Expand Down Expand Up @@ -90,7 +91,27 @@ class BaseYamlParser(ABC, Generic[T]):
def get_visitor(self) -> BaseYamlVisitor[T]:
pass

def get_loader(self):
loader = yaml.SafeLoader
"""Add tag "!decimal" to the loader """
loader.add_constructor("!decimal", self.decimal_constructor)
loader.add_constructor("!decimallist", self.list_of_decimal_constructor)
return loader

def decimal_constructor(self, loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
return self.get_decimal_value(loader, node)

def get_decimal_value(self, loader: yaml.SafeLoader, node: yaml.ScalarNode):
value = loader.construct_scalar(node)
if isinstance(value, str) and value.lower() == 'null':
return None
return Decimal(value)

def list_of_decimal_constructor(self, loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
return [self.get_decimal_value(loader, item) for item in node.value]

def parse(self, f: BinaryIO) -> List[T]:
objs = yaml.load_all(f, SafeLoader)
loader = self.get_loader()
objs = yaml.load_all(f, loader)
visitor = self.get_visitor()
return [visitor.visit(obj) for obj in objs]
23 changes: 23 additions & 0 deletions bft/core/yaml_parser_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from decimal import Decimal
from typing import NamedTuple

from bft.core.yaml_parser import BaseYamlParser


class TestDecimalResult(NamedTuple):
cases: Decimal | list[Decimal]

class TestCaseVisitor():
def visit(self, testcase):
return TestDecimalResult(testcase)
class DecimalTestCaseParser(BaseYamlParser[TestDecimalResult]):
def get_visitor(self) -> TestCaseVisitor:
return TestCaseVisitor()

def test_yaml_parser_decimal_tag():
parser = DecimalTestCaseParser()
# parser returns list of parsed values
assert parser.parse(b"!decimal 1") == [TestDecimalResult(Decimal('1'))]
assert parser.parse(b"!decimal 1.78766") == [TestDecimalResult(Decimal('1.78766'))]
assert parser.parse(b"!decimal null") == [TestDecimalResult(None)]
assert parser.parse(b"!decimallist [1.2, null, 7.547]") == [TestDecimalResult([Decimal('1.2'), None, Decimal('7.547')])]
91 changes: 91 additions & 0 deletions cases/arithmetic_decimal/max_decimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml
function: max
cases:
- group:
id: basic
description: Basic examples without any special cases
args:
- value: !decimallist [20, -3, 1, -10, 0, 5]
type: decimal<2, 0>
result:
value: !decimal 20
type: decimal<2, 0>
- group: basic
args:
- value: !decimallist [-32768, 32767, 20000, -30000]
type: decimal<5, 0>
result:
value: !decimal 32767
type: decimal<5, 0>
- group: basic
args:
- value: !decimallist [-214748648, 214748647, 21470048, 4000000]
type: decimal<9, 0>
result:
value: !decimal 214748647
type: decimal<9, 0>
- group: basic
args:
- value: !decimallist [2000000000, -3217908979, 629000000, -100000000, 0, 987654321]
type: decimal<10, 0>
result:
value: !decimal 2000000000
type: decimal<10, 0>
- group: basic
args:
- value: !decimallist [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<2, 1>
result:
value: !decimal 5.0
type: decimal<2, 1>
- group: basic
args:
- value: !decimallist [99999999999999999999999999999999999999, 0, -99999999999999999999999999999999999998, 111111111, -76]
type: decimal<38, 0>
result:
value: !decimal 99999999999999999999999999999999999999
type: decimal<38, 0>
- group:
id: null_handling
description: Examples with null as unput or output
args:
- value: !decimallist [Null, Null, Null]
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist []
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist [2000000000, Null, 629000000, -100000000, Null, 987654321]
type: decimal<10, 0>
result:
value: !decimal 2000000000
type: decimal<10, 0>
- group: null_handling
args:
- value: !decimallist [Null, Null]
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist []
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, Null, 11111111111111111111111111111111111111, Null]
type: decimal<38, 0>
result:
value: !decimal 99999999999999999999999999999999999999
type: decimal<38, 0>
77 changes: 77 additions & 0 deletions cases/arithmetic_decimal/min_decimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml
function: min
cases:
- group:
id: basic
description: Basic examples without any special cases
args:
- value: !decimallist [20, -3, 1, -10, 0, 5]
type: decimal<2, 0>
result:
value: !decimal -10
type: decimal<2, 0>
- group: basic
args:
- value: !decimallist [-32768, 32767, 20000, -30000]
type: decimal<5, 0>
result:
value: !decimal -32768
type: decimal<5, 0>
- group: basic
args:
- value: !decimallist [-214748648, 214748647, 21470048, 4000000]
type: decimal<9, 0>
result:
value: !decimal -214748648
type: decimal<9, 0>
- group: basic
args:
- value: !decimallist [2000000000, -3217908979, 629000000, -100000000, 0, 987654321]
type: decimal<10, 0>
result:
value: !decimal -3217908979
type: decimal<10, 0>
- group: basic
args:
- value: !decimallist [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<2, 1>
result:
value: !decimal -7.5
type: decimal<2, 1>
- group: basic
args:
- value: !decimallist [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, -99999999999999999999999999999999999997, 0, 1111]
type: decimal<38, 0>
result:
value: !decimal -99999999999999999999999999999999999998
type: decimal<38, 0>
- group:
id: null_handling
description: Examples with null as unput or output
args:
- value: !decimallist [Null, Null, Null]
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist []
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist [2000000000, Null, 629000000, -100000000, Null, 987654321]
type: decimal<10, 0>
result:
value: !decimal -100000000
type: decimal<10, 0>
- group: null_handling
args:
- value: !decimallist [-99999999999999999999999999999999999998, Null, 99999999999999999999999999999999999999, Null]
type: decimal<38, 0>
result:
value: !decimal -99999999999999999999999999999999999998
type: decimal<38, 0>
66 changes: 66 additions & 0 deletions cases/arithmetic_decimal/sum_decimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml
function: sum
cases:
- group:
id: basic
description: Basic examples without any special cases
args:
- value: !decimallist [0, -1, 2, 20]
type: decimal<2, 0>
result:
value: !decimal 21
type: decimal<38, 0>
- group: basic
args:
- value: !decimallist [2000000, -3217908, 629000, -100000, 0, 987654]
type: decimal<7, 0>
result:
value: !decimal 298746
type: decimal<38, 0>
- group: basic
args:
- value: !decimallist [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<2, 1>
result:
value: !decimal -2.5
type: decimal<38, 2>
- group: basic
args:
- value: !decimallist [2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875]
type: decimal<23, 22>
result:
value: !decimal 16.5000021457672119140625
type: decimal<38, 22>
- group:
id: overflow
description: Examples demonstrating overflow behavior
args:
- value: !decimallist [99999999999999999999999999999999999999, 1, 1, 1, 1, 99999999999999999999999999999999999999]
type: decimal<38, 0>
options:
overflow: ERROR
result:
special: error
- group:
id: null_handling
description: Examples with null as unput or output
args:
- value: !decimallist [Null, Null, Null]
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<38, 0>
- group: null_handling
args:
- value: !decimallist []
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<38, 0>
- group: null_handling
args:
- value: !decimallist [200000, Null, 629000, -10000, 0, 987621]
type: decimal<6, 0>
result:
value: !decimal 1806621
type: decimal<38, 0>
12 changes: 12 additions & 0 deletions dialects/duckdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,18 @@ aggregate_functions:
aggregate: true
supported_kernels:
- any
- name: arithmetic_decimal.min
aggregate: true
supported_kernels:
- dec
- name: arithmetic_decimal.max
aggregate: true
supported_kernels:
- dec
- name: arithmetic_decimal.sum
aggregate: true
supported_kernels:
- dec
- name: boolean.bool_and
aggregate: true
supported_kernels:
Expand Down
12 changes: 12 additions & 0 deletions dialects/snowflake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,18 @@ aggregate_functions:
aggregate: true
supported_kernels:
- fp64
- name: arithmetic_decimal.min
aggregate: true
supported_kernels:
- dec
- name: arithmetic_decimal.max
aggregate: true
supported_kernels:
- dec
- name: arithmetic_decimal.sum
aggregate: true
supported_kernels:
- dec
- name: boolean.bool_and
local_name: booland_agg
aggregate: true
Expand Down

0 comments on commit 90241dc

Please sign in to comment.