Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add validate function for a substrait::Plan message #281

Merged
merged 4 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,23 @@ if(SUBSTRAIT_VALIDATOR_BUILD_TESTS)
substrait-validator-c-test
${CMAKE_CURRENT_SOURCE_DIR}/tests/test.cc
)

# For OSX, link in CoreFoundation, needed for some system symbols like
# _CFRelease, _CFStringGetBytes, …
Comment on lines +52 to +53
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was outside the intent of this PR, but after running into this error and attempting to reproduce locally, I found this option helped. I know very little about CMake, so any recommendations here welcome 🙇

set(EXTRA_FRAMEWORK_OPTIONS "")
if(APPLE)
set(EXTRA_FRAMEWORK_OPTIONS "-framework CoreFoundation")
endif()


target_link_libraries(
substrait-validator-c-test
gtest_main
substrait_validator_c
"${EXTRA_FRAMEWORK_OPTIONS}"
)


include(GoogleTest)
gtest_discover_tests(substrait-validator-c-test)

Expand Down
17 changes: 11 additions & 6 deletions c/tests/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@ TEST(BasicTest, BasicTest) {
EXPECT_EQ(strlen(reinterpret_cast<const char *>(data_ptr)), data_size);
EXPECT_EQ(
reinterpret_cast<const char *>(data_ptr),
std::string("Error at plan: failed to parse as substrait.Plan: "
"failed to decode Protobuf message: "
"invalid wire type value: 7 (code 1001) (code 1001)\n"
"Error at plan: failed to parse as substrait.PlanVersion: "
"failed to decode Protobuf message: "
"invalid wire type value: 7 (code 1001) (code 1001)\n"));
std::string(
"Info at plan: this version of the validator is EXPERIMENTAL. "
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that before the EXPERIMENTAL info was not printed when the substrait plan failed to decode into either a Plan or PlanVersion. Given that it is otherwise always printed, it seemed more consistent to have it always be printed (even under full decode failures) and update this test, rathen than to make the code match this test.

"Please report issues via "
"https://github.com/substrait-io/substrait-validator/issues/new "
"(code 0999)\n"
"Error at plan: failed to parse as substrait.Plan: "
"failed to decode Protobuf message: "
"invalid wire type value: 7 (code 1001) (code 1001)\n"
"Error at plan: failed to parse as substrait.PlanVersion: "
"failed to decode Protobuf message: "
"invalid wire type value: 7 (code 1001) (code 1001)\n"));

// Free the buffer.
substrait_validator_free_exported(data_ptr);
Expand Down
25 changes: 16 additions & 9 deletions rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
//! 1) Build a [`Config`] structure to configure the validator. You can also
//! just use [`std::default::Default`] if you don't need to configure
//! anything, but you might want to at least call
//! `Config::add_curl_uri_resolver()` (if you're using the `curl`
//! feature).
//! 2) Parse the incoming `substrait.Plan` message using [`parse()`]. This
//! creates a [ParseResult], containing a [tree](output::tree) structure
//! corresponding to the query plan that also contains diagnostics and
//! other annotations added by the validator.
//! 3) You can traverse the tree yourself using [ParseResult::root], or you
//! can use one of the methods associated with [ParseResult] to obtain the
//! `Config::add_curl_uri_resolver()` (if you're using the `curl` feature).
//! 2) Parse the incoming `substrait.Plan` message using [`parse()`] or
//! [`validate()`]. This creates a [ParseResult], containing a
//! [tree](output::tree) structure corresponding to the query plan that also
//! contains diagnostics and other annotations added by the validator.
//! 3) You can traverse the tree yourself using [ParseResult::root], or you can
//! use one of the methods associated with [ParseResult] to obtain the
//! validation results you need.
//!
//! Note that only the binary protobuf serialization format is supported at the
Expand Down Expand Up @@ -168,6 +167,7 @@ mod util;

use std::str::FromStr;

use input::proto::substrait::Plan;
use strum::IntoEnumIterator;

// Aliases for common types used on the crate interface.
Expand All @@ -180,11 +180,18 @@ pub use output::diagnostic::Level;
pub use output::parse_result::ParseResult;
pub use output::parse_result::Validity;

/// Validates the given substrait.Plan message and returns the parse tree.
/// Parses and validates the given substrait [Plan] message and returns the
/// parse tree and diagnostic results.
pub fn parse<B: prost::bytes::Buf + Clone>(buffer: B, config: &Config) -> ParseResult {
parse::parse(buffer, config)
}

