Skip to content

Commit

Permalink
Merge pull request #81 from pdet/substrait_extensions
Browse files Browse the repository at this point in the history
Adding Extensions URIs for aggregate functions
  • Loading branch information
pdet authored May 6, 2024
2 parents 26bd22a + 8019275 commit 6ae4260
Show file tree
Hide file tree
Showing 15 changed files with 396 additions and 89 deletions.
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 896 files
2 changes: 1 addition & 1 deletion duckdb-r
Submodule duckdb-r updated 754 files
19 changes: 12 additions & 7 deletions scripts/generate_custom_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
import yaml
import regex

def parse_yaml(file_path):
with open(file_path, 'r') as file:
yaml_data = yaml.safe_load(file)

functions = []

for function_data in yaml_data.get('scalar_functions', []):
def parse_function_data(functions,yaml_data,function_type):
for function_data in yaml_data.get(function_type, []):
function = {
'name': function_data['name'],
'impls_args': []
Expand All @@ -24,7 +21,14 @@ def parse_yaml(file_path):
function['impls_args'].append(args)

functions.append(function)
return functions

def parse_yaml(file_path):
with open(file_path, 'r') as file:
yaml_data = yaml.safe_load(file)
functions = []
functions = parse_function_data(functions,yaml_data,'scalar_functions')
functions = parse_function_data(functions,yaml_data,'aggregate_functions')
return functions

def get_custom_functions():
Expand All @@ -39,8 +43,9 @@ def get_custom_functions():
type_str = "{"
for args in impls_args:
type_value = regex.sub(r'<[^>]*>', '', args["value"])
type_set.add(type_value)
type_str += f"\"{type_value}\","
if (len(type_value) != 0):
type_set.add(type_value)
type_str += f"\"{type_value}\","
type_str = type_str[:-1]
type_str += "}"
function_name = function["name"]
Expand Down
121 changes: 114 additions & 7 deletions src/custom_extensions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "custom_extensions/custom_extensions.hpp"
#include "duckdb/common/types.hpp"
#include "duckdb/common/string_util.hpp"

namespace duckdb {

Expand All @@ -16,10 +17,84 @@ string TransformTypes(const ::substrait::Type &type) {
return str_type;
}

vector<string> GetAllTypes() {
return {{"bool"},
{"i8"},
{"i16"},
{"i32"},
{"i64"},
{"fp32"},
{"fp64"},
{"string"},
{"binary"},
{"timestamp"},
{"date"},
{"time"},
{"interval_year"},
{"interval_day"},
{"timestamp_tz"},
{"uuid"},
{"varchar"},
{"fixed_binary"},
{"decimal"},
{"precision_timestamp"},
{"precision_timestamp_tz"}};
}

// Recurse over the whole shebang
void SubstraitCustomFunctions::InsertAllFunctions(const vector<vector<string>> &all_types, vector<idx_t> &indices,
int depth, string &name, string &file_path) {
if (depth == indices.size()) {
vector<string> types;
for (idx_t i = 0; i < indices.size(); i++) {
auto type = all_types[i][indices[i]];
type = StringUtil::Replace(type, "boolean", "bool");
types.push_back(type);
}
if (types.empty()) {
any_arg_functions[{name, types}] = {{name, types}, std::move(file_path)};
} else {
bool many_arg = false;
string type = types[0];
for (auto &t : types) {
if (!t.empty() && t[t.size() - 1] == '?') {
// If all types are equal and they end with ? we have a many_argument function
many_arg = type == t;
}
}
if (many_arg) {
many_arg_functions[{name, types}] = {{name, types}, std::move(file_path)};
} else {
custom_functions[{name, types}] = {{name, types}, std::move(file_path)};
}
}

return;
}
for (int i = 0; i < all_types[depth].size(); ++i) {
indices[depth] = i;
InsertAllFunctions(all_types, indices, depth + 1, name, file_path);
}
}

void SubstraitCustomFunctions::InsertCustomFunction(string name_p, vector<string> types_p, string file_path) {
auto name = std::move(name_p);
auto types = std::move(types_p);
custom_functions[{name, types}] = {{name, types}, std::move(file_path)};
vector<vector<string>> all_types;
for (auto &t : types) {
if (t == "any1" || t == "unknown") {
all_types.emplace_back(GetAllTypes());
} else {
all_types.push_back({t});
}
}
// Get the number of dimensions
idx_t num_arguments = all_types.size();

// Create a vector to hold the indices
vector<idx_t> idx(num_arguments, 0);

// Call the helper function with initial depth 0
InsertAllFunctions(all_types, idx, 0, name_p, file_path);
}

string SubstraitCustomFunction::GetName() {
Expand Down Expand Up @@ -49,25 +124,57 @@ SubstraitCustomFunctions::SubstraitCustomFunctions() {
Initialize();
};

vector<string> SubstraitCustomFunctions::GetTypes(const vector<::substrait::Type> &types) const {
vector<string> transformed_types;
for (auto &type : types) {
transformed_types.emplace_back(TransformTypes(type));
}
return transformed_types;
}

// FIXME: We might have to do DuckDB extensions at some point
SubstraitFunctionExtensions SubstraitCustomFunctions::Get(const string &name,
const vector<::substrait::Type> &types) const {
vector<string> transformed_types;
if (types.empty()) {
SubstraitCustomFunction custom_function {name, {}};
auto it = any_arg_functions.find(custom_function);
if (it != custom_functions.end()) {
// We found it in our substrait custom map, return that
return it->second;
}
return {{name, {}}, "native"};
}

for (auto &type : types) {
transformed_types.emplace_back(TransformTypes(type));
if (transformed_types.back().empty()) {
// If it is empty it means we did not find a yaml extension, we return the function name
return {{name, {}}, "native"};
}
}
SubstraitCustomFunction custom_function {name, {transformed_types}};
auto it = custom_functions.find(custom_function);
if (it != custom_functions.end()) {
// We found it in our substrait custom map, return that
return it->second;
{
SubstraitCustomFunction custom_function {name, {transformed_types}};
auto it = custom_functions.find(custom_function);
if (it != custom_functions.end()) {
// We found it in our substrait custom map, return that
return it->second;
}
}

// check if it's a many argument fit
bool possibly_many_arg = true;
string type = transformed_types[0];
for (auto &t : transformed_types) {
possibly_many_arg = possibly_many_arg && type == t;
}
if (possibly_many_arg) {
type += '?';
SubstraitCustomFunction custom_many_function {name, {{type}}};
auto many_it = many_arg_functions.find(custom_many_function);
if (many_it != many_arg_functions.end()) {
return many_it->second;
}
}
// TODO: check if this should also print the arg types or not
// we did not find it, return it as a native substrait function
Expand Down
Loading

0 comments on commit 6ae4260

Please sign in to comment.