Skip to content

Commit

Permalink
Support CTAS (#125)
Browse files Browse the repository at this point in the history
* Support CTAS

* Handle from_substrait for Update operations

* Add review comments
  • Loading branch information
scgkiran authored Nov 19, 2024
1 parent fa535b9 commit e84a785
Show file tree
Hide file tree
Showing 7 changed files with 554 additions and 11 deletions.
39 changes: 39 additions & 0 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "duckdb/common/helper.hpp"

#include "duckdb/main/relation.hpp"
#include "duckdb/main/relation/create_table_relation.hpp"
#include "duckdb/main/relation/table_relation.hpp"
#include "duckdb/main/relation/table_function_relation.hpp"
#include "duckdb/main/relation/value_relation.hpp"
Expand Down Expand Up @@ -698,6 +699,29 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop
return make_shared_ptr<SetOpRelation>(std::move(lhs), std::move(rhs), type);
}

shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) {
auto &swrite = sop.write();
switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
auto &nobj = swrite.named_table();
if (nobj.names_size() == 0) {
throw InvalidInputException("Named object must have at least one name");
}
auto table_idx = nobj.names_size() - 1;
auto table_name = nobj.names(table_idx);
string schema_name;
if (table_idx > 0) {
schema_name = nobj.names(0);
}

auto input = TransformOp(swrite.input());
return input->CreateRel(schema_name, table_name);
}
default:
throw NotImplementedException("Unsupported write operation " + to_string(swrite.op()));
}
}

shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
switch (sop.rel_type_case()) {
Expand All @@ -719,6 +743,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop,
return TransformSortOp(sop, names);
case substrait::Rel::RelTypeCase::kSet:
return TransformSetOp(sop, names);
case substrait::Rel::RelTypeCase::kWrite:
return TransformWriteOp(sop);
default:
throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case()));
}
Expand Down Expand Up @@ -778,6 +804,19 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
}
}

if (sop.input().rel_type_case() == substrait::Rel::RelTypeCase::kWrite) {
auto write = sop.input().write();
switch (write.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
const auto create_table = static_cast<CreateTableRelation *>(child.get());
auto proj = make_shared_ptr<ProjectionRelation>(create_table->child, std::move(expressions), aliases);
return proj->CreateRel(create_table->schema_name, create_table->table_name);
}
default:
return child;
}
}

return make_shared_ptr<ProjectionRelation>(child, std::move(expressions), aliases);
}

Expand Down
1 change: 1 addition & 0 deletions src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class SubstraitToDuckDB {
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformWriteOp(const substrait::Rel &sop);

//! Transform Substrait Expressions to DuckDB Expressions
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr,
Expand Down
2 changes: 2 additions & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DuckDBToSubstrait {
//! In case of struct types we might we do DFS to get all names
static vector<string> DepthFirstNames(const LogicalType &type);
static void DepthFirstNamesRecurse(vector<string> &names, const LogicalType &type);
static substrait::Expression_Literal ToExpressionLiteral(const substrait::Expression &expr);

//! Transforms Relation Root
substrait::RelRoot *TransformRootOp(LogicalOperator &dop);
Expand All @@ -65,6 +66,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformDistinct(LogicalOperator &dop);
substrait::Rel *TransformExcept(LogicalOperator &dop);
substrait::Rel *TransformIntersect(LogicalOperator &dop);
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
//! To Substrait;
Expand Down
69 changes: 60 additions & 9 deletions src/substrait_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,22 +259,73 @@ static void ToJsonFunction(ClientContext &context, TableFunctionInput &data_p, D
VerifyBlobRoundtrip(query_plan, context, data, serialized);
}

static unique_ptr<TableRef> SubstraitBind(ClientContext &context, TableFunctionBindInput &input, bool is_json) {
static unique_ptr<TableRef> SubstraitBindReplace(ClientContext &context, TableFunctionBindInput &input, bool is_json) {
if (input.inputs[0].IsNull()) {
throw BinderException("from_substrait cannot be called with a NULL parameter");
}
string serialized = input.inputs[0].GetValueUnsafe<string>();
shared_ptr<ClientContext> c_ptr(&context, do_nothing);
auto plan = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json);
if (!plan.get()->IsReadOnly()) {
return nullptr;
}
return plan->GetTableRef();
}

static unique_ptr<TableRef> FromSubstraitBind(ClientContext &context, TableFunctionBindInput &input) {
return SubstraitBind(context, input, false);
static unique_ptr<TableRef> FromSubstraitBindReplace(ClientContext &context, TableFunctionBindInput &input) {
return SubstraitBindReplace(context, input, false);
}

static unique_ptr<TableRef> FromSubstraitBindReplaceJSON(ClientContext &context, TableFunctionBindInput &input) {
return SubstraitBindReplace(context, input, true);
}

struct FromSubstraitFunctionData : public TableFunctionData {
FromSubstraitFunctionData() = default;
shared_ptr<Relation> plan;
unique_ptr<QueryResult> res;
unique_ptr<Connection> conn;
};

static unique_ptr<FunctionData> SubstraitBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names, bool is_json) {
auto result = make_uniq<FromSubstraitFunctionData>();
result->conn = make_uniq<Connection>(*context.db);
if (input.inputs[0].IsNull()) {
throw BinderException("from_substrait cannot be called with a NULL parameter");
}
string serialized = input.inputs[0].GetValueUnsafe<string>();
shared_ptr<ClientContext> c_ptr(&context, do_nothing);
result->plan = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json);
for (auto &column : result->plan->Columns()) {
return_types.emplace_back(column.Type());
names.emplace_back(column.Name());
}
return std::move(result);
}

static unique_ptr<FunctionData> FromSubstraitBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
return SubstraitBind(context, input, return_types, names, false);
}

