Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CTAS #125

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "duckdb/common/helper.hpp"

#include "duckdb/main/relation.hpp"
#include "duckdb/main/relation/create_table_relation.hpp"
#include "duckdb/main/relation/table_relation.hpp"
#include "duckdb/main/relation/table_function_relation.hpp"
#include "duckdb/main/relation/value_relation.hpp"
Expand Down Expand Up @@ -160,7 +161,7 @@
interval_t interval {};
interval.months = 0;
interval.days = literal.interval_day_to_second().days();
interval.micros = literal.interval_day_to_second().microseconds();

Check warning on line 164 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 164 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]
return Value::INTERVAL(interval);
}
default:
Expand Down Expand Up @@ -513,7 +514,7 @@

if (sop.aggregate().groupings_size() > 0) {
for (auto &sgrp : sop.aggregate().groupings()) {
for (auto &sgrpexpr : sgrp.grouping_expressions()) {

Check warning on line 517 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 517 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]
groups.push_back(TransformExpr(sgrpexpr));
expressions.push_back(TransformExpr(sgrpexpr));
}
Expand Down Expand Up @@ -613,7 +614,7 @@
scan = rel->Alias(name);
} else if (sget.has_virtual_table()) {
// We need to handle a virtual table as a LogicalExpressionGet
auto literal_values = sget.virtual_table().values();

Check warning on line 617 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 617 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
vector<vector<Value>> expression_rows;
for (auto &row : literal_values) {
auto values = row.fields();
Expand Down Expand Up @@ -698,6 +699,29 @@
return make_shared_ptr<SetOpRelation>(std::move(lhs), std::move(rhs), type);
}

shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) {
auto &swrite = sop.write();
switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
auto &nobj = swrite.named_table();
if (nobj.names_size() == 0) {
throw InvalidInputException("Named object must have at least one name");
}
auto table_idx = nobj.names_size() - 1;
auto table_name = nobj.names(table_idx);
string schema_name;
if (table_idx > 0) {
schema_name = nobj.names(0);
}

auto input = TransformOp(swrite.input());
return input->CreateRel(schema_name, table_name);
}
default:
throw NotImplementedException("Unsupported write operation " + to_string(swrite.op()));
}
}

shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
switch (sop.rel_type_case()) {
Expand All @@ -719,6 +743,8 @@
return TransformSortOp(sop, names);
case substrait::Rel::RelTypeCase::kSet:
return TransformSetOp(sop, names);
case substrait::Rel::RelTypeCase::kWrite:
return TransformWriteOp(sop);
default:
throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case()));
}
Expand Down Expand Up @@ -778,6 +804,19 @@
}
}

if (sop.input().rel_type_case() == substrait::Rel::RelTypeCase::kWrite) {
auto write = sop.input().write();
switch (write.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
const auto create_table = static_cast<CreateTableRelation *>(child.get());
auto proj = make_shared_ptr<ProjectionRelation>(create_table->child, std::move(expressions), aliases);
return proj->CreateRel(create_table->schema_name, create_table->table_name);
}
default:
return child;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returning child here seems weird. what's the reasoning?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be useful to document this branch as well with the answer. I'm guessing we're treating this as a no-op for all non-write operations.

Copy link
Contributor Author

@scgkiran scgkiran Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RootOp is always a ProjectionRelation for read operations (existing code). ProjRel is added in line 820. For write_op we should return the child relation without adding projection. Only CTAS needed to handle column names here.
Any unsupported write op we throw an exception above in TransformOp, so here we should just return the child.

I can add below comment in next PR if it is helpful.
// return child relation for other supported write operations (INSERT, DELETE)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why would we wouldn't throw here. It sounds like you're saying that code above is going to throw so we just return something random here. And what's wrong with adding a projection relation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge this change and open an issue to follow-up on better clarity around why we have these weird specialized cases outside of individual transform ops.

}
}

return make_shared_ptr<ProjectionRelation>(child, std::move(expressions), aliases);
}

