Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test cases for aggregate arithmetic function "sum/min/max" having decimal argumets #88

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion bft/core/yaml_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from abc import ABC, abstractmethod
from decimal import Decimal
from typing import BinaryIO, Generic, Iterable, List, TypeVar

import yaml
Expand Down Expand Up @@ -90,7 +91,27 @@ class BaseYamlParser(ABC, Generic[T]):
def get_visitor(self) -> BaseYamlVisitor[T]:
pass

def get_loader(self):
loader = yaml.SafeLoader
"""Add tag "!decimal" to the loader """
loader.add_constructor("!decimal", self.decimal_constructor)
loader.add_constructor("!decimallist", self.list_of_decimal_constructor)
return loader

def decimal_constructor(self, loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
return self.get_decimal_value(loader, node)

def get_decimal_value(self, loader: yaml.SafeLoader, node: yaml.ScalarNode):
value = loader.construct_scalar(node)
if isinstance(value, str) and value.lower() == 'null':
return None
return Decimal(value)

def list_of_decimal_constructor(self, loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
return [self.get_decimal_value(loader, item) for item in node.value]

def parse(self, f: BinaryIO) -> List[T]:
objs = yaml.load_all(f, SafeLoader)
loader = self.get_loader()
objs = yaml.load_all(f, loader)
visitor = self.get_visitor()
return [visitor.visit(obj) for obj in objs]
23 changes: 23 additions & 0 deletions bft/core/yaml_parser_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from decimal import Decimal
from typing import NamedTuple

from bft.core.yaml_parser import BaseYamlParser


class TestDecimalResult(NamedTuple):
cases: Decimal | list[Decimal]

class TestCaseVisitor():
def visit(self, testcase):
return TestDecimalResult(testcase)
class DecimalTestCaseParser(BaseYamlParser[TestDecimalResult]):
def get_visitor(self) -> TestCaseVisitor:
return TestCaseVisitor()

def test_yaml_parser_decimal_tag():
parser = DecimalTestCaseParser()
# parser returns list of parsed values
assert parser.parse(b"!decimal 1") == [TestDecimalResult(Decimal('1'))]
assert parser.parse(b"!decimal 1.78766") == [TestDecimalResult(Decimal('1.78766'))]
assert parser.parse(b"!decimal null") == [TestDecimalResult(None)]
assert parser.parse(b"!decimallist [1.2, null, 7.547]") == [TestDecimalResult([Decimal('1.2'), None, Decimal('7.547')])]
91 changes: 91 additions & 0 deletions cases/arithmetic_decimal/max_decimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml
function: max
cases:
- group:
id: basic
description: Basic examples without any special cases
args:
- value: !decimallist [20, -3, 1, -10, 0, 5]
type: decimal<2, 0>
result:
value: !decimal 20
type: decimal<2, 0>
- group: basic
args:
- value: !decimallist [-32768, 32767, 20000, -30000]
type: decimal<5, 0>
result:
value: !decimal 32767
type: decimal<5, 0>
- group: basic
args:
- value: !decimallist [-214748648, 214748647, 21470048, 4000000]
type: decimal<9, 0>
result:
value: !decimal 214748647
type: decimal<9, 0>
- group: basic
args:
- value: !decimallist [2000000000, -3217908979, 629000000, -100000000, 0, 987654321]
type: decimal<10, 0>
result:
value: !decimal 2000000000
type: decimal<10, 0>
- group: basic
args:
- value: !decimallist [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<2, 1>
result:
value: !decimal 5.0
type: decimal<2, 1>
- group: basic
args:
- value: !decimallist [99999999999999999999999999999999999999, 0, -99999999999999999999999999999999999998, 111111111, -76]
type: decimal<38, 0>
result:
value: !decimal 99999999999999999999999999999999999999
type: decimal<38, 0>
- group:
id: null_handling
description: Examples with null as unput or output
args:
- value: !decimallist [Null, Null, Null]
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist []
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist [2000000000, Null, 629000000, -100000000, Null, 987654321]
type: decimal<10, 0>
result:
value: !decimal 2000000000
type: decimal<10, 0>
- group: null_handling
args:
- value: !decimallist [Null, Null]
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist []
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, Null, 11111111111111111111111111111111111111, Null]
type: decimal<38, 0>
result:
value: !decimal 99999999999999999999999999999999999999
type: decimal<38, 0>
77 changes: 77 additions & 0 deletions cases/arithmetic_decimal/min_decimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml
function: min
cases:
- group:
id: basic
description: Basic examples without any special cases
args:
- value: !decimallist [20, -3, 1, -10, 0, 5]
type: decimal<2, 0>
result:
value: !decimal -10
type: decimal<2, 0>
- group: basic
args:
- value: !decimallist [-32768, 32767, 20000, -30000]
type: decimal<5, 0>
result:
value: !decimal -32768
type: decimal<5, 0>
- group: basic
args:
- value: !decimallist [-214748648, 214748647, 21470048, 4000000]
type: decimal<9, 0>
result:
value: !decimal -214748648
type: decimal<9, 0>
- group: basic
args:
- value: !decimallist [2000000000, -3217908979, 629000000, -100000000, 0, 987654321]
type: decimal<10, 0>
result:
value: !decimal -3217908979
type: decimal<10, 0>
- group: basic
args:
- value: !decimallist [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<2, 1>
result:
value: !decimal -7.5
type: decimal<2, 1>
- group: basic
args:
- value: !decimallist [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, -99999999999999999999999999999999999997, 0, 1111]
type: decimal<38, 0>
result:
value: !decimal -99999999999999999999999999999999999998
type: decimal<38, 0>
- group:
id: null_handling
description: Examples with null as unput or output
args:
- value: !decimallist [Null, Null, Null]
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist []
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<1, 0>
- group: null_handling
args:
- value: !decimallist [2000000000, Null, 629000000, -100000000, Null, 987654321]
type: decimal<10, 0>
result:
value: !decimal -100000000
type: decimal<10, 0>
- group: null_handling
args:
- value: !decimallist [-99999999999999999999999999999999999998, Null, 99999999999999999999999999999999999999, Null]
type: decimal<38, 0>
result:
value: !decimal -99999999999999999999999999999999999998
type: decimal<38, 0>
66 changes: 66 additions & 0 deletions cases/arithmetic_decimal/sum_decimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml
function: sum
cases:
- group:
id: basic
description: Basic examples without any special cases
args:
- value: !decimallist [0, -1, 2, 20]
type: decimal<2, 0>
result:
value: !decimal 21
type: decimal<38, 0>
- group: basic
args:
- value: !decimallist [2000000, -3217908, 629000, -100000, 0, 987654]
type: decimal<7, 0>
result:
value: !decimal 298746
type: decimal<38, 0>
- group: basic
args:
- value: !decimallist [2.5, 0, 5.0, -2.5, -7.5]
type: decimal<2, 1>
result:
value: !decimal -2.5
type: decimal<38, 2>
- group: basic
args:
- value: !decimallist [2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875]
type: decimal<23, 22>
result:
value: !decimal 16.5000021457672119140625
type: decimal<38, 22>
- group:
id: overflow
description: Examples demonstrating overflow behavior
args:
- value: !decimallist [99999999999999999999999999999999999999, 1, 1, 1, 1, 99999999999999999999999999999999999999]
type: decimal<38, 0>
options:
overflow: ERROR
result:
special: error
- group:
id: null_handling
description: Examples with null as unput or output
args:
- value: !decimallist [Null, Null, Null]
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<38, 0>
- group: null_handling
args:
- value: !decimallist []
type: decimal<1, 0>
result:
value: !decimal Null
type: decimal<38, 0>
- group: null_handling
args:
- value: !decimallist [200000, Null, 629000, -10000, 0, 987621]
type: decimal<6, 0>
result:
value: !decimal 1806621
type: decimal<38, 0>
12 changes: 12 additions & 0 deletions dialects/duckdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,18 @@ aggregate_functions:
aggregate: true
supported_kernels:
- any
- name: arithmetic_decimal.min
aggregate: true
supported_kernels:
- dec
- name: arithmetic_decimal.max
aggregate: true
supported_kernels:
- dec
- name: arithmetic_decimal.sum
aggregate: true
supported_kernels:
- dec
- name: boolean.bool_and
aggregate: true
supported_kernels:
Expand Down
12 changes: 12 additions & 0 deletions dialects/snowflake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,18 @@ aggregate_functions:
aggregate: true
supported_kernels:
- fp64
- name: arithmetic_decimal.min
aggregate: true
supported_kernels:
- dec
- name: arithmetic_decimal.max
aggregate: true
supported_kernels:
- dec
- name: arithmetic_decimal.sum
aggregate: true
supported_kernels:
- dec
- name: boolean.bool_and
local_name: booland_agg
aggregate: true
Expand Down
Loading