Skip to content

Commit

Permalink
fix(bad-rebase): fixed typos and added test
Browse files Browse the repository at this point in the history
I had to re-add FromSubstraitFunctionData, so in the process I forgot to
return unique_ptr<FunctionData> instead.

Also re-added missing closing brace.

This also adds a test, but I haven't been able to actually test it.
Additionally, it seems that calling explain is truncating the result, so
I need to figure out how to prevent that from happening in order to
create a proper expected output (and to make the function itself
useful).
  • Loading branch information
drin committed Nov 14, 2024
1 parent 99752e4 commit 5707b53
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/substrait_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ static unique_ptr<TableRef> FromSubstraitBind(ClientContext &context, TableFunct

static unique_ptr<TableRef> FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input) {
return SubstraitBind(context, input, true);
}

//! Container for TableFnExplainSubstrait to get data from BindFnExplainSubstrait
struct FromSubstraitFunctionData : public TableFunctionData {
Expand All @@ -284,8 +285,8 @@ struct FromSubstraitFunctionData : public TableFunctionData {
unique_ptr<Connection> conn;
};

static unique_ptr<TableRef> BindFnExplainSubstrait(ClientContext &context, TableFunctionBindInput &input
vector<LogicalType> &return_types, vector<string> &names) {
static unique_ptr<FunctionData> BindFnExplainSubstrait(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
if (input.inputs[0].IsNull()) {
throw BinderException("explain_substrait cannot be called with a NULL parameter");
}
Expand Down
26 changes: 26 additions & 0 deletions test/python/test_substrait_explain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pandas as pd
import duckdb

EXPECTED_RESULT = '''
┌───────────────┬──────────────────────────────────────────────────────────────────────────────────────────────────────┐
│ explain_key │ explain_value │
│ varchar │ varchar │
├───────────────┼──────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ physical_plan │ ┌───────────────────────────┐\n│ STREAMING_LIMIT │\n└─────────────┬─────────────┘\n┌────… │
└───────────────┴──────────────────────────────────────────────────────────────────────────────────────────────────────┘
'''

def test_roundtrip_substrait(require):
connection = require('substrait')
connection.execute('CREATE TABLE integers (i integer)')
connection.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)')

translate_result = connection.get_substrait('SELECT * FROM integers LIMIT 5')
proto_bytes = translate_result.fetchone()[0]

expected = pd.Series([EXPECTED_RESULT], name='Explain Plan', dtype='str')
actual = connection.table_function('explain_substrait', proto_bytes).execute()

pd.testing.assert_series_equal(actual.df()['Explain Plan'], expected)

0 comments on commit 5707b53

Please sign in to comment.