Expand Down
1 change: 1 addition & 0 deletions src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class SubstraitToDuckDB {
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformWriteOp(const substrait::Rel &sop);

//! Transform Substrait Expressions to DuckDB Expressions
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr,
Expand Down
2 changes: 2 additions & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DuckDBToSubstrait {
//! In case of struct types we might we do DFS to get all names
static vector<string> DepthFirstNames(const LogicalType &type);
static void DepthFirstNamesRecurse(vector<string> &names, const LogicalType &type);
static substrait::Expression_Literal ToExpressionLiteral(const substrait::Expression &expr);

//! Transforms Relation Root
substrait::RelRoot *TransformRootOp(LogicalOperator &dop);
Expand All @@ -65,6 +66,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformDistinct(LogicalOperator &dop);
substrait::Rel *TransformExcept(LogicalOperator &dop);
substrait::Rel *TransformIntersect(LogicalOperator &dop);
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
//! To Substrait;
Expand Down
69 changes: 60 additions & 9 deletions src/substrait_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,22 +259,73 @@ static void ToJsonFunction(ClientContext &context, TableFunctionInput &data_p, D
VerifyBlobRoundtrip(query_plan, context, data, serialized);
}

static unique_ptr<TableRef> SubstraitBind(ClientContext &context, TableFunctionBindInput &input, bool is_json) {
static unique_ptr<TableRef> SubstraitBindReplace(ClientContext &context, TableFunctionBindInput &input, bool is_json) {
if (input.inputs[0].IsNull()) {
throw BinderException("from_substrait cannot be called with a NULL parameter");
}
string serialized = input.inputs[0].GetValueUnsafe<string>();
shared_ptr<ClientContext> c_ptr(&context, do_nothing);
auto plan = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json);
if (!plan.get()->IsReadOnly()) {
return nullptr;
}
return plan->GetTableRef();
}

static unique_ptr<TableRef> FromSubstraitBind(ClientContext &context, TableFunctionBindInput &input) {
return SubstraitBind(context, input, false);
static unique_ptr<TableRef> FromSubstraitBindReplace(ClientContext &context, TableFunctionBindInput &input) {
return SubstraitBindReplace(context, input, false);
}

static unique_ptr<TableRef> FromSubstraitBindReplaceJSON(ClientContext &context, TableFunctionBindInput &input) {
return SubstraitBindReplace(context, input, true);
}

struct FromSubstraitFunctionData : public TableFunctionData {
FromSubstraitFunctionData() = default;
shared_ptr<Relation> plan;
unique_ptr<QueryResult> res;
unique_ptr<Connection> conn;
};

static unique_ptr<FunctionData> SubstraitBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names, bool is_json) {
auto result = make_uniq<FromSubstraitFunctionData>();
result->conn = make_uniq<Connection>(*context.db);
if (input.inputs[0].IsNull()) {
throw BinderException("from_substrait cannot be called with a NULL parameter");
}
string serialized = input.inputs[0].GetValueUnsafe<string>();
shared_ptr<ClientContext> c_ptr(&context, do_nothing);
result->plan = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json);
for (auto &column : result->plan->Columns()) {
return_types.emplace_back(column.Type());
names.emplace_back(column.Name());
}
return std::move(result);
}

static unique_ptr<FunctionData> FromSubstraitBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
return SubstraitBind(context, input, return_types, names, false);
}

static unique_ptr<TableRef> FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input) {
return SubstraitBind(context, input, true);
static unique_ptr<FunctionData> FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
return SubstraitBind(context, input, return_types, names, true);
}

static void FromSubFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
auto &data = data_p.bind_data->CastNoConst<FromSubstraitFunctionData>();
if (!data.res) {
auto con = Connection(*context.db);
data.plan->context = make_shared_ptr<ClientContextWrapper>(con.context);
data.res = data.plan->Execute();
}
auto result_chunk = data.res->Fetch();
if (!result_chunk) {
return;
}
output.Move(*result_chunk);
}

void InitializeGetSubstrait(const Connection &con) {
Expand Down Expand Up @@ -304,8 +355,8 @@ void InitializeFromSubstrait(const Connection &con) {

// create the from_substrait table function that allows us to get a query
// result from a substrait plan
TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, nullptr, nullptr);
from_sub_func.bind_replace = FromSubstraitBind;
TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, FromSubFunction, FromSubstraitBind);
from_sub_func.bind_replace = FromSubstraitBindReplace;
CreateTableFunctionInfo from_sub_info(from_sub_func);
catalog.CreateTableFunction(*con.context, from_sub_info);
}
Expand All @@ -314,8 +365,8 @@ void InitializeFromSubstraitJSON(const Connection &con) {
auto &catalog = Catalog::GetSystemCatalog(*con.context);
// create the from_substrait table function that allows us to get a query
// result from a substrait plan
TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, nullptr, nullptr);
from_sub_func_json.bind_replace = FromSubstraitBindJSON;
TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, FromSubFunction, FromSubstraitBindJSON);
from_sub_func_json.bind_replace = FromSubstraitBindReplaceJSON;
CreateTableFunctionInfo from_sub_info_json(from_sub_func_json);
catalog.CreateTableFunction(*con.context, from_sub_info_json);
}
Expand Down
45 changes: 43 additions & 2 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
} else {
auto interval_day = make_uniq<substrait::Expression_Literal_IntervalDayToSecond>();
interval_day->set_days(dval.GetValue<interval_t>().days);
interval_day->set_microseconds(static_cast<int32_t>(dval.GetValue<interval_t>().micros));

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]
sval.set_allocated_interval_day_to_second(interval_day.release());
}
}
Expand Down Expand Up @@ -426,7 +426,7 @@
args_types.emplace_back(DuckToSubstraitType(dcomp.lower->return_type));
args_types.emplace_back(DuckToSubstraitType(dcomp.upper->return_type));
scalar_fun->set_function_reference(RegisterFunction("between", args_types));

