Skip to content

Commit

Permalink
Support delete rows in a table
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran committed Nov 19, 2024
1 parent e84a785 commit e72cc29
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 15 deletions.
32 changes: 18 additions & 14 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "duckdb/main/relation.hpp"
#include "duckdb/main/relation/create_table_relation.hpp"
#include <duckdb/main/relation/delete_relation.hpp>
#include "duckdb/main/relation/table_relation.hpp"
#include "duckdb/main/relation/table_function_relation.hpp"
#include "duckdb/main/relation/value_relation.hpp"
Expand Down Expand Up @@ -701,21 +702,24 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop

shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) {
auto &swrite = sop.write();
auto &nobj = swrite.named_table();
if (nobj.names_size() == 0) {
throw InvalidInputException("Named object must have at least one name");
}
auto table_idx = nobj.names_size() - 1;
auto table_name = nobj.names(table_idx);
string schema_name;
if (table_idx > 0) {
schema_name = nobj.names(0);
}
auto input = TransformOp(swrite.input());
switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
auto &nobj = swrite.named_table();
if (nobj.names_size() == 0) {
throw InvalidInputException("Named object must have at least one name");
}
auto table_idx = nobj.names_size() - 1;
auto table_name = nobj.names(table_idx);
string schema_name;
if (table_idx > 0) {
schema_name = nobj.names(0);
}

auto input = TransformOp(swrite.input());
return input->CreateRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
return input->CreateRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
auto filter = std::move(input.get()->Cast<FilterRelation>());
auto context = filter.child->Cast<TableRelation>().context;
return make_shared_ptr<DeleteRelation>(filter.context, std::move(filter.condition), schema_name, table_name);
}
default:
throw NotImplementedException("Unsupported write operation " + to_string(swrite.op()));
Expand Down
3 changes: 3 additions & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class DuckDBToSubstrait {
static vector<string> DepthFirstNames(const LogicalType &type);
static void DepthFirstNamesRecurse(vector<string> &names, const LogicalType &type);
static substrait::Expression_Literal ToExpressionLiteral(const substrait::Expression &expr);
static void SetTableSchema(const TableCatalogEntry &table, substrait::NamedStruct *schema);
static void SetNamedTable(const TableCatalogEntry &table, substrait::WriteRel *writeRel);

//! Transforms Relation Root
substrait::RelRoot *TransformRootOp(LogicalOperator &dop);
Expand All @@ -67,6 +69,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformExcept(LogicalOperator &dop);
substrait::Rel *TransformIntersect(LogicalOperator &dop);
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
//! To Substrait;
Expand Down
62 changes: 62 additions & 0 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,51 @@ substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) {
return rel;
}

void DuckDBToSubstrait::SetTableSchema(const TableCatalogEntry &table, substrait::NamedStruct *schema) {
for (auto &name : table.GetColumns().GetColumnNames()) {
schema->add_names(name);
}
auto type_info = new substrait::Type_Struct();
type_info->set_nullability(substrait::Type_Nullability_NULLABILITY_REQUIRED);
for (auto &col_type : table.GetColumns().GetColumnTypes()) {
auto s_type = DuckToSubstraitType(col_type, nullptr, false);
*type_info->add_types() = s_type;
}
schema->set_allocated_struct_(type_info);
}

void DuckDBToSubstrait::SetNamedTable(const TableCatalogEntry &table, substrait::WriteRel *writeRel) {
auto named_table = writeRel->mutable_named_table();
named_table->add_names(table.schema.name);
named_table->add_names(table.name);
}

substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &logical_delete = dop.Cast<LogicalDelete>();
auto &table = logical_delete.table;
if (logical_delete.children.size() != 1) {
throw InternalException("Delete table expected one child, found " + to_string(logical_delete.children.size()));
}

auto writeRel = rel->mutable_write();
writeRel->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE);
writeRel->set_output(substrait::WriteRel::OUTPUT_MODE_NO_OUTPUT);

auto named_table = writeRel->mutable_named_table();
named_table->add_names(table.schema.name);
named_table->add_names(table.name);

SetNamedTable(logical_delete.table, writeRel);
auto schema = new substrait::NamedStruct();
SetTableSchema(logical_delete.table, schema);
writeRel->set_allocated_table_schema(schema);

substrait::Rel *input = TransformOp(*logical_delete.children[0]);
writeRel->set_allocated_input(input);
return rel;
}

substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
switch (dop.type) {
case LogicalOperatorType::LOGICAL_FILTER:
Expand Down Expand Up @@ -1485,6 +1530,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
return TransformDummyScan();
case LogicalOperatorType::LOGICAL_CREATE_TABLE:
return TransformCreateTable(dop);
case LogicalOperatorType::LOGICAL_DELETE:
return TransformDeleteTable(dop);
default:
throw NotImplementedException(LogicalOperatorToString(dop.type));
}
Expand All @@ -1495,8 +1542,23 @@ static bool IsSetOperation(const LogicalOperator &op) {
op.type == LogicalOperatorType::LOGICAL_INTERSECT;
}

static bool IsRowModificationOperator(const LogicalOperator &op) {
switch (op.type) {
case LogicalOperatorType::LOGICAL_INSERT:
case LogicalOperatorType::LOGICAL_DELETE:
case LogicalOperatorType::LOGICAL_UPDATE:
return true;
default:
return false;
}
}

substrait::RelRoot *DuckDBToSubstrait::TransformRootOp(LogicalOperator &dop) {
auto root_rel = new substrait::RelRoot();
if (IsRowModificationOperator(dop)) {
root_rel->set_allocated_input(TransformOp(dop));
return root_rel;
}
LogicalOperator *current_op = &dop;
bool weird_scenario = current_op->type == LogicalOperatorType::LOGICAL_PROJECTION &&
current_op->children[0]->type == LogicalOperatorType::LOGICAL_TOP_N;
Expand Down
16 changes: 15 additions & 1 deletion test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,18 @@ TEST_CASE("Test C CTAS Union with Substrait API", "[substrait-api]") {
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5, 6, 7}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black", "David White", "Eve Green"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2, 1, 2}));
}
}

TEST_CASE("Test C DeleteRows with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "DELETE FROM employees WHERE salary < 80000");
auto result = ExecuteViaSubstrait(con, "SELECT * from employees");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 4}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Bob Brown"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 3}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 80000, 95000}));
}
12 changes: 12 additions & 0 deletions test/python/test_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ def test_ctas_with_union(require):
pd.testing.assert_frame_equal(query_result.df(), expected)


def test_delete_rows_in_table(require):
connection = require('substrait')
create_employee_table(connection)
connection.execute("DELETE FROM employees WHERE salary < 80000")
query_result = execute_via_substrait(connection, "SELECT * FROM employees")
expected = pd.DataFrame({"employee_id": pd.Series([1, 2, 4], dtype="int32"),
"name": ["John Doe", "Jane Smith", "Bob Brown"],
"department_id": pd.Series([1, 2, 3], dtype="int32"),
"salary": pd.Series([120000, 80000, 95000], dtype="float64")})
pd.testing.assert_frame_equal(query_result.df(), expected)


def execute_via_substrait(connection, query):
res = connection.get_substrait(query)
proto_bytes = res.fetchone()[0]
Expand Down

0 comments on commit e72cc29

Please sign in to comment.