diff --git a/src/from_substrait.cpp b/src/from_substrait.cpp index 5679815..2925dda 100644 --- a/src/from_substrait.cpp +++ b/src/from_substrait.cpp @@ -22,6 +22,7 @@ #include "duckdb/common/helper.hpp" #include "duckdb/main/relation.hpp" +#include "duckdb/main/relation/create_table_relation.hpp" #include "duckdb/main/relation/table_relation.hpp" #include "duckdb/main/relation/table_function_relation.hpp" #include "duckdb/main/relation/value_relation.hpp" @@ -698,6 +699,29 @@ shared_ptr SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop return make_shared_ptr(std::move(lhs), std::move(rhs), type); } +shared_ptr SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) { + auto &swrite = sop.write(); + switch (swrite.op()) { + case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: { + auto &nobj = swrite.named_table(); + if (nobj.names_size() == 0) { + throw InvalidInputException("Named object must have at least one name"); + } + auto table_idx = nobj.names_size() - 1; + auto table_name = nobj.names(table_idx); + string schema_name; + if (table_idx > 0) { + schema_name = nobj.names(0); + } + + auto input = TransformOp(swrite.input()); + return input->CreateRel(schema_name, table_name); + } + default: + throw NotImplementedException("Unsupported write operation " + to_string(swrite.op())); + } +} + shared_ptr SubstraitToDuckDB::TransformOp(const substrait::Rel &sop, const google::protobuf::RepeatedPtrField *names) { switch (sop.rel_type_case()) { @@ -719,6 +743,8 @@ shared_ptr SubstraitToDuckDB::TransformOp(const substrait::Rel &sop, return TransformSortOp(sop, names); case substrait::Rel::RelTypeCase::kSet: return TransformSetOp(sop, names); + case substrait::Rel::RelTypeCase::kWrite: + return TransformWriteOp(sop); default: throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case())); } @@ -778,6 +804,19 @@ shared_ptr SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot } } + if (sop.input().rel_type_case() == substrait::Rel::RelTypeCase::kWrite) { + auto write = sop.input().write(); + switch (write.op()) { + case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: { + const auto create_table = static_cast(child.get()); + auto proj = make_shared_ptr(create_table->child, std::move(expressions), aliases); + return proj->CreateRel(create_table->schema_name, create_table->table_name); + } + default: + return child; + } + } + return make_shared_ptr(child, std::move(expressions), aliases); } diff --git a/src/include/from_substrait.hpp b/src/include/from_substrait.hpp index 95a4884..4ed20f4 100644 --- a/src/include/from_substrait.hpp +++ b/src/include/from_substrait.hpp @@ -78,6 +78,7 @@ class SubstraitToDuckDB { const google::protobuf::RepeatedPtrField *names = nullptr); shared_ptr TransformSetOp(const substrait::Rel &sop, const google::protobuf::RepeatedPtrField *names = nullptr); + shared_ptr TransformWriteOp(const substrait::Rel &sop); //! Transform Substrait Expressions to DuckDB Expressions unique_ptr TransformExpr(const substrait::Expression &sexpr, diff --git a/src/include/to_substrait.hpp b/src/include/to_substrait.hpp index aa063b8..7c8d1a0 100644 --- a/src/include/to_substrait.hpp +++ b/src/include/to_substrait.hpp @@ -46,6 +46,7 @@ class DuckDBToSubstrait { //! In case of struct types we might we do DFS to get all names static vector DepthFirstNames(const LogicalType &type); static void DepthFirstNamesRecurse(vector &names, const LogicalType &type); + static substrait::Expression_Literal ToExpressionLiteral(const substrait::Expression &expr); //! Transforms Relation Root substrait::RelRoot *TransformRootOp(LogicalOperator &dop); @@ -65,6 +66,7 @@ class DuckDBToSubstrait { substrait::Rel *TransformDistinct(LogicalOperator &dop); substrait::Rel *TransformExcept(LogicalOperator &dop); substrait::Rel *TransformIntersect(LogicalOperator &dop); + substrait::Rel *TransformCreateTable(LogicalOperator &dop); static substrait::Rel *TransformDummyScan(); //! Methods to transform different LogicalGet Types (e.g., Table, Parquet) //! To Substrait; diff --git a/src/substrait_extension.cpp b/src/substrait_extension.cpp index 0ff7de8..6533afa 100644 --- a/src/substrait_extension.cpp +++ b/src/substrait_extension.cpp @@ -259,22 +259,73 @@ static void ToJsonFunction(ClientContext &context, TableFunctionInput &data_p, D VerifyBlobRoundtrip(query_plan, context, data, serialized); } -static unique_ptr SubstraitBind(ClientContext &context, TableFunctionBindInput &input, bool is_json) { +static unique_ptr SubstraitBindReplace(ClientContext &context, TableFunctionBindInput &input, bool is_json) { if (input.inputs[0].IsNull()) { throw BinderException("from_substrait cannot be called with a NULL parameter"); } string serialized = input.inputs[0].GetValueUnsafe(); shared_ptr c_ptr(&context, do_nothing); auto plan = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json); + if (!plan.get()->IsReadOnly()) { + return nullptr; + } return plan->GetTableRef(); } -static unique_ptr FromSubstraitBind(ClientContext &context, TableFunctionBindInput &input) { - return SubstraitBind(context, input, false); +static unique_ptr FromSubstraitBindReplace(ClientContext &context, TableFunctionBindInput &input) { + return SubstraitBindReplace(context, input, false); +} + +static unique_ptr FromSubstraitBindReplaceJSON(ClientContext &context, TableFunctionBindInput &input) { + return SubstraitBindReplace(context, input, true); +} + +struct FromSubstraitFunctionData : public TableFunctionData { + FromSubstraitFunctionData() = default; + shared_ptr plan; + unique_ptr res; + unique_ptr conn; +}; + +static unique_ptr SubstraitBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names, bool is_json) { + auto result = make_uniq(); + result->conn = make_uniq(*context.db); + if (input.inputs[0].IsNull()) { + throw BinderException("from_substrait cannot be called with a NULL parameter"); + } + string serialized = input.inputs[0].GetValueUnsafe(); + shared_ptr c_ptr(&context, do_nothing); + result->plan = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json); + for (auto &column : result->plan->Columns()) { + return_types.emplace_back(column.Type()); + names.emplace_back(column.Name()); + } + return std::move(result); +} + +static unique_ptr FromSubstraitBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + return SubstraitBind(context, input, return_types, names, false); } -static unique_ptr FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input) { - return SubstraitBind(context, input, true); +static unique_ptr FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + return SubstraitBind(context, input, return_types, names, true); +} + +static void FromSubFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.bind_data->CastNoConst(); + if (!data.res) { + auto con = Connection(*context.db); + data.plan->context = make_shared_ptr(con.context); + data.res = data.plan->Execute(); + } + auto result_chunk = data.res->Fetch(); + if (!result_chunk) { + return; + } + output.Move(*result_chunk); } void InitializeGetSubstrait(const Connection &con) { @@ -304,8 +355,8 @@ void InitializeFromSubstrait(const Connection &con) { // create the from_substrait table function that allows us to get a query // result from a substrait plan - TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, nullptr, nullptr); - from_sub_func.bind_replace = FromSubstraitBind; + TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, FromSubFunction, FromSubstraitBind); + from_sub_func.bind_replace = FromSubstraitBindReplace; CreateTableFunctionInfo from_sub_info(from_sub_func); catalog.CreateTableFunction(*con.context, from_sub_info); } @@ -314,8 +365,8 @@ void InitializeFromSubstraitJSON(const Connection &con) { auto &catalog = Catalog::GetSystemCatalog(*con.context); // create the from_substrait table function that allows us to get a query // result from a substrait plan - TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, nullptr, nullptr); - from_sub_func_json.bind_replace = FromSubstraitBindJSON; + TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, FromSubFunction, FromSubstraitBindJSON); + from_sub_func_json.bind_replace = FromSubstraitBindReplaceJSON; CreateTableFunctionInfo from_sub_info_json(from_sub_func_json); catalog.CreateTableFunction(*con.context, from_sub_info_json); } diff --git a/src/to_substrait.cpp b/src/to_substrait.cpp index 5fb7f2b..c691629 100644 --- a/src/to_substrait.cpp +++ b/src/to_substrait.cpp @@ -426,7 +426,7 @@ void DuckDBToSubstrait::TransformBetweenExpression(Expression &dexpr, substrait: args_types.emplace_back(DuckToSubstraitType(dcomp.lower->return_type)); args_types.emplace_back(DuckToSubstraitType(dcomp.upper->return_type)); scalar_fun->set_function_reference(RegisterFunction("between", args_types)); - + auto sarg = scalar_fun->add_arguments(); TransformExpr(*dcomp.input, *sarg->mutable_value(), 0); sarg = scalar_fun->add_arguments(); @@ -1381,7 +1381,8 @@ substrait::Rel *DuckDBToSubstrait::TransformDistinct(LogicalOperator &dop) { set_op->set_op(substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_INTERSECTION_PRIMARY); break; default: - throw NotImplementedException("Found unexpected child type in Distinct operator"); + throw NotImplementedException("Found unexpected child type in Distinct operator " + + LogicalOperatorToString(set_operation_p->type)); } auto &set_operation = set_operation_p->Cast(); @@ -1417,6 +1418,41 @@ substrait::Rel *DuckDBToSubstrait::TransformIntersect(LogicalOperator &dop) { return rel; } +substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) { + auto rel = new substrait::Rel(); + auto &create_table = dop.Cast(); + auto &create_info = create_table.info.get()->Base(); + if (create_table.children.size() != 1) { + if (create_table.children.size() == 0) { + throw NotImplementedException("Create table without children not implemented"); + } + throw InternalException("Create table with more than one child is not supported"); + } + + auto schema = new substrait::NamedStruct(); + auto type_info = new substrait::Type_Struct(); + for (auto &name : create_info.columns.GetColumnNames()) { + schema->add_names(name); + } + for (auto &col_type : create_info.columns.GetColumnTypes()) { + auto s_type = DuckToSubstraitType(col_type, nullptr, false); + *type_info->add_types() = s_type; + } + schema->set_allocated_struct_(type_info); + + // This is CreateTableAsSelect + substrait::Rel *input = TransformOp(*create_table.children[0]); + auto write = rel->mutable_write(); + write->set_allocated_table_schema(schema); + write->set_allocated_input(input); + write->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS); + auto named_table = write->mutable_named_table(); + named_table->add_names(create_info.schema); + named_table->add_names(create_info.table); + + return rel; +} + substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { switch (dop.type) { case LogicalOperatorType::LOGICAL_FILTER: @@ -1447,6 +1483,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { return TransformIntersect(dop); case LogicalOperatorType::LOGICAL_DUMMY_SCAN: return TransformDummyScan(); + case LogicalOperatorType::LOGICAL_CREATE_TABLE: + return TransformCreateTable(dop); default: throw NotImplementedException(LogicalOperatorToString(dop.type)); } @@ -1477,6 +1515,9 @@ substrait::RelRoot *DuckDBToSubstrait::TransformRootOp(LogicalOperator &dop) { continue; } if (current_op->children.size() != 1) { + if (current_op->type == LogicalOperatorType::LOGICAL_CREATE_TABLE) { + break; + } throw InternalException("Root node has more than 1, or 0 children (%d) up to " "reaching a projection node. Type %d", current_op->children.size(), current_op->type); diff --git a/test/c/test_substrait_c_api.cpp b/test/c/test_substrait_c_api.cpp index fe79d65..6999c03 100644 --- a/test/c/test_substrait_c_api.cpp +++ b/test/c/test_substrait_c_api.cpp @@ -45,3 +45,216 @@ TEST_CASE("Test C Get and To Json-Substrait API", "[substrait-api]") { REQUIRE_THROWS(con.FromSubstraitJSON("this is not valid")); } + +duckdb::unique_ptr ExecuteViaSubstrait(Connection &con, const string &sql) { + auto proto = con.GetSubstrait(sql); + return con.FromSubstrait(proto); +} + +duckdb::unique_ptr ExecuteViaSubstraitJSON(Connection &con, const string &sql) { + auto json_str = con.GetSubstraitJSON(sql); + return con.FromSubstraitJSON(json_str); +} + +void CreateEmployeeTable(Connection& con) { + REQUIRE_NO_FAIL(con.Query("CREATE TABLE employees (" + "employee_id INTEGER PRIMARY KEY, " + "name VARCHAR(100), " + "department_id INTEGER, " + "salary DECIMAL(10, 2))")); + + REQUIRE_NO_FAIL(con.Query("INSERT INTO employees VALUES " + "(1, 'John Doe', 1, 120000), " + "(2, 'Jane Smith', 2, 80000), " + "(3, 'Alice Johnson', 1, 50000), " + "(4, 'Bob Brown', 3, 95000), " + "(5, 'Charlie Black', 2, 60000)")); +} + +void CreatePartTimeEmployeeTable(Connection& con) { + REQUIRE_NO_FAIL(con.Query("CREATE TABLE part_time_employees (" + "id INTEGER PRIMARY KEY, " + "name VARCHAR(100), " + "department_id INTEGER, " + "hourly_rate DECIMAL(10, 2))")); + + REQUIRE_NO_FAIL(con.Query("INSERT INTO part_time_employees VALUES " + "(6, 'David White', 1, 30000), " + "(7, 'Eve Green', 2, 40000)")); +} + +void CreateDepartmentsTable(Connection& con) { + REQUIRE_NO_FAIL(con.Query("CREATE TABLE departments (department_id INTEGER PRIMARY KEY, department_name VARCHAR(100))")); + + REQUIRE_NO_FAIL(con.Query("INSERT INTO departments VALUES " + "(1, 'HR'), " + "(2, 'Engineering'), " + "(3, 'Finance')")); +} + +TEST_CASE("Test C CTAS Select columns with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "CREATE TABLE employee_salaries AS " + "SELECT name, salary FROM employees" + ); + + auto result = ExecuteViaSubstrait(con, "SELECT * from employee_salaries"); + REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000})); +} + +TEST_CASE("Test C CTAS Filter with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "CREATE TABLE filtered_employees AS " + "SELECT * FROM employees " + "WHERE salary > 80000;" + ); + + auto result = ExecuteViaSubstrait(con, "SELECT * from filtered_employees"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 4})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Bob Brown"})); + REQUIRE(CHECK_COLUMN(result, 2, {1, 3})); + REQUIRE(CHECK_COLUMN(result, 3, {120000, 95000})); +} + +TEST_CASE("Test C CTAS Case_When with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "CREATE TABLE categorized_employees AS " + "SELECT name, " + "CASE " + "WHEN salary > 100000 THEN 'High' " + "WHEN salary BETWEEN 60000 AND 100000 THEN 'Medium' " + "ELSE 'Low' " + "END AS salary_category " + "FROM employees" + ); + + auto result = ExecuteViaSubstrait(con, "SELECT * from categorized_employees"); + REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 1, {"High", "Medium", "Low", "Medium", "Medium"})); +} + +TEST_CASE("Test C CTAS OrderBy with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "CREATE TABLE ordered_employees AS " + "SELECT * FROM employees " + "ORDER BY salary DESC" + ); + + auto result = ExecuteViaSubstrait(con, "SELECT * from ordered_employees"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 4, 2, 5, 3})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Bob Brown", "Jane Smith", "Charlie Black", "Alice Johnson"})); + REQUIRE(CHECK_COLUMN(result, 2, {1, 3, 2, 2, 1})); + REQUIRE(CHECK_COLUMN(result, 3, {120000, 95000, 80000, 60000, 50000})); +} + +TEST_CASE("Test C CTAS SubQuery with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "CREATE TABLE high_salary_employees AS " + "SELECT * " + "FROM ( " + "SELECT employee_id, name, salary " + "FROM employees " + "WHERE salary > 100000)" + ); + + auto result = ExecuteViaSubstrait(con, "SELECT * from high_salary_employees"); + REQUIRE(CHECK_COLUMN(result, 0, {1})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe"})); + REQUIRE(CHECK_COLUMN(result, 2, {120000})); +} + +TEST_CASE("Test C CTAS Distinct with Substrait API", "[substrait-api]") { + SKIP_TEST("SKIP: Distinct operator has unsupported child type"); // TODO fix TransformDistinct + return; + + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + ExecuteViaSubstraitJSON(con, "CREATE TABLE unique_departments AS " + "SELECT DISTINCT department_id FROM employees" + ); + + auto result = ExecuteViaSubstrait(con, "SELECT * from unique_departments"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3})); +} + +TEST_CASE("Test C CTAS Aggregation with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "CREATE TABLE department_summary AS " + "SELECT department_id, COUNT(*) AS employee_count " + "FROM employees " + "GROUP BY department_id" + ); + + auto result = ExecuteViaSubstrait(con, "SELECT * from department_summary"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3})); + REQUIRE(CHECK_COLUMN(result, 1, {2, 2, 1})); +} + +TEST_CASE("Test C CTAS Join with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + CreateDepartmentsTable(con); + + ExecuteViaSubstraitJSON(con, "CREATE TABLE employee_departments AS " + "SELECT e.employee_id, e.name, d.department_name " + "FROM employees e " + "JOIN departments d " + "ON e.department_id = d.department_id" + ); + + auto result = ExecuteViaSubstrait(con, "SELECT * from employee_departments"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 2, {"HR", "Engineering", "HR", "Finance", "Engineering"})); +} + +TEST_CASE("Test C CTAS Union with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + CreatePartTimeEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "CREATE TABLE all_employees AS " + "SELECT employee_id, name, department_id, salary " + "FROM employees " + "UNION " + "SELECT id, name, department_id, hourly_rate * 2000 AS salary " + "FROM part_time_employees " + "ORDER BY employee_id" + ); + + auto result = ExecuteViaSubstrait(con, "SELECT * from all_employees"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5, 6, 7})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black", "David White", "Eve Green"})); + REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2, 1, 2})); +} \ No newline at end of file diff --git a/test/python/test_substrait.py b/test/python/test_substrait.py index b9ae456..38590fa 100644 --- a/test/python/test_substrait.py +++ b/test/python/test_substrait.py @@ -1,6 +1,7 @@ import pandas as pd import duckdb + def test_roundtrip_substrait(require): connection = require('substrait') @@ -15,3 +16,198 @@ def test_roundtrip_substrait(require): pd.testing.assert_series_equal(query_result.df()["i"], expected) + +def test_ctas_with_select_columns(require): + connection = require('substrait') + create_employee_table(connection) + _ = execute_via_substrait(connection, """CREATE TABLE employee_salaries AS + SELECT name, salary FROM employees""") + + expected = pd.DataFrame({"name": ["John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"], + "salary": pd.Series([120000, 80000, 50000, 95000, 60000], dtype="float64")}) + + query_result = execute_via_substrait(connection, "SELECT * FROM employee_salaries") + pd.testing.assert_frame_equal(query_result.df(), expected) + + +def test_ctas_with_filter(require): + connection = require('substrait') + create_employee_table(connection) + _ = execute_via_substrait(connection, """CREATE TABLE high_earners AS + SELECT * FROM employees WHERE salary > 80000""") + + expected = pd.DataFrame({"employee_id": pd.Series([1, 4], dtype="int32"), + "name": ["John Doe", "Bob Brown"], + "department_id": pd.Series([1, 3], dtype="int32"), + "salary": pd.Series([120000, 95000], dtype="float64")}) + + query_result = execute_via_substrait(connection, "SELECT * FROM high_earners") + pd.testing.assert_frame_equal(query_result.df(), expected) + + +def test_ctas_with_case_and_when(require): + connection = require('substrait') + create_employee_table(connection) + _ = execute_via_substrait(connection, """CREATE TABLE categorized_employees AS + SELECT name, + CASE + WHEN salary > 100000 THEN 'HIGH' + WHEN salary BETWEEN 60000 AND 100000 THEN 'Medium' + ELSE 'Low' + END AS salary_category + FROM employees""") + + expected = pd.DataFrame({"name": ["John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"], + "salary_category": ["HIGH", "Medium", "Low", "Medium", "Medium"]}) + + query_result = execute_via_substrait(connection, "SELECT * FROM categorized_employees") + pd.testing.assert_frame_equal(query_result.df(), expected) + + +def test_ctas_with_order_by(require): + connection = require('substrait') + create_employee_table(connection) + _ = execute_via_substrait(connection, """CREATE TABLE sorted_employees AS + SELECT * FROM employees ORDER BY salary DESC""") + + expected = pd.DataFrame({"employee_id": pd.Series([1, 4, 2, 5, 3], dtype="int32"), + "name": ["John Doe", "Bob Brown", "Jane Smith", "Charlie Black", "Alice Johnson"], + "department_id": pd.Series([1, 3, 2, 2, 1], dtype="int32"), + "salary": pd.Series([120000, 95000, 80000, 60000, 50000], dtype="float64")}) + + query_result = execute_via_substrait(connection, "SELECT * FROM sorted_employees") + pd.testing.assert_frame_equal(query_result.df(), expected) + + +def test_ctas_with_subquery(require): + connection = require('substrait') + create_employee_table(connection) + _ = execute_via_substrait(connection, """CREATE TABLE high_salary_employees AS + SELECT * FROM ( + SELECT employee_id, name, salary + FROM employees + WHERE salary > 100000) + """) + + expected = pd.DataFrame({"employee_id": pd.Series([1], dtype="int32"), + "name": ["John Doe"], "salary": pd.Series([120000], dtype="float64")}) + query_result = execute_via_substrait(connection, "SELECT * FROM high_salary_employees") + pd.testing.assert_frame_equal(query_result.df(), expected) + + +def test_ctas_with_aggregation(require): + connection = require('substrait') + create_employee_table(connection) + _ = execute_via_substrait(connection, """CREATE TABLE department_summary AS + SELECT department_id, COUNT(*) AS employee_count + FROM employees + GROUP BY department_id + """) + + expected = pd.DataFrame({"department_id": pd.Series([1, 2, 3], dtype="int32"), + "employee_count": [2, 2, 1]}) + query_result = execute_via_substrait(connection, "SELECT * FROM department_summary") + pd.testing.assert_frame_equal(query_result.df(), expected) + + +def test_ctas_with_join(require): + connection = require('substrait') + create_employee_table(connection) + create_departments_table(connection) + _ = execute_via_substrait(connection, """CREATE TABLE employee_department AS + SELECT e.employee_id, e.name, d.name AS department_name + FROM employees e + JOIN departments d + ON e.department_id = d.department_id + """) + + expected = pd.DataFrame({"employee_id": pd.Series([1, 2, 3, 4, 5], dtype="int32"), + "name": ["John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"], + "department_name": ["HR", "Engineering", "HR", "Finance", "Engineering"]}) + + query_result = execute_via_substrait(connection, "SELECT * FROM employee_department") + pd.testing.assert_frame_equal(query_result.df(), expected) + + +def test_ctas_with_union(require): + connection = require('substrait') + create_employee_table(connection) + create_part_time_employee_table(connection) + _ = execute_via_substrait(connection, """CREATE TABLE all_employees AS + SELECT employee_id, name, department_id, salary + FROM employees + UNION ALL + SELECT id as employee_id, name, department_id, hourly_rate * 2000 AS salary + FROM part_time_employees + ORDER BY employee_id + """) + + expected = pd.DataFrame({"employee_id": pd.Series([1, 2, 3, 4, 5, 6, 7], dtype="int32"), + "name": ["John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black", + "David White", "Eve Green"], + "department_id": pd.Series([1, 2, 1, 3, 2, 1, 2], dtype="int32"), + "salary": pd.Series([120000, 80000, 50000, 95000, 60000, 30000, 40000], dtype="float64")}) + + query_result = execute_via_substrait(connection, "SELECT * FROM all_employees") + pd.testing.assert_frame_equal(query_result.df(), expected) + + +def execute_via_substrait(connection, query): + res = connection.get_substrait(query) + proto_bytes = res.fetchone()[0] + con = connection.from_substrait(proto_bytes) + con.fetchall() # this is needed to force the execution + return con + + +def create_employee_table(connection): + connection.execute(""" + CREATE TABLE employees ( + employee_id INTEGER PRIMARY KEY, + name VARCHAR(100), + department_id INTEGER, + salary DECIMAL(10, 2) + ) + """) + + connection.execute(""" + INSERT INTO employees VALUES + (1, 'John Doe', 1, 120000), + (2, 'Jane Smith', 2, 80000), + (3, 'Alice Johnson', 1, 50000), + (4, 'Bob Brown', 3, 95000), + (5, 'Charlie Black', 2, 60000) + """) + + +def create_part_time_employee_table(connection): + connection.execute(""" + CREATE TABLE part_time_employees ( + id INTEGER PRIMARY KEY, + name VARCHAR(100), + department_id INTEGER, + hourly_rate DECIMAL(10, 2) + ) + """) + + connection.execute(""" + INSERT INTO part_time_employees VALUES + (6, 'David White', 1, 15), + (7, 'Eve Green', 2, 20) + """) + + +def create_departments_table(connection): + connection.execute(""" + CREATE TABLE departments ( + department_id INTEGER PRIMARY KEY, + name VARCHAR(100) + ) + """) + + connection.execute(""" + INSERT INTO departments VALUES + (1, 'HR'), + (2, 'Engineering'), + (3, 'Finance'), + """)