diff --git a/duckdb b/duckdb index eb9f63a..d6e7cdb 160000 --- a/duckdb +++ b/duckdb @@ -1 +1 @@ -Subproject commit eb9f63a1cd506afe1fad8125d6f123868e979b5d +Subproject commit d6e7cdbc6062604fc0e34712cb9d6099f9588257 diff --git a/duckdb-r b/duckdb-r index 59f1673..3e6b47d 160000 --- a/duckdb-r +++ b/duckdb-r @@ -1 +1 @@ -Subproject commit 59f1673ba4dcdf327782f6b7dfcf44f7d3ec7666 +Subproject commit 3e6b47da6c96128616e446833ff6991ed22aa810 diff --git a/scripts/generate_custom_functions.py b/scripts/generate_custom_functions.py index 98aabb6..81d4548 100644 --- a/scripts/generate_custom_functions.py +++ b/scripts/generate_custom_functions.py @@ -3,13 +3,10 @@ import yaml import regex -def parse_yaml(file_path): - with open(file_path, 'r') as file: - yaml_data = yaml.safe_load(file) - functions = [] - for function_data in yaml_data.get('scalar_functions', []): +def parse_function_data(functions,yaml_data,function_type): + for function_data in yaml_data.get(function_type, []): function = { 'name': function_data['name'], 'impls_args': [] @@ -24,7 +21,14 @@ def parse_yaml(file_path): function['impls_args'].append(args) functions.append(function) + return functions +def parse_yaml(file_path): + with open(file_path, 'r') as file: + yaml_data = yaml.safe_load(file) + functions = [] + functions = parse_function_data(functions,yaml_data,'scalar_functions') + functions = parse_function_data(functions,yaml_data,'aggregate_functions') return functions def get_custom_functions(): @@ -39,8 +43,9 @@ def get_custom_functions(): type_str = "{" for args in impls_args: type_value = regex.sub(r'<[^>]*>', '', args["value"]) - type_set.add(type_value) - type_str += f"\"{type_value}\"," + if (len(type_value) != 0): + type_set.add(type_value) + type_str += f"\"{type_value}\"," type_str = type_str[:-1] type_str += "}" function_name = function["name"] diff --git a/src/custom_extensions.cpp b/src/custom_extensions.cpp index 6841293..b314e36 100644 --- a/src/custom_extensions.cpp +++ b/src/custom_extensions.cpp @@ -1,5 +1,6 @@ #include "custom_extensions/custom_extensions.hpp" #include "duckdb/common/types.hpp" +#include "duckdb/common/string_util.hpp" namespace duckdb { @@ -16,10 +17,84 @@ string TransformTypes(const ::substrait::Type &type) { return str_type; } +vector GetAllTypes() { + return {{"bool"}, + {"i8"}, + {"i16"}, + {"i32"}, + {"i64"}, + {"fp32"}, + {"fp64"}, + {"string"}, + {"binary"}, + {"timestamp"}, + {"date"}, + {"time"}, + {"interval_year"}, + {"interval_day"}, + {"timestamp_tz"}, + {"uuid"}, + {"varchar"}, + {"fixed_binary"}, + {"decimal"}, + {"precision_timestamp"}, + {"precision_timestamp_tz"}}; +} + +// Recurse over the whole shebang +void SubstraitCustomFunctions::InsertAllFunctions(const vector> &all_types, vector &indices, + int depth, string &name, string &file_path) { + if (depth == indices.size()) { + vector types; + for (idx_t i = 0; i < indices.size(); i++) { + auto type = all_types[i][indices[i]]; + type = StringUtil::Replace(type, "boolean", "bool"); + types.push_back(type); + } + if (types.empty()) { + any_arg_functions[{name, types}] = {{name, types}, std::move(file_path)}; + } else { + bool many_arg = false; + string type = types[0]; + for (auto &t : types) { + if (!t.empty() && t[t.size() - 1] == '?') { + // If all types are equal and they end with ? we have a many_argument function + many_arg = type == t; + } + } + if (many_arg) { + many_arg_functions[{name, types}] = {{name, types}, std::move(file_path)}; + } else { + custom_functions[{name, types}] = {{name, types}, std::move(file_path)}; + } + } + + return; + } + for (int i = 0; i < all_types[depth].size(); ++i) { + indices[depth] = i; + InsertAllFunctions(all_types, indices, depth + 1, name, file_path); + } +} + void SubstraitCustomFunctions::InsertCustomFunction(string name_p, vector types_p, string file_path) { - auto name = std::move(name_p); auto types = std::move(types_p); - custom_functions[{name, types}] = {{name, types}, std::move(file_path)}; + vector> all_types; + for (auto &t : types) { + if (t == "any1" || t == "unknown") { + all_types.emplace_back(GetAllTypes()); + } else { + all_types.push_back({t}); + } + } + // Get the number of dimensions + idx_t num_arguments = all_types.size(); + + // Create a vector to hold the indices + vector idx(num_arguments, 0); + + // Call the helper function with initial depth 0 + InsertAllFunctions(all_types, idx, 0, name_p, file_path); } string SubstraitCustomFunction::GetName() { @@ -49,13 +124,28 @@ SubstraitCustomFunctions::SubstraitCustomFunctions() { Initialize(); }; +vector SubstraitCustomFunctions::GetTypes(const vector<::substrait::Type> &types) const { + vector transformed_types; + for (auto &type : types) { + transformed_types.emplace_back(TransformTypes(type)); + } + return transformed_types; +} + // FIXME: We might have to do DuckDB extensions at some point SubstraitFunctionExtensions SubstraitCustomFunctions::Get(const string &name, const vector<::substrait::Type> &types) const { vector transformed_types; if (types.empty()) { + SubstraitCustomFunction custom_function {name, {}}; + auto it = any_arg_functions.find(custom_function); + if (it != custom_functions.end()) { + // We found it in our substrait custom map, return that + return it->second; + } return {{name, {}}, "native"}; } + for (auto &type : types) { transformed_types.emplace_back(TransformTypes(type)); if (transformed_types.back().empty()) { @@ -63,11 +153,28 @@ SubstraitFunctionExtensions SubstraitCustomFunctions::Get(const string &name, return {{name, {}}, "native"}; } } - SubstraitCustomFunction custom_function {name, {transformed_types}}; - auto it = custom_functions.find(custom_function); - if (it != custom_functions.end()) { - // We found it in our substrait custom map, return that - return it->second; + { + SubstraitCustomFunction custom_function {name, {transformed_types}}; + auto it = custom_functions.find(custom_function); + if (it != custom_functions.end()) { + // We found it in our substrait custom map, return that + return it->second; + } + } + + // check if it's a many argument fit + bool possibly_many_arg = true; + string type = transformed_types[0]; + for (auto &t : transformed_types) { + possibly_many_arg = possibly_many_arg && type == t; + } + if (possibly_many_arg) { + type += '?'; + SubstraitCustomFunction custom_many_function {name, {{type}}}; + auto many_it = many_arg_functions.find(custom_many_function); + if (many_it != many_arg_functions.end()) { + return many_it->second; + } } // TODO: check if this should also print the arg types or not // we did not find it, return it as a native substrait function diff --git a/src/custom_extensions_generated.cpp b/src/custom_extensions_generated.cpp index 6c8a3f2..e2f3d68 100644 --- a/src/custom_extensions_generated.cpp +++ b/src/custom_extensions_generated.cpp @@ -5,22 +5,34 @@ namespace duckdb { void SubstraitCustomFunctions::Initialize() { - InsertCustomFunction("extract", {"", "timestamp_tz", "string"}, "functions_datetime.yaml"); - InsertCustomFunction("extract", {"", "timestamp"}, "functions_datetime.yaml"); - InsertCustomFunction("extract", {"", "date"}, "functions_datetime.yaml"); - InsertCustomFunction("extract", {"", "time"}, "functions_datetime.yaml"); - InsertCustomFunction("extract", {"", "", "timestamp_tz", "string"}, "functions_datetime.yaml"); - InsertCustomFunction("extract", {"", "", "timestamp"}, "functions_datetime.yaml"); - InsertCustomFunction("extract", {"", "", "date"}, "functions_datetime.yaml"); - InsertCustomFunction("extract_boolean", {"", "timestamp"}, "functions_datetime.yaml"); - InsertCustomFunction("extract_boolean", {"", "timestamp_tz", "string"}, "functions_datetime.yaml"); - InsertCustomFunction("extract_boolean", {"", "date"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"timestamp_tz", "string"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"precision_timestamp_tz", "string"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"timestamp"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"precision_timestamp"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"date"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"time"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"timestamp_tz", "string"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"precision_timestamp_tz", "string"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"timestamp"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"precision_timestamp"}, "functions_datetime.yaml"); + InsertCustomFunction("extract", {"date"}, "functions_datetime.yaml"); + InsertCustomFunction("extract_boolean", {"timestamp"}, "functions_datetime.yaml"); + InsertCustomFunction("extract_boolean", {"timestamp_tz", "string"}, "functions_datetime.yaml"); + InsertCustomFunction("extract_boolean", {"date"}, "functions_datetime.yaml"); InsertCustomFunction("add", {"timestamp", "interval_year"}, "functions_datetime.yaml"); InsertCustomFunction("add", {"timestamp_tz", "interval_year", "string"}, "functions_datetime.yaml"); InsertCustomFunction("add", {"date", "interval_year"}, "functions_datetime.yaml"); InsertCustomFunction("add", {"timestamp", "interval_day"}, "functions_datetime.yaml"); InsertCustomFunction("add", {"timestamp_tz", "interval_day"}, "functions_datetime.yaml"); InsertCustomFunction("add", {"date", "interval_day"}, "functions_datetime.yaml"); + InsertCustomFunction("multiply", {"i8", "interval_day"}, "functions_datetime.yaml"); + InsertCustomFunction("multiply", {"i16", "interval_day"}, "functions_datetime.yaml"); + InsertCustomFunction("multiply", {"i32", "interval_day"}, "functions_datetime.yaml"); + InsertCustomFunction("multiply", {"i64", "interval_day"}, "functions_datetime.yaml"); + InsertCustomFunction("multiply", {"i8", "interval_year"}, "functions_datetime.yaml"); + InsertCustomFunction("multiply", {"i16", "interval_year"}, "functions_datetime.yaml"); + InsertCustomFunction("multiply", {"i32", "interval_year"}, "functions_datetime.yaml"); + InsertCustomFunction("multiply", {"i64", "interval_year"}, "functions_datetime.yaml"); InsertCustomFunction("add_intervals", {"interval_day", "interval_day"}, "functions_datetime.yaml"); InsertCustomFunction("add_intervals", {"interval_year", "interval_year"}, "functions_datetime.yaml"); InsertCustomFunction("subtract", {"timestamp", "interval_year"}, "functions_datetime.yaml"); @@ -61,15 +73,27 @@ void SubstraitCustomFunctions::Initialize() { InsertCustomFunction("strftime", {"timestamp_tz", "string", "string"}, "functions_datetime.yaml"); InsertCustomFunction("strftime", {"date", "string"}, "functions_datetime.yaml"); InsertCustomFunction("strftime", {"time", "string"}, "functions_datetime.yaml"); - InsertCustomFunction("round_temporal", {"timestamp", "", "", "i64", "timestamp"}, "functions_datetime.yaml"); - InsertCustomFunction("round_temporal", {"timestamp_tz", "", "", "i64", "string", "timestamp_tz"}, + InsertCustomFunction("round_temporal", {"timestamp", "i64", "timestamp"}, "functions_datetime.yaml"); + InsertCustomFunction("round_temporal", {"timestamp_tz", "i64", "string", "timestamp_tz"}, "functions_datetime.yaml"); - InsertCustomFunction("round_temporal", {"date", "", "", "i64", "date"}, "functions_datetime.yaml"); - InsertCustomFunction("round_temporal", {"time", "", "", "i64", "time"}, "functions_datetime.yaml"); - InsertCustomFunction("round_calendar", {"timestamp", "", "", "", "i64"}, "functions_datetime.yaml"); - InsertCustomFunction("round_calendar", {"timestamp_tz", "", "", "", "i64", "string"}, "functions_datetime.yaml"); - InsertCustomFunction("round_calendar", {"date", "", "", "", "i64", "date"}, "functions_datetime.yaml"); - InsertCustomFunction("round_calendar", {"time", "", "", "", "i64", "time"}, "functions_datetime.yaml"); + InsertCustomFunction("round_temporal", {"date", "i64", "date"}, "functions_datetime.yaml"); + InsertCustomFunction("round_temporal", {"time", "i64", "time"}, "functions_datetime.yaml"); + InsertCustomFunction("round_calendar", {"timestamp", "i64"}, "functions_datetime.yaml"); + InsertCustomFunction("round_calendar", {"timestamp_tz", "i64", "string"}, "functions_datetime.yaml"); + InsertCustomFunction("round_calendar", {"date", "i64", "date"}, "functions_datetime.yaml"); + InsertCustomFunction("round_calendar", {"time", "i64", "time"}, "functions_datetime.yaml"); + InsertCustomFunction("min", {"date"}, "functions_datetime.yaml"); + InsertCustomFunction("min", {"time"}, "functions_datetime.yaml"); + InsertCustomFunction("min", {"timestamp"}, "functions_datetime.yaml"); + InsertCustomFunction("min", {"timestamp_tz"}, "functions_datetime.yaml"); + InsertCustomFunction("min", {"interval_day"}, "functions_datetime.yaml"); + InsertCustomFunction("min", {"interval_year"}, "functions_datetime.yaml"); + InsertCustomFunction("max", {"date"}, "functions_datetime.yaml"); + InsertCustomFunction("max", {"time"}, "functions_datetime.yaml"); + InsertCustomFunction("max", {"timestamp"}, "functions_datetime.yaml"); + InsertCustomFunction("max", {"timestamp_tz"}, "functions_datetime.yaml"); + InsertCustomFunction("max", {"interval_day"}, "functions_datetime.yaml"); + InsertCustomFunction("max", {"interval_year"}, "functions_datetime.yaml"); InsertCustomFunction("not_equal", {"any1", "any1"}, "functions_comparison.yaml"); InsertCustomFunction("equal", {"any1", "any1"}, "functions_comparison.yaml"); InsertCustomFunction("is_not_distinct_from", {"any1", "any1"}, "functions_comparison.yaml"); @@ -113,6 +137,7 @@ void SubstraitCustomFunctions::Initialize() { InsertCustomFunction("buffer", {"u!geometry", "fp64"}, "functions_geometry.yaml"); InsertCustomFunction("centroid", {"u!geometry"}, "functions_geometry.yaml"); InsertCustomFunction("minimum_bounding_circle", {"u!geometry"}, "functions_geometry.yaml"); + InsertCustomFunction("approx_count_distinct", {"any"}, "functions_aggregate_approx.yaml"); InsertCustomFunction("ln", {"fp32"}, "functions_logarithmic.yaml"); InsertCustomFunction("ln", {"fp64"}, "functions_logarithmic.yaml"); InsertCustomFunction("log10", {"fp32"}, "functions_logarithmic.yaml"); @@ -128,6 +153,11 @@ void SubstraitCustomFunctions::Initialize() { InsertCustomFunction("multiply", {"unknown", "unknown"}, "unknown.yaml"); InsertCustomFunction("divide", {"unknown", "unknown"}, "unknown.yaml"); InsertCustomFunction("modulus", {"unknown", "unknown"}, "unknown.yaml"); + InsertCustomFunction("sum", {"unknown"}, "unknown.yaml"); + InsertCustomFunction("avg", {"unknown"}, "unknown.yaml"); + InsertCustomFunction("min", {"unknown"}, "unknown.yaml"); + InsertCustomFunction("max", {"unknown"}, "unknown.yaml"); + InsertCustomFunction("count", {"unknown"}, "unknown.yaml"); InsertCustomFunction("concat", {"varchar"}, "functions_string.yaml"); InsertCustomFunction("concat", {"string"}, "functions_string.yaml"); InsertCustomFunction("like", {"varchar", "varchar"}, "functions_string.yaml"); @@ -238,6 +268,7 @@ void SubstraitCustomFunctions::Initialize() { InsertCustomFunction("string_split", {"string", "string"}, "functions_string.yaml"); InsertCustomFunction("regexp_string_split", {"varchar", "varchar"}, "functions_string.yaml"); InsertCustomFunction("regexp_string_split", {"string", "string"}, "functions_string.yaml"); + InsertCustomFunction("string_agg", {"string", "string"}, "functions_string.yaml"); InsertCustomFunction("add", {"i8", "i8"}, "functions_arithmetic.yaml"); InsertCustomFunction("add", {"i16", "i16"}, "functions_arithmetic.yaml"); InsertCustomFunction("add", {"i32", "i32"}, "functions_arithmetic.yaml"); @@ -340,6 +371,65 @@ void SubstraitCustomFunctions::Initialize() { InsertCustomFunction("bitwise_xor", {"i16", "i16"}, "functions_arithmetic.yaml"); InsertCustomFunction("bitwise_xor", {"i32", "i32"}, "functions_arithmetic.yaml"); InsertCustomFunction("bitwise_xor", {"i64", "i64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum", {"i8"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum", {"i16"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum", {"i32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum", {"i64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum", {"fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum", {"fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum0", {"i8"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum0", {"i16"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum0", {"i32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum0", {"i64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum0", {"fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("sum0", {"fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("avg", {"i8"}, "functions_arithmetic.yaml"); + InsertCustomFunction("avg", {"i16"}, "functions_arithmetic.yaml"); + InsertCustomFunction("avg", {"i32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("avg", {"i64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("avg", {"fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("avg", {"fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("min", {"i8"}, "functions_arithmetic.yaml"); + InsertCustomFunction("min", {"i16"}, "functions_arithmetic.yaml"); + InsertCustomFunction("min", {"i32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("min", {"i64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("min", {"fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("min", {"fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("min", {"timestamp"}, "functions_arithmetic.yaml"); + InsertCustomFunction("min", {"timestamp_tz"}, "functions_arithmetic.yaml"); + InsertCustomFunction("max", {"i8"}, "functions_arithmetic.yaml"); + InsertCustomFunction("max", {"i16"}, "functions_arithmetic.yaml"); + InsertCustomFunction("max", {"i32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("max", {"i64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("max", {"fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("max", {"fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("max", {"timestamp"}, "functions_arithmetic.yaml"); + InsertCustomFunction("max", {"timestamp_tz"}, "functions_arithmetic.yaml"); + InsertCustomFunction("product", {"i8"}, "functions_arithmetic.yaml"); + InsertCustomFunction("product", {"i16"}, "functions_arithmetic.yaml"); + InsertCustomFunction("product", {"i32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("product", {"i64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("product", {"fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("product", {"fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("std_dev", {"fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("std_dev", {"fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("variance", {"fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("variance", {"fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("corr", {"fp32", "fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("corr", {"fp64", "fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("mode", {"i8"}, "functions_arithmetic.yaml"); + InsertCustomFunction("mode", {"i16"}, "functions_arithmetic.yaml"); + InsertCustomFunction("mode", {"i32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("mode", {"i64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("mode", {"fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("mode", {"fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("median", {"i8"}, "functions_arithmetic.yaml"); + InsertCustomFunction("median", {"i16"}, "functions_arithmetic.yaml"); + InsertCustomFunction("median", {"i32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("median", {"i64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("median", {"fp32"}, "functions_arithmetic.yaml"); + InsertCustomFunction("median", {"fp64"}, "functions_arithmetic.yaml"); + InsertCustomFunction("quantile", {"i64", "any"}, "functions_arithmetic.yaml"); InsertCustomFunction("ceil", {"fp32"}, "functions_rounding.yaml"); InsertCustomFunction("ceil", {"fp64"}, "functions_rounding.yaml"); InsertCustomFunction("floor", {"fp32"}, "functions_rounding.yaml"); @@ -350,16 +440,26 @@ void SubstraitCustomFunctions::Initialize() { InsertCustomFunction("round", {"i64", "i32"}, "functions_rounding.yaml"); InsertCustomFunction("round", {"fp32", "i32"}, "functions_rounding.yaml"); InsertCustomFunction("round", {"fp64", "i32"}, "functions_rounding.yaml"); + InsertCustomFunction("count", {"any"}, "functions_aggregate_generic.yaml"); + InsertCustomFunction("count", {}, "functions_aggregate_generic.yaml"); + InsertCustomFunction("any_value", {"any"}, "functions_aggregate_generic.yaml"); InsertCustomFunction("or", {"boolean?"}, "functions_boolean.yaml"); InsertCustomFunction("and", {"boolean?"}, "functions_boolean.yaml"); InsertCustomFunction("and_not", {"boolean?", "boolean?"}, "functions_boolean.yaml"); InsertCustomFunction("xor", {"boolean?", "boolean?"}, "functions_boolean.yaml"); InsertCustomFunction("not", {"boolean?"}, "functions_boolean.yaml"); + InsertCustomFunction("bool_and", {"boolean"}, "functions_boolean.yaml"); + InsertCustomFunction("bool_or", {"boolean"}, "functions_boolean.yaml"); InsertCustomFunction("add", {"decimal", "decimal"}, "functions_arithmetic_decimal.yaml"); InsertCustomFunction("subtract", {"decimal", "decimal"}, "functions_arithmetic_decimal.yaml"); InsertCustomFunction("multiply", {"decimal", "decimal"}, "functions_arithmetic_decimal.yaml"); InsertCustomFunction("divide", {"decimal", "decimal"}, "functions_arithmetic_decimal.yaml"); InsertCustomFunction("modulus", {"decimal", "decimal"}, "functions_arithmetic_decimal.yaml"); + InsertCustomFunction("sum", {"DECIMAL"}, "functions_arithmetic_decimal.yaml"); + InsertCustomFunction("avg", {"DECIMAL"}, "functions_arithmetic_decimal.yaml"); + InsertCustomFunction("min", {"DECIMAL"}, "functions_arithmetic_decimal.yaml"); + InsertCustomFunction("max", {"DECIMAL"}, "functions_arithmetic_decimal.yaml"); + InsertCustomFunction("sum0", {"DECIMAL"}, "functions_arithmetic_decimal.yaml"); InsertCustomFunction("index_in", {"T", "List"}, "functions_set.yaml"); } diff --git a/src/from_substrait.cpp b/src/from_substrait.cpp index 566e21d..0081a96 100644 --- a/src/from_substrait.cpp +++ b/src/from_substrait.cpp @@ -404,25 +404,26 @@ shared_ptr SubstraitToDuckDB::TransformJoinOp(const substrait::Rel &so throw InternalException("Unsupported join type"); } unique_ptr join_condition = TransformExpr(sjoin.expression()); - return make_shared(TransformOp(sjoin.left())->Alias("left"), - TransformOp(sjoin.right())->Alias("right"), std::move(join_condition), djointype); + return make_shared_ptr(TransformOp(sjoin.left())->Alias("left"), + TransformOp(sjoin.right())->Alias("right"), std::move(join_condition), + djointype); } shared_ptr SubstraitToDuckDB::TransformCrossProductOp(const substrait::Rel &sop) { auto &sub_cross = sop.cross(); - return make_shared(TransformOp(sub_cross.left())->Alias("left"), - TransformOp(sub_cross.right())->Alias("right")); + return make_shared_ptr(TransformOp(sub_cross.left())->Alias("left"), + TransformOp(sub_cross.right())->Alias("right")); } shared_ptr SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop) { auto &slimit = sop.fetch(); - return make_shared(TransformOp(slimit.input()), slimit.count(), slimit.offset()); + return make_shared_ptr(TransformOp(slimit.input()), slimit.count(), slimit.offset()); } shared_ptr SubstraitToDuckDB::TransformFilterOp(const substrait::Rel &sop) { auto &sfilter = sop.filter(); - return make_shared(TransformOp(sfilter.input()), TransformExpr(sfilter.condition())); + return make_shared_ptr(TransformOp(sfilter.input()), TransformExpr(sfilter.condition())); } shared_ptr SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop) { @@ -435,8 +436,8 @@ shared_ptr SubstraitToDuckDB::TransformProjectOp(const substrait::Rel for (size_t i = 0; i < expressions.size(); i++) { mock_aliases.push_back("expr_" + to_string(i)); } - return make_shared(TransformOp(sop.project().input()), std::move(expressions), - std::move(mock_aliases)); + return make_shared_ptr(TransformOp(sop.project().input()), std::move(expressions), + std::move(mock_aliases)); } shared_ptr SubstraitToDuckDB::TransformAggregateOp(const substrait::Rel &sop) { @@ -463,8 +464,8 @@ shared_ptr SubstraitToDuckDB::TransformAggregateOp(const substrait::Re expressions.push_back(make_uniq(RemapFunctionName(function_name), std::move(children))); } - return make_shared(TransformOp(sop.aggregate().input()), std::move(expressions), - std::move(groups)); + return make_shared_ptr(TransformOp(sop.aggregate().input()), std::move(expressions), + std::move(groups)); } shared_ptr SubstraitToDuckDB::TransformReadOp(const substrait::Rel &sop) { @@ -502,7 +503,7 @@ shared_ptr SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so } if (sget.has_filter()) { - scan = make_shared(std::move(scan), TransformExpr(sget.filter())); + scan = make_shared_ptr(std::move(scan), TransformExpr(sget.filter())); } if (sget.has_projection()) { @@ -516,7 +517,7 @@ shared_ptr SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so expressions.push_back(make_uniq(sproj.field() + 1)); } - scan = make_shared(std::move(scan), std::move(expressions), std::move(aliases)); + scan = make_shared_ptr(std::move(scan), std::move(expressions), std::move(aliases)); } return scan; @@ -527,7 +528,7 @@ shared_ptr SubstraitToDuckDB::TransformSortOp(const substrait::Rel &so for (auto &sordf : sop.sort().sorts()) { order_nodes.push_back(TransformOrder(sordf)); } - return make_shared(TransformOp(sop.sort().input()), std::move(order_nodes)); + return make_shared_ptr(TransformOp(sop.sort().input()), std::move(order_nodes)); } static duckdb::SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop) { @@ -562,7 +563,7 @@ shared_ptr SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop auto lhs = TransformOp(inputs[0]); auto rhs = TransformOp(inputs[1]); - return make_shared(std::move(lhs), std::move(rhs), type); + return make_shared_ptr(std::move(lhs), std::move(rhs), type); } shared_ptr SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) { @@ -599,7 +600,7 @@ shared_ptr SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot aliases.push_back(column_name); expressions.push_back(make_uniq(id++)); } - return make_shared(TransformOp(sop.input()), std::move(expressions), aliases); + return make_shared_ptr(TransformOp(sop.input()), std::move(expressions), aliases); } shared_ptr SubstraitToDuckDB::TransformPlan() { diff --git a/src/include/custom_extensions/custom_extensions.hpp b/src/include/custom_extensions/custom_extensions.hpp index 3bd1062..34369c3 100644 --- a/src/include/custom_extensions/custom_extensions.hpp +++ b/src/include/custom_extensions/custom_extensions.hpp @@ -56,15 +56,33 @@ struct HashSubstraitFunctions { } }; +struct HashSubstraitFunctionsName { + size_t operator()(SubstraitCustomFunction const &custom_function) const noexcept { + // Hash Name + return Hash(custom_function.name.c_str()); + } +}; + class SubstraitCustomFunctions { public: SubstraitCustomFunctions(); SubstraitFunctionExtensions Get(const string &name, const vector<::substrait::Type> &types) const; + vector GetTypes(const vector<::substrait::Type> &types) const; void Initialize(); private: + // For Regular Functions std::unordered_map custom_functions; + // For * Functions + std::unordered_map + any_arg_functions; + // For ? Functions + // When we have an argument ending with ? it means this argument can repeat many times + std::unordered_map many_arg_functions; + void InsertCustomFunction(string name_p, vector types_p, string file_path); + void InsertAllFunctions(const vector> &all_types, vector &indices, int depth, string &name_p, + string &file_path); }; } // namespace duckdb \ No newline at end of file diff --git a/src/include/to_substrait.hpp b/src/include/to_substrait.hpp index 7b5d5e1..ac4ebf8 100644 --- a/src/include/to_substrait.hpp +++ b/src/include/to_substrait.hpp @@ -16,7 +16,8 @@ namespace duckdb { class DuckDBToSubstrait { public: - explicit DuckDBToSubstrait(ClientContext &context, duckdb::LogicalOperator &dop) : context(context) { + explicit DuckDBToSubstrait(ClientContext &context, duckdb::LogicalOperator &dop, bool strict_p) + : context(context), strict(strict_p) { TransformPlan(dop); }; @@ -151,12 +152,13 @@ class DuckDBToSubstrait { //! Variable that holds information about yaml function extensions static const SubstraitCustomFunctions custom_functions; uint64_t last_function_id = 1; - uint64_t last_extension_id = 1; - + uint64_t last_uri_id = 1; //! The substrait Plan substrait::Plan plan; ClientContext &context; - - uint64_t max_string_length = 1; + //! If we are generating a query plan on strict mode we will error if + //! things don't go perfectly shiny + bool strict; + string errors; }; } // namespace duckdb diff --git a/src/substrait_extension.cpp b/src/substrait_extension.cpp index fae645c..cf3f2b0 100644 --- a/src/substrait_extension.cpp +++ b/src/substrait_extension.cpp @@ -17,10 +17,11 @@ namespace duckdb { struct ToSubstraitFunctionData : public TableFunctionData { - ToSubstraitFunctionData() { - } + ToSubstraitFunctionData() = default; string query; - bool enable_optimizer; + bool enable_optimizer = false; + //! We will fail the conversion on possible warnings + bool strict = false; bool finished = false; }; @@ -34,45 +35,50 @@ static void VerifyJSONRoundtrip(unique_ptr &query_plan, Connect static void VerifyBlobRoundtrip(unique_ptr &query_plan, Connection &con, ToSubstraitFunctionData &data, const string &serialized); -static bool SetOptimizationOption(const ClientConfig &config, const duckdb::named_parameter_map_t &named_params) { +static void SetOptions(ToSubstraitFunctionData &function, const ClientConfig &config, + const duckdb::named_parameter_map_t &named_params) { + bool optimizer_option_set = false; for (const auto ¶m : named_params) { auto loption = StringUtil::Lower(param.first); // If the user has explicitly requested to enable/disable the optimizer when // generating Substrait, then that takes precedence. if (loption == "enable_optimizer") { - return BooleanValue::Get(param.second); + function.enable_optimizer = BooleanValue::Get(param.second); + optimizer_option_set = true; + } + if (loption == "strict") { + function.strict = BooleanValue::Get(param.second); } } - - // If the user has not specified what they want, fall back to the settings - // on the connection (e.g. if the optimizer was disabled by the user at - // the connection level, it would be surprising to enable the optimizer - // when generating Substrait). - return config.enable_optimizer; + if (!optimizer_option_set) { + // If the user has not specified what they want, fall back to the settings + // on the connection (e.g. if the optimizer was disabled by the user at + // the connection level, it would be surprising to enable the optimizer + // when generating Substrait). + function.enable_optimizer = config.enable_optimizer; + } } static unique_ptr InitToSubstraitFunctionData(const ClientConfig &config, TableFunctionBindInput &input) { auto result = make_uniq(); result->query = input.inputs[0].ToString(); - result->enable_optimizer = SetOptimizationOption(config, input.named_parameters); - return std::move(result); + SetOptions(*result, config, input.named_parameters); + return result; } static unique_ptr ToSubstraitBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { return_types.emplace_back(LogicalType::BLOB); names.emplace_back("Plan Blob"); - auto result = InitToSubstraitFunctionData(context.config, input); - return std::move(result); + return InitToSubstraitFunctionData(context.config, input); } static unique_ptr ToJsonBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { return_types.emplace_back(LogicalType::VARCHAR); names.emplace_back("Json"); - auto result = InitToSubstraitFunctionData(context.config, input); - return std::move(result); + return InitToSubstraitFunctionData(context.config, input); } shared_ptr SubstraitPlanToDuckDBRel(Connection &conn, const string &serialized, bool json = false) { @@ -136,7 +142,7 @@ static DuckDBToSubstrait InitPlanExtractor(ClientContext &context, ToSubstraitFu DBConfig::GetConfig(*new_conn.context).options.disabled_optimizers = disabled_optimizers; query_plan = new_conn.context->ExtractPlan(data.query); - return DuckDBToSubstrait(context, *query_plan); + return DuckDBToSubstrait(context, *query_plan, data.strict); } static void ToSubFunctionInternal(ClientContext &context, ToSubstraitFunctionData &data, DataChunk &output, @@ -255,6 +261,7 @@ void InitializeGetSubstrait(Connection &con) { // binary from a valid SQL Query TableFunction to_sub_func("get_substrait", {LogicalType::VARCHAR}, ToSubFunction, ToSubstraitBind); to_sub_func.named_parameters["enable_optimizer"] = LogicalType::BOOLEAN; + to_sub_func.named_parameters["strict"] = LogicalType::BOOLEAN; CreateTableFunctionInfo to_sub_info(to_sub_func); catalog.CreateTableFunction(*con.context, to_sub_info); } diff --git a/src/to_substrait.cpp b/src/to_substrait.cpp index 7b5a1f4..b268e3c 100644 --- a/src/to_substrait.cpp +++ b/src/to_substrait.cpp @@ -21,9 +21,22 @@ namespace duckdb { const std::unordered_map DuckDBToSubstrait::function_names_remap = { - {"mod", "modulus"}, {"stddev", "std_dev"}, {"prefix", "starts_with"}, {"suffix", "ends_with"}, - {"substr", "substring"}, {"length", "char_length"}, {"isnan", "is_nan"}, {"isfinite", "is_finite"}, - {"isinf", "is_infinite"}, {"sum_no_overflow", "sum"}, {"count_star", "count"}, {"~~", "like"}}; + {"mod", "modulus"}, + {"stddev", "std_dev"}, + {"prefix", "starts_with"}, + {"suffix", "ends_with"}, + {"substr", "substring"}, + {"length", "char_length"}, + {"isnan", "is_nan"}, + {"isfinite", "is_finite"}, + {"isinf", "is_infinite"}, + {"sum_no_overflow", "sum"}, + {"count_star", "count"}, + {"~~", "like"}, + {"*", "multiply"}, + {"-", "subtract"}, + {"+", "add"}, + {"/", "divide"}}; const case_insensitive_set_t DuckDBToSubstrait::valid_extract_subfields = { "year", "month", "day", "decade", "century", "millenium", @@ -517,29 +530,50 @@ uint64_t DuckDBToSubstrait::RegisterFunction(const string &name, vector<::substr if (name.empty()) { throw InternalException("Missing function name"); } - // FIXME: For now I'm ignoring DuckDB functions that are either not mapped to native substrait or custom substrait - // extensions auto function = custom_functions.Get(name, args_types); - idx_t uri_reference = 0; + auto substrait_extensions = plan.mutable_extension_uris(); if (!function.IsNative()) { auto extensionURI = function.GetExtensionURI(); auto it = extension_uri_map.find(extensionURI); if (it == extension_uri_map.end()) { // We have to add this extension + extension_uri_map[extensionURI] = last_uri_id; auto allocated_string = new string(); *allocated_string = extensionURI; - auto uri = plan.add_extension_uris(); + auto uri = new ::substrait::extensions::SimpleExtensionURI(); uri->set_allocated_uri(allocated_string); - uri->set_extension_uri_anchor(last_extension_id++); + uri->set_extension_uri_anchor(last_uri_id); + substrait_extensions->AddAllocated(uri); + last_uri_id++; } - uri_reference = extension_uri_map[extensionURI]; } if (functions_map.find(function.function.GetName()) == functions_map.end()) { auto function_id = last_function_id++; auto sfun = plan.add_extensions()->mutable_extension_function(); sfun->set_function_anchor(function_id); sfun->set_name(function.function.GetName()); - sfun->set_extension_uri_reference(uri_reference); + if (!function.IsNative()) { + // We only define URI if not native + sfun->set_extension_uri_reference(extension_uri_map[function.GetExtensionURI()]); + } else { + // Function was not found in the yaml files + sfun->set_extension_uri_reference(0); + if (strict) { + // Produce warning message + std::ostringstream error; + // Casting Error Message + error << "Could not find function \"" << function.function.GetName() << "\" with argument types: ("; + auto types = custom_functions.GetTypes(args_types); + for (idx_t i = 0; i < types.size(); i++) { + error << "\'" << types[i] << "\'"; + if (i != types.size() - 1) { + error << ", "; + } + } + error << ")" << std::endl; + errors += error.str(); + } + } functions_map[function.function.GetName()] = function_id; } return functions_map[function.function.GetName()]; @@ -1329,6 +1363,10 @@ substrait::RelRoot *DuckDBToSubstrait::TransformRootOp(LogicalOperator &dop) { void DuckDBToSubstrait::TransformPlan(LogicalOperator &dop) { plan.add_relations()->set_allocated_root(TransformRootOp(dop)); + if (strict && !errors.empty()) { + throw InvalidInputException("Strict Mode is set to true, and the following warnings/errors happened. \n" + + errors); + } auto version = plan.mutable_version(); version->set_major_number(0); version->set_minor_number(39); diff --git a/test/python/test_validator.py b/test/python/test_validator.py index 4b14f72..52ac4db 100644 --- a/test/python/test_validator.py +++ b/test/python/test_validator.py @@ -13,8 +13,6 @@ def run_substrait_validator(con, query): c.override_diagnostic_level(1, "warning", "info") # validator limitation: did not attempt to resolve YAML c.override_diagnostic_level(2001, "warning", "info") - # Function Anchor to YAML file, no clue what is that - c.override_diagnostic_level(3001, "error", "info") # too few field names c.override_diagnostic_level(4003, "error", "info") # Validator being of a different version than substrait @@ -33,7 +31,7 @@ def run_tpch_validator(require, query_number): run_substrait_validator(con,query) -@pytest.mark.parametrize('query_number', [1,3,5,6,7,8,9,10,11,12,13,14,15,19]) +@pytest.mark.parametrize('query_number', [1,3,5,6,7,8,9,10,12,14,19]) def test_substrait_tpch_validator(require,query_number): run_tpch_validator(require,query_number) @@ -68,3 +66,16 @@ def test_substrait_tpch_validator_21(require): @pytest.mark.skip(reason="DuckDB Compilation: INTERNAL Error: INTERNAL Error: DELIM_JOIN") def test_substrait_tpch_validator_22(require): run_tpch_validator(require,22) + +@pytest.mark.skip(reason="Could not find function \"first\" with argument types: ('decimal')") +def test_substrait_tpch_validator_11(require): + run_tpch_validator(require,11) + +@pytest.mark.skip(reason="Could not find function \"!~~\" with argument types: ('string', 'string')") +def test_substrait_tpch_validator_13(require): + run_tpch_validator(require,13) + +@pytest.mark.skip(reason="Could not find function \"first\" with argument types: ('decimal')") +def test_substrait_tpch_validator_15(require): + run_tpch_validator(require,15) + diff --git a/test/sql/test_custom_function.test b/test/sql/test_custom_function.test index 7019cd3..8d1519b 100644 --- a/test/sql/test_custom_function.test +++ b/test/sql/test_custom_function.test @@ -52,12 +52,12 @@ CALL get_substrait_json('select trim(a, ''<'') from t_2') query I CALL get_substrait_json('select sum(a) from t') ---- -:.*https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml.* +:.*https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml.* query I CALL get_substrait_json('select sum(a) from t') ---- -:.*"sum".* +:.*"sum:fp64".* # Test mix diff --git a/test/sql/test_substrait.test b/test/sql/test_substrait.test index ac58d21..6ad0528 100644 --- a/test/sql/test_substrait.test +++ b/test/sql/test_substrait.test @@ -19,7 +19,7 @@ insert into crossfit values ('Push Ups', 3), ('Pull Ups', 5) , ('Push Jerk', 7), query I CALL get_substrait('select count(exercise) as exercise from crossfit where dificulty_level <=5') ---- -\x12\x09\x1A\x07\x10\x01\x1A\x03lte\x12\x11\x1A\x0F\x10\x02\x1A\x0Bis_not_null\x12\x09\x1A\x07\x10\x03\x1A\x03and\x12\x0B\x1A\x09\x10\x04\x1A\x05count\x1A\xC4\x01\x12\xC1\x01\x0A\xB4\x01:\xB1\x01\x12\xA4\x01\x22\xA1\x01\x12\x90\x01\x0A\x8D\x01\x12+\x0A\x08exercise\x0A\x0Fdificulty_level\x12\x0E\x0A\x04b\x02\x10\x01\x0A\x04*\x02\x10\x01\x18\x02\x1AJ\x1AH\x08\x03\x1A\x04\x0A\x02\x10\x01\x22\x22\x1A \x1A\x1E\x08\x01\x1A\x04*\x02\x10\x01\x22\x0C\x1A\x0A\x12\x08\x0A\x04\x12\x02\x08\x01\x22\x00\x22\x06\x1A\x04\x0A\x02(\x05\x22\x1A\x1A\x18\x1A\x16\x08\x02\x1A\x04*\x02\x10\x01\x22\x0C\x1A\x0A\x12\x08\x0A\x04\x12\x02\x08\x01\x22\x00\x22\x06\x0A\x02\x0A\x00\x10\x01:\x0A\x0A\x08crossfit\x1A\x00\x22\x0A\x0A\x08\x08\x04*\x04:\x02\x10\x01\x1A\x08\x12\x06\x0A\x02\x12\x00\x22\x00\x12\x08exercise2\x0A\x10\x27*\x06DuckDB +\x0AC\x08\x01\x12?https://github.com/substrait-io/substrait/blob/main/extensions/\x0AY\x08\x02\x12Uhttps://github.com/substrait-io/substrait/blob/main/extensions/functions_boolean.yaml\x0Ac\x08\x03\x12_https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml\x12\x13\x1A\x11\x08\x01\x10\x01\x1A\x0Blte:i32_i32\x12\x17\x1A\x15\x08\x01\x10\x02\x1A\x0Fis_not_null:i32\x12\x11\x1A\x0F\x08\x02\x10\x03\x1A\x09and:bool?\x12\x0D\x1A\x0B\x08\x03\x10\x04\x1A\x05count\x1A\xC4\x01\x12\xC1\x01\x0A\xB4\x01:\xB1\x01\x12\xA4\x01\x22\xA1\x01\x12\x90\x01\x0A\x8D\x01\x12+\x0A\x08exercise\x0A\x0Fdificulty_level\x12\x0E\x0A\x04b\x02\x10\x01\x0A\x04*\x02\x10\x01\x18\x02\x1AJ\x1AH\x08\x03\x1A\x04\x0A\x02\x10\x01\x22\x22\x1A \x1A\x1E\x08\x01\x1A\x04*\x02\x10\x01\x22\x0C\x1A\x0A\x12\x08\x0A\x04\x12\x02\x08\x01\x22\x00\x22\x06\x1A\x04\x0A\x02(\x05\x22\x1A\x1A\x18\x1A\x16\x08\x02\x1A\x04*\x02\x10\x01\x22\x0C\x1A\x0A\x12\x08\x0A\x04\x12\x02\x08\x01\x22\x00\x22\x06\x0A\x02\x0A\x00\x10\x01:\x0A\x0A\x08crossfit\x1A\x00\x22\x0A\x0A\x08\x08\x04*\x04:\x02\x10\x01\x1A\x08\x12\x06\x0A\x02\x12\x00\x22\x00\x12\x08exercise2\x0A\x10\x27*\x06DuckDB query I CALL from_substrait('\x12\x09\x1A\x07\x10\x01\x1A\x03lte\x12\x11\x1A\x0F\x10\x02\x1A\x0Bis_not_null\x12\x09\x1A\x07\x10\x03\x1A\x03and\x12\x0B\x1A\x09\x10\x04\x1A\x05count\x1A\xC1\x01\x12\xBE\x01\x0A\xB1\x01:\xAE\x01\x12\xA1\x01\x22\x9E\x01\x12\x8D\x01\x0A\x8A\x01\x12,\x0A\x08exercise\x0A\x0Fdificulty_level\x12\x0F\x0A\x07\xB2\x01\x04\x08\x0D\x18\x01\x0A\x02*\x00\x18\x02\x1AF\x1AD\x08\x03\x1A\x04\x0A\x02\x10\x01\x22 \x1A\x1E\x1A\x1C\x08\x01\x1A\x02*\x00\x22\x0C\x1A\x0A\x12\x08\x0A\x04\x12\x02\x08\x01\x22\x00\x22\x06\x1A\x04\x0A\x02(\x05\x22\x18\x1A\x16\x1A\x14\x08\x02\x1A\x02*\x00\x22\x0C\x1A\x0A\x12\x08\x0A\x04\x12\x02\x08\x01\x22\x00\x22\x06\x0A\x02\x0A\x00\x10\x01:\x0A\x0A\x08crossfit\x1A\x00\x22\x0A\x0A\x08\x08\x04*\x04:\x02\x10\x01\x1A\x08\x12\x06\x0A\x02\x12\x00\x22\x00\x12\x08exercise2\x0A\x10\x18*\x06DuckDB'::BLOB) diff --git a/test/sql/test_substrait_parquet.test b/test/sql/test_substrait_parquet.test index 64090e1..0e3606a 100644 --- a/test/sql/test_substrait_parquet.test +++ b/test/sql/test_substrait_parquet.test @@ -20,7 +20,7 @@ CREATE TABLE lineitem_parquet AS SELECT * FROM parquet_scan('data/parquet-testin query I CALL get_substrait('SELECT sum(l_extendedprice * l_discount) as revenue FROM lineitem_parquet') ---- -\x12\x07\x1A\x05\x10\x01\x1A\x01*\x12\x09\x1A\x07\x10\x02\x1A\x03sum\x1A\xB3\x03\x12\xB0\x03\x0A\xA4\x03:\xA1\x03\x12\x94\x03\x22\x91\x03\x12\xD8\x02\x0A\xD5\x02\x12\xB0\x02\x0A\x0Al_orderkey\x0A\x09l_partkey\x0A\x09l_suppkey\x0A\x0Cl_linenumber\x0A\x0Al_quantity\x0A\x0Fl_extendedprice\x0A\x0Al_discount\x0A\x05l_tax\x0A\x0Cl_returnflag\x0A\x0Cl_linestatus\x0A\x0Al_shipdate\x0A\x0Cl_commitdate\x0A\x0Dl_receiptdate\x0A\x0El_shipinstruct\x0A\x0Al_shipmode\x0A\x09l_comment\x12b\x0A\x04:\x02\x10\x01\x0A\x04:\x02\x10\x01\x0A\x04:\x02\x10\x01\x0A\x04*\x02\x10\x01\x0A\x04*\x02\x10\x01\x0A\x04Z\x02\x10\x01\x0A\x04Z\x02\x10\x01\x0A\x04Z\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x18\x02\x22\x0C\x0A\x08\x0A\x02\x08\x05\x0A\x02\x08\x06\x10\x01:\x12\x0A\x10lineitem_parquet\x1A\x00\x222\x0A0\x08\x02*\x04Z\x02\x10\x01:&\x1A$\x1A\x22\x08\x01\x1A\x04Z\x02\x10\x01\x22\x0A\x1A\x08\x12\x06\x0A\x02\x12\x00\x22\x00\x22\x0C\x1A\x0A\x12\x08\x0A\x04\x12\x02\x08\x01\x22\x00\x1A\x08\x12\x06\x0A\x02\x12\x00\x22\x00\x12\x07revenue2\x0A\x10\x27*\x06DuckDB +\x0A\x5C\x08\x01\x12Xhttps://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml\x12\x1A\x1A\x18\x08\x01\x10\x01\x1A\x12multiply:fp64_fp64\x12\x10\x1A\x0E\x08\x01\x10\x02\x1A\x08sum:fp64\x1A\xB3\x03\x12\xB0\x03\x0A\xA4\x03:\xA1\x03\x12\x94\x03\x22\x91\x03\x12\xD8\x02\x0A\xD5\x02\x12\xB0\x02\x0A\x0Al_orderkey\x0A\x09l_partkey\x0A\x09l_suppkey\x0A\x0Cl_linenumber\x0A\x0Al_quantity\x0A\x0Fl_extendedprice\x0A\x0Al_discount\x0A\x05l_tax\x0A\x0Cl_returnflag\x0A\x0Cl_linestatus\x0A\x0Al_shipdate\x0A\x0Cl_commitdate\x0A\x0Dl_receiptdate\x0A\x0El_shipinstruct\x0A\x0Al_shipmode\x0A\x09l_comment\x12b\x0A\x04:\x02\x10\x01\x0A\x04:\x02\x10\x01\x0A\x04:\x02\x10\x01\x0A\x04*\x02\x10\x01\x0A\x04*\x02\x10\x01\x0A\x04Z\x02\x10\x01\x0A\x04Z\x02\x10\x01\x0A\x04Z\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x0A\x04b\x02\x10\x01\x18\x02\x22\x0C\x0A\x08\x0A\x02\x08\x05\x0A\x02\x08\x06\x10\x01:\x12\x0A\x10lineitem_parquet\x1A\x00\x222\x0A0\x08\x02*\x04Z\x02\x10\x01:&\x1A$\x1A\x22\x08\x01\x1A\x04Z\x02\x10\x01\x22\x0A\x1A\x08\x12\x06\x0A\x02\x12\x00\x22\x00\x22\x0C\x1A\x0A\x12\x08\x0A\x04\x12\x02\x08\x01\x22\x00\x1A\x08\x12\x06\x0A\x02\x12\x00\x22\x00\x12\x07revenue2\x0A\x10\x27*\x06DuckDB statement ok DROP TABLE lineitem_parquet; diff --git a/test/sql/test_substrait_tpch.test b/test/sql/test_substrait_tpch.test index 081311c..b16888a 100644 --- a/test/sql/test_substrait_tpch.test +++ b/test/sql/test_substrait_tpch.test @@ -62,6 +62,12 @@ CALL get_substrait('SELECT c_custkey, c_name, sum(l_extendedprice * (1 - l_disco statement ok CALL get_substrait('SELECT ps_partkey, sum(ps_supplycost * ps_availqty) AS value FROM partsupp, supplier, nation WHERE ps_suppkey = s_suppkey AND s_nationkey = n_nationkey AND n_name = ''GERMANY'' GROUP BY ps_partkey HAVING sum(ps_supplycost * ps_availqty) > ( SELECT sum(ps_supplycost * ps_availqty) * 0.0001000000 FROM partsupp, supplier, nation WHERE ps_suppkey = s_suppkey AND s_nationkey = n_nationkey AND n_name = ''GERMANY'') ORDER BY value DESC;') +#Q 11 (Test with strict mode) +statement error +CALL get_substrait('SELECT ps_partkey, sum(ps_supplycost * ps_availqty) AS value FROM partsupp, supplier, nation WHERE ps_suppkey = s_suppkey AND s_nationkey = n_nationkey AND n_name = ''GERMANY'' GROUP BY ps_partkey HAVING sum(ps_supplycost * ps_availqty) > ( SELECT sum(ps_supplycost * ps_availqty) * 0.0001000000 FROM partsupp, supplier, nation WHERE ps_suppkey = s_suppkey AND s_nationkey = n_nationkey AND n_name = ''GERMANY'') ORDER BY value DESC;', strict = true) +---- +Could not find function "first" with argument types: ('decimal') + #Q 12 statement ok CALL get_substrait('SELECT l_shipmode, sum( CASE WHEN o_orderpriority = ''1-URGENT'' OR o_orderpriority = ''2-HIGH'' THEN 1 ELSE 0 END) AS high_line_count, sum( CASE WHEN o_orderpriority <> ''1-URGENT'' AND o_orderpriority <> ''2-HIGH'' THEN 1 ELSE 0 END) AS low_line_count FROM orders, lineitem WHERE o_orderkey = l_orderkey AND l_shipmode IN (''MAIL'', ''SHIP'') AND l_commitdate < l_receiptdate AND l_shipdate < l_commitdate AND l_receiptdate >= CAST(''1994-01-01'' AS date) AND l_receiptdate < CAST(''1995-01-01'' AS date) GROUP BY l_shipmode ORDER BY l_shipmode;') @@ -70,6 +76,12 @@ CALL get_substrait('SELECT l_shipmode, sum( CASE WHEN o_orderpriority = ''1-URGE statement ok CALL get_substrait('SELECT c_count, count(*) AS custdist FROM ( SELECT c_custkey, count(o_orderkey) FROM customer LEFT OUTER JOIN orders ON c_custkey = o_custkey AND o_comment NOT LIKE ''%special%requests%'' GROUP BY c_custkey) AS c_orders (c_custkey, c_count) GROUP BY c_count ORDER BY custdist DESC, c_count DESC;') +#Q 13 (Test with strict mode) +statement error +CALL get_substrait('SELECT c_count, count(*) AS custdist FROM ( SELECT c_custkey, count(o_orderkey) FROM customer LEFT OUTER JOIN orders ON c_custkey = o_custkey AND o_comment NOT LIKE ''%special%requests%'' GROUP BY c_custkey) AS c_orders (c_custkey, c_count) GROUP BY c_count ORDER BY custdist DESC, c_count DESC;', strict = true) +---- +Could not find function "!~~" with argument types: ('string', 'string') + #Q 14 statement ok CALL get_substrait('SELECT 100.00 * sum( CASE WHEN p_type LIKE ''PROMO%'' THEN l_extendedprice * (1 - l_discount) ELSE 0 END) / sum(l_extendedprice * (1 - l_discount)) AS promo_revenue FROM lineitem, part WHERE l_partkey = p_partkey AND l_shipdate >= date ''1995-09-01'' AND l_shipdate < CAST(''1995-10-01'' AS date);') @@ -78,6 +90,12 @@ CALL get_substrait('SELECT 100.00 * sum( CASE WHEN p_type LIKE ''PROMO%'' THEN l statement ok CALL get_substrait('SELECT s_suppkey, s_name, s_address, s_phone, total_revenue FROM supplier, ( SELECT l_suppkey AS supplier_no, sum(l_extendedprice * (1 - l_discount)) AS total_revenue FROM lineitem WHERE l_shipdate >= CAST(''1996-01-01'' AS date) AND l_shipdate < CAST(''1996-04-01'' AS date) GROUP BY supplier_no) revenue0 WHERE s_suppkey = supplier_no AND total_revenue = ( SELECT max(total_revenue) FROM ( SELECT l_suppkey AS supplier_no, sum(l_extendedprice * (1 - l_discount)) AS total_revenue FROM lineitem WHERE l_shipdate >= CAST(''1996-01-01'' AS date) AND l_shipdate < CAST(''1996-04-01'' AS date) GROUP BY supplier_no) revenue1) ORDER BY s_suppkey;') +#Q 15 (Test with strict mode) +statement error +CALL get_substrait('SELECT s_suppkey, s_name, s_address, s_phone, total_revenue FROM supplier, ( SELECT l_suppkey AS supplier_no, sum(l_extendedprice * (1 - l_discount)) AS total_revenue FROM lineitem WHERE l_shipdate >= CAST(''1996-01-01'' AS date) AND l_shipdate < CAST(''1996-04-01'' AS date) GROUP BY supplier_no) revenue0 WHERE s_suppkey = supplier_no AND total_revenue = ( SELECT max(total_revenue) FROM ( SELECT l_suppkey AS supplier_no, sum(l_extendedprice * (1 - l_discount)) AS total_revenue FROM lineitem WHERE l_shipdate >= CAST(''1996-01-01'' AS date) AND l_shipdate < CAST(''1996-04-01'' AS date) GROUP BY supplier_no) revenue1) ORDER BY s_suppkey;', strict = true) +---- +Could not find function "first" with argument types: ('decimal') + #Q 16 (Missing Chunk Get) #statement ok #CALL get_substrait('SELECT p_brand, p_type, p_size, count(DISTINCT ps_suppkey) AS supplier_cnt FROM partsupp, part WHERE p_partkey = ps_partkey AND p_brand <> ''Brand#45'' AND p_type NOT LIKE ''MEDIUM POLISHED%'' AND p_size IN (49, 14, 23, 45, 19, 3, 36, 9) AND ps_suppkey NOT IN ( SELECT s_suppkey FROM supplier WHERE s_comment LIKE ''%Customer%Complaints%'') GROUP BY p_brand, p_type, p_size ORDER BY supplier_cnt DESC, p_brand, p_type, p_size;')