Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add explain TableFunction for substrait #114

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 67 additions & 9 deletions src/substrait_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

namespace duckdb {

void do_nothing(ClientContext *) {
}
//! This is a no-op deleter for creating a shared pointer to a reference.
void deleter_noop(ClientContext *) {}

struct ToSubstraitFunctionData : public TableFunctionData {
ToSubstraitFunctionData() = default;
Expand Down Expand Up @@ -264,7 +264,7 @@ static unique_ptr<TableRef> SubstraitBind(ClientContext &context, TableFunctionB
throw BinderException("from_substrait cannot be called with a NULL parameter");
}
string serialized = input.inputs[0].GetValueUnsafe<string>();
shared_ptr<ClientContext> c_ptr(&context, do_nothing);
shared_ptr<ClientContext> c_ptr(&context, deleter_noop);
auto plan = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json);
return plan->GetTableRef();
}
Expand All @@ -277,6 +277,46 @@ static unique_ptr<TableRef> FromSubstraitBindJSON(ClientContext &context, TableF
return SubstraitBind(context, input, true);
}

//! Container for TableFnExplainSubstrait to get data from BindFnExplainSubstrait
struct FromSubstraitFunctionData : public TableFunctionData {
FromSubstraitFunctionData() = default;
shared_ptr<Relation> plan;
unique_ptr<QueryResult> res;
unique_ptr<Connection> conn;
};

static unique_ptr<FunctionData> BindFnExplainSubstrait(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
if (input.inputs[0].IsNull()) {
throw BinderException("explain_substrait cannot be called with a NULL parameter");
}

// Prep args to `SubstraitPlanToDuckDBRel`
constexpr bool is_json = false;
string serialized = input.inputs[0].GetValueUnsafe<string>();
shared_ptr<ClientContext> c_ptr(&context, deleter_noop);

auto result = make_uniq<FromSubstraitFunctionData>();
result->conn = make_uniq<Connection>(*context.db);
result->plan = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json);

// return schema is a single string attribute (column)
return_types.emplace_back(LogicalType::VARCHAR);
names.emplace_back("Explain Plan");

return std::move(result);
}

static void TableFnExplainSubstrait(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
auto &data = data_p.bind_data->CastNoConst<FromSubstraitFunctionData>();
if (!data.res) { data.res = data.plan->Explain(); }

auto result_chunk = data.res->Fetch();
if (!result_chunk) { return; }

output.Move(*result_chunk);
}

void InitializeGetSubstrait(const Connection &con) {
auto &catalog = Catalog::GetSystemCatalog(*con.context);
// create the get_substrait table function that allows us to get a substrait
Expand All @@ -299,22 +339,39 @@ void InitializeGetSubstraitJSON(const Connection &con) {
catalog.CreateTableFunction(*con.context, get_substrait_json_info);
}

//! Define and register a TableFunction ("from_substrait") that returns a TableRef
void InitializeFromSubstrait(const Connection &con) {
auto &catalog = Catalog::GetSystemCatalog(*con.context);

// create the from_substrait table function that allows us to get a query
// result from a substrait plan
TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, nullptr, nullptr);
// `FromSubstraitBind` translates a substrait plan and returns a `TableRef`
// to return a `TableRef` we use `bind_replace` instead of `bind`
TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, nullptr);
from_sub_func.bind_replace = FromSubstraitBind;

// register the TableFunction in the system catalog
auto &catalog = Catalog::GetSystemCatalog(*con.context);
CreateTableFunctionInfo from_sub_info(from_sub_func);
catalog.CreateTableFunction(*con.context, from_sub_info);
}

//! Define and register a TableFunction ("explain_substrait") that returns a QueryResult
void InitializeExplainSubstrait(const Connection &con) {
TableFunction explain_sub_func(
"explain_substrait"
,{LogicalType::BLOB}
,/*function=*/TableFnExplainSubstrait // Translates the plan then converts to a string
,/*bind=*/BindFnExplainSubstrait // Sets return schema to a single string
);

// register the TableFunction in the system catalog
auto &catalog = Catalog::GetSystemCatalog(*con.context);
CreateTableFunctionInfo explain_sub_info(explain_sub_func);
catalog.CreateTableFunction(*con.context, explain_sub_info);
}

void InitializeFromSubstraitJSON(const Connection &con) {
auto &catalog = Catalog::GetSystemCatalog(*con.context);
// create the from_substrait table function that allows us to get a query
// result from a substrait plan
TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, nullptr, nullptr);
TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, nullptr);
from_sub_func_json.bind_replace = FromSubstraitBindJSON;
CreateTableFunctionInfo from_sub_info_json(from_sub_func_json);
catalog.CreateTableFunction(*con.context, from_sub_info_json);
Expand All @@ -329,6 +386,7 @@ void SubstraitExtension::Load(DuckDB &db) {

InitializeFromSubstrait(con);
InitializeFromSubstraitJSON(con);
InitializeExplainSubstrait(con);

con.Commit();
}
Expand Down
26 changes: 26 additions & 0 deletions test/python/test_substrait_explain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pandas as pd
import duckdb

EXPECTED_RESULT = '''
┌───────────────┬──────────────────────────────────────────────────────────────────────────────────────────────────────┐
│ explain_key │ explain_value │
│ varchar │ varchar │
├───────────────┼──────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ physical_plan │ ┌───────────────────────────┐\n│ STREAMING_LIMIT │\n└─────────────┬─────────────┘\n┌────… │
└───────────────┴──────────────────────────────────────────────────────────────────────────────────────────────────────┘

'''

def test_roundtrip_substrait(require):
connection = require('substrait')
connection.execute('CREATE TABLE integers (i integer)')
connection.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)')

translate_result = connection.get_substrait('SELECT * FROM integers LIMIT 5')
proto_bytes = translate_result.fetchone()[0]

expected = pd.Series([EXPECTED_RESULT], name='Explain Plan', dtype='str')
actual = connection.table_function('explain_substrait', proto_bytes).execute()

pd.testing.assert_series_equal(actual.df()['Explain Plan'], expected)

Loading