auto sarg = scalar_fun->add_arguments();
TransformExpr(*dcomp.input, *sarg->mutable_value(), 0);
sarg = scalar_fun->add_arguments();
Expand Down Expand Up @@ -1012,7 +1012,7 @@
// TODO push projection or push substrait to allow expressions here
throw NotImplementedException("No expressions in groupings yet");
}
TransformExpr(*dgrp, *sgrp->add_grouping_expressions());

Check warning on line 1015 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 1015 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 1015 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 1015 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]
}
for (auto &dmeas : daggr.expressions) {
auto smeas = saggr->add_measures()->mutable_measure();
Expand Down Expand Up @@ -1280,7 +1280,7 @@
auto virtual_table = sget->mutable_virtual_table();

// Add a dummy value to emit one row
auto dummy_value = virtual_table->add_values();

Check warning on line 1283 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]

Check warning on line 1283 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]
dummy_value->add_fields()->set_i32(42);
return get_rel;
}
Expand Down Expand Up @@ -1381,7 +1381,8 @@
set_op->set_op(substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_INTERSECTION_PRIMARY);
break;
default:
throw NotImplementedException("Found unexpected child type in Distinct operator");
throw NotImplementedException("Found unexpected child type in Distinct operator " +
LogicalOperatorToString(set_operation_p->type));
}
auto &set_operation = set_operation_p->Cast<LogicalSetOperation>();

Expand Down Expand Up @@ -1417,6 +1418,41 @@
return rel;
}

substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &create_table = dop.Cast<LogicalCreateTable>();
auto &create_info = create_table.info.get()->Base();
if (create_table.children.size() != 1) {
if (create_table.children.size() == 0) {
throw NotImplementedException("Create table without children not implemented");
}
throw InternalException("Create table with more than one child is not supported");
}

auto schema = new substrait::NamedStruct();
auto type_info = new substrait::Type_Struct();
for (auto &name : create_info.columns.GetColumnNames()) {
schema->add_names(name);
}
for (auto &col_type : create_info.columns.GetColumnTypes()) {
auto s_type = DuckToSubstraitType(col_type, nullptr, false);
*type_info->add_types() = s_type;
}
schema->set_allocated_struct_(type_info);

// This is CreateTableAsSelect
substrait::Rel *input = TransformOp(*create_table.children[0]);
auto write = rel->mutable_write();
write->set_allocated_table_schema(schema);
write->set_allocated_input(input);
write->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS);
auto named_table = write->mutable_named_table();
named_table->add_names(create_info.schema);
named_table->add_names(create_info.table);

return rel;
}

substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
switch (dop.type) {
case LogicalOperatorType::LOGICAL_FILTER:
Expand Down Expand Up @@ -1447,6 +1483,8 @@
return TransformIntersect(dop);
case LogicalOperatorType::LOGICAL_DUMMY_SCAN:
return TransformDummyScan();
case LogicalOperatorType::LOGICAL_CREATE_TABLE:
return TransformCreateTable(dop);
default:
throw NotImplementedException(LogicalOperatorToString(dop.type));
}
Expand Down Expand Up @@ -1477,6 +1515,9 @@
continue;
}
if (current_op->children.size() != 1) {
if (current_op->type == LogicalOperatorType::LOGICAL_CREATE_TABLE) {
break;
}
throw InternalException("Root node has more than 1, or 0 children (%d) up to "
"reaching a projection node. Type %d",
current_op->children.size(), current_op->type);
Expand Down
Loading
Loading