diff --git a/bft/core/yaml_parser.py b/bft/core/yaml_parser.py index 91ff4ad..322db92 100644 --- a/bft/core/yaml_parser.py +++ b/bft/core/yaml_parser.py @@ -1,5 +1,6 @@ import math from abc import ABC, abstractmethod +from decimal import Decimal from typing import BinaryIO, Generic, Iterable, List, TypeVar import yaml @@ -90,7 +91,27 @@ class BaseYamlParser(ABC, Generic[T]): def get_visitor(self) -> BaseYamlVisitor[T]: pass + def get_loader(self): + loader = yaml.SafeLoader + """Add tag "!decimal" to the loader """ + loader.add_constructor("!decimal", self.decimal_constructor) + loader.add_constructor("!decimallist", self.list_of_decimal_constructor) + return loader + + def decimal_constructor(self, loader: yaml.SafeLoader, node: yaml.nodes.MappingNode): + return self.get_decimal_value(loader, node) + + def get_decimal_value(self, loader: yaml.SafeLoader, node: yaml.ScalarNode): + value = loader.construct_scalar(node) + if isinstance(value, str) and value.lower() == 'null': + return None + return Decimal(value) + + def list_of_decimal_constructor(self, loader: yaml.SafeLoader, node: yaml.nodes.MappingNode): + return [self.get_decimal_value(loader, item) for item in node.value] + def parse(self, f: BinaryIO) -> List[T]: - objs = yaml.load_all(f, SafeLoader) + loader = self.get_loader() + objs = yaml.load_all(f, loader) visitor = self.get_visitor() return [visitor.visit(obj) for obj in objs] diff --git a/bft/core/yaml_parser_test.py b/bft/core/yaml_parser_test.py new file mode 100644 index 0000000..52262d4 --- /dev/null +++ b/bft/core/yaml_parser_test.py @@ -0,0 +1,23 @@ +from decimal import Decimal +from typing import NamedTuple + +from bft.core.yaml_parser import BaseYamlParser + + +class TestDecimalResult(NamedTuple): + cases: Decimal | list[Decimal] + +class TestCaseVisitor(): + def visit(self, testcase): + return TestDecimalResult(testcase) +class DecimalTestCaseParser(BaseYamlParser[TestDecimalResult]): + def get_visitor(self) -> TestCaseVisitor: + return TestCaseVisitor() + +def test_yaml_parser_decimal_tag(): + parser = DecimalTestCaseParser() + # parser returns list of parsed values + assert parser.parse(b"!decimal 1") == [TestDecimalResult(Decimal('1'))] + assert parser.parse(b"!decimal 1.78766") == [TestDecimalResult(Decimal('1.78766'))] + assert parser.parse(b"!decimal null") == [TestDecimalResult(None)] + assert parser.parse(b"!decimallist [1.2, null, 7.547]") == [TestDecimalResult([Decimal('1.2'), None, Decimal('7.547')])] diff --git a/cases/arithmetic_decimal/max_decimal.yaml b/cases/arithmetic_decimal/max_decimal.yaml new file mode 100644 index 0000000..5a34ca3 --- /dev/null +++ b/cases/arithmetic_decimal/max_decimal.yaml @@ -0,0 +1,91 @@ +base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml +function: max +cases: + - group: + id: basic + description: Basic examples without any special cases + args: + - value: !decimallist [20, -3, 1, -10, 0, 5] + type: decimal<2, 0> + result: + value: !decimal 20 + type: decimal<2, 0> + - group: basic + args: + - value: !decimallist [-32768, 32767, 20000, -30000] + type: decimal<5, 0> + result: + value: !decimal 32767 + type: decimal<5, 0> + - group: basic + args: + - value: !decimallist [-214748648, 214748647, 21470048, 4000000] + type: decimal<9, 0> + result: + value: !decimal 214748647 + type: decimal<9, 0> + - group: basic + args: + - value: !decimallist [2000000000, -3217908979, 629000000, -100000000, 0, 987654321] + type: decimal<10, 0> + result: + value: !decimal 2000000000 + type: decimal<10, 0> + - group: basic + args: + - value: !decimallist [2.5, 0, 5.0, -2.5, -7.5] + type: decimal<2, 1> + result: + value: !decimal 5.0 + type: decimal<2, 1> + - group: basic + args: + - value: !decimallist [99999999999999999999999999999999999999, 0, -99999999999999999999999999999999999998, 111111111, -76] + type: decimal<38, 0> + result: + value: !decimal 99999999999999999999999999999999999999 + type: decimal<38, 0> + - group: + id: null_handling + description: Examples with null as unput or output + args: + - value: !decimallist [Null, Null, Null] + type: decimal<1, 0> + result: + value: !decimal Null + type: decimal<1, 0> + - group: null_handling + args: + - value: !decimallist [] + type: decimal<1, 0> + result: + value: !decimal Null + type: decimal<1, 0> + - group: null_handling + args: + - value: !decimallist [2000000000, Null, 629000000, -100000000, Null, 987654321] + type: decimal<10, 0> + result: + value: !decimal 2000000000 + type: decimal<10, 0> + - group: null_handling + args: + - value: !decimallist [Null, Null] + type: decimal<1, 0> + result: + value: !decimal Null + type: decimal<1, 0> + - group: null_handling + args: + - value: !decimallist [] + type: decimal<1, 0> + result: + value: !decimal Null + type: decimal<1, 0> + - group: null_handling + args: + - value: !decimallist [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, Null, 11111111111111111111111111111111111111, Null] + type: decimal<38, 0> + result: + value: !decimal 99999999999999999999999999999999999999 + type: decimal<38, 0> diff --git a/cases/arithmetic_decimal/min_decimal.yaml b/cases/arithmetic_decimal/min_decimal.yaml new file mode 100644 index 0000000..eca2f59 --- /dev/null +++ b/cases/arithmetic_decimal/min_decimal.yaml @@ -0,0 +1,77 @@ +base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml +function: min +cases: + - group: + id: basic + description: Basic examples without any special cases + args: + - value: !decimallist [20, -3, 1, -10, 0, 5] + type: decimal<2, 0> + result: + value: !decimal -10 + type: decimal<2, 0> + - group: basic + args: + - value: !decimallist [-32768, 32767, 20000, -30000] + type: decimal<5, 0> + result: + value: !decimal -32768 + type: decimal<5, 0> + - group: basic + args: + - value: !decimallist [-214748648, 214748647, 21470048, 4000000] + type: decimal<9, 0> + result: + value: !decimal -214748648 + type: decimal<9, 0> + - group: basic + args: + - value: !decimallist [2000000000, -3217908979, 629000000, -100000000, 0, 987654321] + type: decimal<10, 0> + result: + value: !decimal -3217908979 + type: decimal<10, 0> + - group: basic + args: + - value: !decimallist [2.5, 0, 5.0, -2.5, -7.5] + type: decimal<2, 1> + result: + value: !decimal -7.5 + type: decimal<2, 1> + - group: basic + args: + - value: !decimallist [99999999999999999999999999999999999999, -99999999999999999999999999999999999998, -99999999999999999999999999999999999997, 0, 1111] + type: decimal<38, 0> + result: + value: !decimal -99999999999999999999999999999999999998 + type: decimal<38, 0> + - group: + id: null_handling + description: Examples with null as unput or output + args: + - value: !decimallist [Null, Null, Null] + type: decimal<1, 0> + result: + value: !decimal Null + type: decimal<1, 0> + - group: null_handling + args: + - value: !decimallist [] + type: decimal<1, 0> + result: + value: !decimal Null + type: decimal<1, 0> + - group: null_handling + args: + - value: !decimallist [2000000000, Null, 629000000, -100000000, Null, 987654321] + type: decimal<10, 0> + result: + value: !decimal -100000000 + type: decimal<10, 0> + - group: null_handling + args: + - value: !decimallist [-99999999999999999999999999999999999998, Null, 99999999999999999999999999999999999999, Null] + type: decimal<38, 0> + result: + value: !decimal -99999999999999999999999999999999999998 + type: decimal<38, 0> diff --git a/cases/arithmetic_decimal/sum_decimal.yaml b/cases/arithmetic_decimal/sum_decimal.yaml new file mode 100644 index 0000000..b7414b5 --- /dev/null +++ b/cases/arithmetic_decimal/sum_decimal.yaml @@ -0,0 +1,66 @@ +base_uri: https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic_decimal.yaml +function: sum +cases: + - group: + id: basic + description: Basic examples without any special cases + args: + - value: !decimallist [0, -1, 2, 20] + type: decimal<2, 0> + result: + value: !decimal 21 + type: decimal<38, 0> + - group: basic + args: + - value: !decimallist [2000000, -3217908, 629000, -100000, 0, 987654] + type: decimal<7, 0> + result: + value: !decimal 298746 + type: decimal<38, 0> + - group: basic + args: + - value: !decimallist [2.5, 0, 5.0, -2.5, -7.5] + type: decimal<2, 1> + result: + value: !decimal -2.5 + type: decimal<38, 2> + - group: basic + args: + - value: !decimallist [2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875] + type: decimal<23, 22> + result: + value: !decimal 16.5000021457672119140625 + type: decimal<38, 22> + - group: + id: overflow + description: Examples demonstrating overflow behavior + args: + - value: !decimallist [99999999999999999999999999999999999999, 1, 1, 1, 1, 99999999999999999999999999999999999999] + type: decimal<38, 0> + options: + overflow: ERROR + result: + special: error + - group: + id: null_handling + description: Examples with null as unput or output + args: + - value: !decimallist [Null, Null, Null] + type: decimal<1, 0> + result: + value: !decimal Null + type: decimal<38, 0> + - group: null_handling + args: + - value: !decimallist [] + type: decimal<1, 0> + result: + value: !decimal Null + type: decimal<38, 0> + - group: null_handling + args: + - value: !decimallist [200000, Null, 629000, -10000, 0, 987621] + type: decimal<6, 0> + result: + value: !decimal 1806621 + type: decimal<38, 0> diff --git a/dialects/duckdb.yaml b/dialects/duckdb.yaml index 4d76968..d96f219 100644 --- a/dialects/duckdb.yaml +++ b/dialects/duckdb.yaml @@ -657,6 +657,18 @@ aggregate_functions: aggregate: true supported_kernels: - any +- name: arithmetic_decimal.min + aggregate: true + supported_kernels: + - dec +- name: arithmetic_decimal.max + aggregate: true + supported_kernels: + - dec +- name: arithmetic_decimal.sum + aggregate: true + supported_kernels: + - dec - name: boolean.bool_and aggregate: true supported_kernels: diff --git a/dialects/snowflake.yaml b/dialects/snowflake.yaml index 4141216..3b20318 100644 --- a/dialects/snowflake.yaml +++ b/dialects/snowflake.yaml @@ -350,6 +350,18 @@ aggregate_functions: aggregate: true supported_kernels: - fp64 +- name: arithmetic_decimal.min + aggregate: true + supported_kernels: + - dec +- name: arithmetic_decimal.max + aggregate: true + supported_kernels: + - dec +- name: arithmetic_decimal.sum + aggregate: true + supported_kernels: + - dec - name: boolean.bool_and local_name: booland_agg aggregate: true