static unique_ptr<TableRef> FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input) {
return SubstraitBind(context, input, true);
static unique_ptr<FunctionData> FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
return SubstraitBind(context, input, return_types, names, true);
}

static void FromSubFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
auto &data = data_p.bind_data->CastNoConst<FromSubstraitFunctionData>();
if (!data.res) {
auto con = Connection(*context.db);
data.plan->context = make_shared_ptr<ClientContextWrapper>(con.context);
data.res = data.plan->Execute();
}
auto result_chunk = data.res->Fetch();
if (!result_chunk) {
return;
}
output.Move(*result_chunk);
}

void InitializeGetSubstrait(const Connection &con) {
Expand Down Expand Up @@ -304,8 +355,8 @@ void InitializeFromSubstrait(const Connection &con) {

// create the from_substrait table function that allows us to get a query
// result from a substrait plan
TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, nullptr, nullptr);
from_sub_func.bind_replace = FromSubstraitBind;
TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, FromSubFunction, FromSubstraitBind);
from_sub_func.bind_replace = FromSubstraitBindReplace;
CreateTableFunctionInfo from_sub_info(from_sub_func);
catalog.CreateTableFunction(*con.context, from_sub_info);
}
Expand All @@ -314,8 +365,8 @@ void InitializeFromSubstraitJSON(const Connection &con) {
auto &catalog = Catalog::GetSystemCatalog(*con.context);
// create the from_substrait table function that allows us to get a query
// result from a substrait plan
TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, nullptr, nullptr);
from_sub_func_json.bind_replace = FromSubstraitBindJSON;
TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, FromSubFunction, FromSubstraitBindJSON);
from_sub_func_json.bind_replace = FromSubstraitBindReplaceJSON;
CreateTableFunctionInfo from_sub_info_json(from_sub_func_json);
catalog.CreateTableFunction(*con.context, from_sub_info_json);
}
Expand Down
45 changes: 43 additions & 2 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ void DuckDBToSubstrait::TransformBetweenExpression(Expression &dexpr, substrait:
args_types.emplace_back(DuckToSubstraitType(dcomp.lower->return_type));
args_types.emplace_back(DuckToSubstraitType(dcomp.upper->return_type));
scalar_fun->set_function_reference(RegisterFunction("between", args_types));

auto sarg = scalar_fun->add_arguments();
TransformExpr(*dcomp.input, *sarg->mutable_value(), 0);
sarg = scalar_fun->add_arguments();
Expand Down Expand Up @@ -1381,7 +1381,8 @@ substrait::Rel *DuckDBToSubstrait::TransformDistinct(LogicalOperator &dop) {
set_op->set_op(substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_INTERSECTION_PRIMARY);
break;
default:
throw NotImplementedException("Found unexpected child type in Distinct operator");
throw NotImplementedException("Found unexpected child type in Distinct operator " +
LogicalOperatorToString(set_operation_p->type));
}
auto &set_operation = set_operation_p->Cast<LogicalSetOperation>();

Expand Down Expand Up @@ -1417,6 +1418,41 @@ substrait::Rel *DuckDBToSubstrait::TransformIntersect(LogicalOperator &dop) {
return rel;
}

substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &create_table = dop.Cast<LogicalCreateTable>();
auto &create_info = create_table.info.get()->Base();
if (create_table.children.size() != 1) {
if (create_table.children.size() == 0) {
throw NotImplementedException("Create table without children not implemented");
}
throw InternalException("Create table with more than one child is not supported");
}

auto schema = new substrait::NamedStruct();
auto type_info = new substrait::Type_Struct();
for (auto &name : create_info.columns.GetColumnNames()) {
schema->add_names(name);
}
for (auto &col_type : create_info.columns.GetColumnTypes()) {
auto s_type = DuckToSubstraitType(col_type, nullptr, false);
*type_info->add_types() = s_type;
}
schema->set_allocated_struct_(type_info);

// This is CreateTableAsSelect
substrait::Rel *input = TransformOp(*create_table.children[0]);
auto write = rel->mutable_write();
write->set_allocated_table_schema(schema);
write->set_allocated_input(input);
write->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS);
auto named_table = write->mutable_named_table();
named_table->add_names(create_info.schema);
named_table->add_names(create_info.table);

return rel;
}

substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
switch (dop.type) {
case LogicalOperatorType::LOGICAL_FILTER:
Expand Down Expand Up @@ -1447,6 +1483,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
return TransformIntersect(dop);
case LogicalOperatorType::LOGICAL_DUMMY_SCAN:
return TransformDummyScan();
case LogicalOperatorType::LOGICAL_CREATE_TABLE:
return TransformCreateTable(dop);
default:
throw NotImplementedException(LogicalOperatorToString(dop.type));
}
Expand Down Expand Up @@ -1477,6 +1515,9 @@ substrait::RelRoot *DuckDBToSubstrait::TransformRootOp(LogicalOperator &dop) {
continue;
}
if (current_op->children.size() != 1) {
if (current_op->type == LogicalOperatorType::LOGICAL_CREATE_TABLE) {
break;
}
throw InternalException("Root node has more than 1, or 0 children (%d) up to "
"reaching a projection node. Type %d",
current_op->children.size(), current_op->type);
Expand Down
Loading

0 comments on commit e84a785

Please sign in to comment.