From 5707b5339379405dfc2813d1b236fc2ab36aa5d1 Mon Sep 17 00:00:00 2001 From: Aldrin Montana Date: Wed, 13 Nov 2024 17:10:12 -0800 Subject: [PATCH] fix(bad-rebase): fixed typos and added test I had to re-add FromSubstraitFunctionData, so in the process I forgot to return unique_ptr instead. Also re-added missing closing brace. This also adds a test, but I haven't been able to actually test it. Additionally, it seems that calling explain is truncating the result, so I need to figure out how to prevent that from happening in order to create a proper expected output (and to make the function itself useful). --- src/substrait_extension.cpp | 5 +++-- test/python/test_substrait_explain.py | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 test/python/test_substrait_explain.py diff --git a/src/substrait_extension.cpp b/src/substrait_extension.cpp index 7361623..c6bbb7f 100644 --- a/src/substrait_extension.cpp +++ b/src/substrait_extension.cpp @@ -275,6 +275,7 @@ static unique_ptr FromSubstraitBind(ClientContext &context, TableFunct static unique_ptr FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input) { return SubstraitBind(context, input, true); +} //! Container for TableFnExplainSubstrait to get data from BindFnExplainSubstrait struct FromSubstraitFunctionData : public TableFunctionData { @@ -284,8 +285,8 @@ struct FromSubstraitFunctionData : public TableFunctionData { unique_ptr conn; }; -static unique_ptr BindFnExplainSubstrait(ClientContext &context, TableFunctionBindInput &input - vector &return_types, vector &names) { +static unique_ptr BindFnExplainSubstrait(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { if (input.inputs[0].IsNull()) { throw BinderException("explain_substrait cannot be called with a NULL parameter"); } diff --git a/test/python/test_substrait_explain.py b/test/python/test_substrait_explain.py new file mode 100644 index 0000000..b4ef39d --- /dev/null +++ b/test/python/test_substrait_explain.py @@ -0,0 +1,26 @@ +import pandas as pd +import duckdb + +EXPECTED_RESULT = ''' +┌───────────────┬──────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ explain_key │ explain_value │ +│ varchar │ varchar │ +├───────────────┼──────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ physical_plan │ ┌───────────────────────────┐\n│ STREAMING_LIMIT │\n└─────────────┬─────────────┘\n┌────… │ +└───────────────┴──────────────────────────────────────────────────────────────────────────────────────────────────────┘ + +''' + +def test_roundtrip_substrait(require): + connection = require('substrait') + connection.execute('CREATE TABLE integers (i integer)') + connection.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)') + + translate_result = connection.get_substrait('SELECT * FROM integers LIMIT 5') + proto_bytes = translate_result.fetchone()[0] + + expected = pd.Series([EXPECTED_RESULT], name='Explain Plan', dtype='str') + actual = connection.table_function('explain_substrait', proto_bytes).execute() + + pd.testing.assert_series_equal(actual.df()['Explain Plan'], expected) +