/// Validates the given substrait [Plan] message and returns the parse tree and
/// diagnostic results.
pub fn validate(plan: &Plan, config: &Config) -> ParseResult {
parse::validate(plan, config)
}

/// Returns an iterator that yields all known diagnostic classes.
pub fn iter_diagnostics() -> impl Iterator<Item = Classification> {
Classification::iter()
Expand Down
104 changes: 81 additions & 23 deletions rs/src/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,44 +197,86 @@ use crate::output::diagnostic;
use crate::output::parse_result;
use crate::output::path;

/// Validates the given substrait.Plan message and returns the parse tree.
use prost::Message;

/// Represents the state of attempting to decode a buffer as a plan, or failing
/// that, as a plan version.
enum ParsedProtoResult {
/// Successful parsing
Parsed(proto::substrait::Plan),

/// Failed to decode into `substrait::Plan`, possibly due to a version
/// conflict; see `Cause`. Succeeded in decoding the version.
VersionParsed(proto::substrait::PlanVersion, diagnostic::Cause),

/// Failed to decode into `substrait::Plan` (first cause) and
/// `substrait::PlanVersion` (second cause).
Failed(diagnostic::Cause, diagnostic::Cause),
}

/// Parse the given buffer as a Plan if possible, PlanVersion if not, returning
/// errors as they are generated.
fn parse_proto<B: prost::bytes::Buf + Clone>(buffer: B) -> ParsedProtoResult {
// Attempt to parse the buffer as a Plan.
let err1 = match proto::substrait::Plan::decode(buffer.clone()) {
Ok(plan) => return ParsedProtoResult::Parsed(plan),
Err(e) => ecause!(ProtoParseFailed, e),
};

// Attempt to parse the buffer as a PlanVersion, as fallback.
match proto::substrait::PlanVersion::decode(buffer) {
Ok(version) => ParsedProtoResult::VersionParsed(version, err1),
Err(e2) => ParsedProtoResult::Failed(err1, ecause!(ProtoParseFailed, e2)),
}
}

/// Parses the given [`proto::substrait::Plan`] message, validates it, and
/// returns the parse tree with diagnostic results.
pub fn parse<B: prost::bytes::Buf + Clone>(
buffer: B,
config: &config::Config,
) -> parse_result::ParseResult {
let mut state = context::State::default();

// Parse the normal way.
let err1 = match traversal::parse_proto::<proto::substrait::Plan, _, _>(
buffer.clone(),
"plan",
plan::parse_plan,
&mut state,
config,
) {
Ok(parse_result) => return parse_result,
Err(err) => err,
};
let (err1, err2) = match parse_proto(buffer) {
ParsedProtoResult::Parsed(ref plan) => {
return traversal::validate::<proto::substrait::Plan, _>(
plan,
"plan",
plan::parse_plan,
&mut state,
config,
);
}
ParsedProtoResult::VersionParsed(ref version, err1) => {
return traversal::validate::<proto::substrait::PlanVersion, _>(
version,
"plan",
|tree, ctx| {
// We have a PlanVersion, but the Plan itself failed to
// decode - so include that error.
diagnostic!(ctx, Error, err1);

// Parse the fallback PlanVersion message that only includes the version
// information.
let err2 = match traversal::parse_proto::<proto::substrait::PlanVersion, _, _>(
buffer,
"plan",
|x, y| plan::parse_plan_version(x, y, err1.clone()),
&mut state,
config,
) {
Ok(parse_result) => return parse_result,
Err(err) => err,
plan::parse_plan_version(tree, ctx)
},
&mut state,
config,
);
}
ParsedProtoResult::Failed(err1, err2) => (err1, err2),
};

// --------------------------------------------------------------------------------
// The parser failed to decode the buffer as either a plan or plan version, so now
// we create a minimal parse tree with the error diagnostics from the parser.

// Create a minimal root node with just the decode error
// diagnostic.
let mut root = proto::substrait::Plan::type_to_node();

// Create a root context for it.
let mut context = context::Context::new("plan", &mut root, &mut state, config);
plan::mark_experimental(&mut context);

// Push the earlier diagnostic.
context.push_diagnostic(diagnostic::RawDiagnostic {
Expand Down Expand Up @@ -264,3 +306,19 @@ pub fn parse<B: prost::bytes::Buf + Clone>(

parse_result::ParseResult { root }
}

/// Validate the given [`proto::substrait::Plan`] message, returning the parse
/// tree with diagnostic results.
pub fn validate(
plan: &proto::substrait::Plan,
config: &config::Config,
) -> parse_result::ParseResult {
let mut state = context::State::default();
traversal::validate::<proto::substrait::Plan, _>(
plan,
"plan",
plan::parse_plan,
&mut state,
config,
)
}
39 changes: 14 additions & 25 deletions rs/src/parse/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ fn parse_version(x: &substrait::Version, y: &mut context::Context) -> diagnostic
}

/// Report the "validator is experimental" diagnostic.
fn mark_experimental(y: &mut context::Context) {
pub fn mark_experimental(ctx: &mut context::Context) {
diagnostic!(
y,
ctx,
Info,
Experimental,
"this version of the validator is EXPERIMENTAL. Please report issues \
Expand All @@ -168,23 +168,23 @@ fn mark_experimental(y: &mut context::Context) {
}

/// Toplevel parse function for a plan.
pub fn parse_plan(x: &substrait::Plan, y: &mut context::Context) -> diagnostic::Result<()> {
mark_experimental(y);
pub fn parse_plan(plan: &substrait::Plan, ctx: &mut context::Context) {
mark_experimental(ctx);

// Parse the version.
proto_required_field!(x, y, version, parse_version);
proto_required_field!(plan, ctx, version, parse_version);

// Handle extensions first, because we'll need their declarations to
// correctly interpret the relations.
extensions::parse_plan(x, y);
extensions::parse_plan(plan, ctx);

// Handle the relations.
let num_relations = proto_repeated_field!(x, y, relations, parse_plan_rel)
let num_relations = proto_repeated_field!(plan, ctx, relations, parse_plan_rel)
.0
.len();
if num_relations == 0 {
diagnostic!(
y,
ctx,
Error,
RelationRootMissing,
"a plan must have at least one relation"
Expand All @@ -193,25 +193,14 @@ pub fn parse_plan(x: &substrait::Plan, y: &mut context::Context) -> diagnostic::

// Generate an Info diagnostic for every extension definition that wasn't
// used at any point, and can thus be safely removed.
extensions::check_unused_definitions(y);

Ok(())
extensions::check_unused_definitions(ctx);
}

/// Toplevel parse function for a plan.
pub fn parse_plan_version(
x: &substrait::PlanVersion,
y: &mut context::Context,
e: diagnostic::Cause,
) -> diagnostic::Result<()> {
mark_experimental(y);

// Push the diagnostic that the caller got while parsing as a complete Plan
// before.
diagnostic!(y, Error, e);
/// Toplevel validation function for a plan. Validates that the PlanVersion
/// matches expected format, pushing errors to the `Context`.
pub fn parse_plan_version(tree: &substrait::PlanVersion, ctx: &mut context::Context) {
mark_experimental(ctx);

// Parse the version.
proto_required_field!(x, y, version, parse_version);

Ok(())
proto_required_field!(tree, ctx, version, parse_version);
}
30 changes: 22 additions & 8 deletions rs/src/parse/traversal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,31 +587,45 @@ pub fn parse_proto<T, F, B>(
) -> diagnostic::Result<parse_result::ParseResult>
where
T: prost::Message + InputNode + Default,
F: FnOnce(&T, &mut context::Context) -> diagnostic::Result<()>,
F: FnOnce(&T, &mut context::Context),
B: prost::bytes::Buf,
{
// Run protobuf deserialization.
let input = T::decode(buffer).map_err(|e| ecause!(ProtoParseFailed, e))?;

Ok(validate(&input, root_name, root_parser, state, config))
}

/// Validates a serialized protobuf message using the given `root_parser`
/// function, initial state, and configuration, pushing any errors from the
/// root_validator or unhandled children into the returned
/// [`parse_result::ParseResult`].
pub fn validate<T, F>(
input: &T,
root_name: &'static str,
root_validator: F,
state: &mut context::State,
config: &config::Config,
) -> parse_result::ParseResult
where
T: prost::Message + InputNode + Default,
F: FnOnce(&T, &mut context::Context),
{
// Create the root node.
let mut root = input.data_to_node();

// Create the root context.
let mut context = context::Context::new(root_name, &mut root, state, config);

// Call the provided parser function.
let success = root_parser(&input, &mut context)
.map_err(|cause| {
diagnostic!(&mut context, Error, cause);
})
.is_ok();
root_validator(input, &mut context);

// Handle any fields not handled by the provided parse function.
// Only generate a warning diagnostic for unhandled children if the
// parse function succeeded.
handle_unknown_children(&input, &mut context, success);
handle_unknown_children(input, &mut context, true);

Ok(parse_result::ParseResult { root })
parse_result::ParseResult { root }
}

//=============================================================================
Expand Down
Loading