diff --git a/README.md b/README.md index 91a376d..92351f1 100644 --- a/README.md +++ b/README.md @@ -78,13 +78,18 @@ protobuf extensions are involved, a well-known pitfall is that extensions are silently moved to the `proto2::UnknownFieldSet` when a message is deserialized in C++, but the `cc_proto_library` for the extensions is not linked in. The root -cause is an asymmetry in the handling of Python protos vs C++ protos: when -a Python proto is deserialized, both the Python descriptor pool and the C++ -descriptor pool are inspected, but when a C++ proto is deserialized, only +cause is an asymmetry in the handling of Python protos vs C++ +protos: +when a Python proto is deserialized, both the Python descriptor pool and the +C++ descriptor pool are inspected, but when a C++ proto is deserialized, only the C++ descriptor pool is inspected. Until this asymmetry is resolved, the `cc_proto_library` for all extensions involved must be added to the `deps` of -the relevant `pybind_library` or `pybind_extension`, but this is sufficiently -unobvious to be a setup for regular accidents, potentially with critical +the relevant `pybind_library` or `pybind_extension`, or if this is impractial, +`pybind11_protobuf::runtime_config::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse` +or `pybind11_protobuf::AllowUnknownFieldsFor` can be used. + +The pitfall is sufficiently unobvious to be a setup for regular accidents, +potentially with critical consequences. To guard against the most common type of accident, native_proto_caster.h @@ -97,14 +102,20 @@ in certain situations: * and the `proto2::UnknownFieldSet` for the message or any of its submessages is not empty. -`pybind11_protobuf::AllowUnknownFieldsFor` is an escape hatch for situations in -which +`pybind11_protobuf::runtime_config::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse` +is a **global** escape hatch trading off convenience and runtime overhead: the +convenience is that it is not necessary to determine what `cc_proto_library` +dependencies need to be added, the runtime overhead is that +`SerializePartialToString`/`ParseFromString` is used for messages with unknown +fields, instead of the much faster `CopyFrom`. -* unknown fields existed before the safety mechanism was - introduced. -* unknown fields are needed in the future. +Another escape hatch is `pybind11_protobuf::AllowUnknownFieldsFor`, which +simply disables the safety mechanism for **specific message types**, without +a runtime overhead. This is useful for situations in which unknown fields +are acceptable. -An example of a full error message (with lines breaks here for readability): +An example of a full error message generated by the safety mechanism +(with lines breaks here for readability): ``` Proto Message of type pybind11.test.NestRepeated has an Unknown Field with @@ -117,14 +128,13 @@ Only if there is no alternative to suppressing this error, use (Warning: suppressions may mask critical bugs.) ``` -The current implementation is a compromise solution, trading off simplicity -of implementation, runtime performance, and precision. Generally, the runtime -overhead is expected to be very small, but fields flagged as unknown may not -necessarily be in extensions. -Alerting developers of new code to unknown fields is assumed to be generally -helpful, but the unknown fields detection is limited to messages with -extensions, to avoid the runtime overhead for the presumably much more common -case that no extensions are involved. +Note that the current implementation of the safety mechanism is a compromise +solution, trading off simplicity of implementation, runtime performance, +and precision. Alerting developers of new code to unknown fields is assumed +to be generally helpful, but the unknown fields detection is limited to +messages with extensions, to avoid the runtime overhead for the presumably +much more common case that no extensions are involved. Because of this, +the runtime overhead for the safety mechanism is expected to be very small. ### Enumerations diff --git a/pybind11_protobuf/BUILD b/pybind11_protobuf/BUILD index 528db9e..b48dd54 100644 --- a/pybind11_protobuf/BUILD +++ b/pybind11_protobuf/BUILD @@ -5,6 +5,26 @@ load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") licenses(["notice"]) +cc_library( + name = "runtime_config", + hdrs = ["runtime_config.h"], + visibility = [ + "//visibility:public", + ], +) + +cc_library( + name = "disallow_extensions_with_unknown_fields", + srcs = ["disallow_extensions_with_unknown_fields.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":runtime_config", + ], + alwayslink = 1, +) + pybind_library( name = "enum_type_caster", hdrs = ["enum_type_caster.h"], @@ -58,6 +78,7 @@ pybind_library( }), deps = [ ":check_unknown_fields", + ":runtime_config", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/pybind11_protobuf/check_unknown_fields.cc b/pybind11_protobuf/check_unknown_fields.cc index 5b5e2c6..f0d1fab 100644 --- a/pybind11_protobuf/check_unknown_fields.cc +++ b/pybind11_protobuf/check_unknown_fields.cc @@ -181,9 +181,9 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, unknown_field_parent_message_fqn)); } -std::optional CheckAndBuildErrorMessageIfAny( +std::optional RunCheck( const ::google::protobuf::python::PyProto_API* py_proto_api, - const ::google::protobuf::Message* message) { + const ::google::protobuf::Message* message, bool build_error_message_if_any) { const auto* root_descriptor = message->GetDescriptor(); HasUnknownFields search{py_proto_api, root_descriptor}; if (!search.FindUnknownFieldsRecursive(message, 0u)) { @@ -193,6 +193,9 @@ std::optional CheckAndBuildErrorMessageIfAny( search.FieldFQN())) != 0) { return std::nullopt; } + if (!build_error_message_if_any) { + return ""; // This indicates that an unknown field was found. + } return search.BuildErrorMessage(); } diff --git a/pybind11_protobuf/check_unknown_fields.h b/pybind11_protobuf/check_unknown_fields.h index 4fc9771..1a00f43 100644 --- a/pybind11_protobuf/check_unknown_fields.h +++ b/pybind11_protobuf/check_unknown_fields.h @@ -12,9 +12,9 @@ namespace pybind11_protobuf::check_unknown_fields { void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, absl::string_view unknown_field_parent_message_fqn); -std::optional CheckAndBuildErrorMessageIfAny( +std::optional RunCheck( const ::google::protobuf::python::PyProto_API* py_proto_api, - const ::google::protobuf::Message* top_message); + const ::google::protobuf::Message* top_message, bool build_error_message_if_any); } // namespace pybind11_protobuf::check_unknown_fields diff --git a/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc b/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc new file mode 100644 index 0000000..9e6cd43 --- /dev/null +++ b/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc @@ -0,0 +1,12 @@ +#include "pybind11_protobuf/runtime_config.h" + +namespace pybind11_protobuf::runtime_config { +namespace { + +static int kSetConfigDone = []() { + ExtensionsWithUnknownFieldsPolicy::StrongSetDisallow(); + return 0; +}(); + +} // namespace +} // namespace pybind11_protobuf::runtime_config diff --git a/pybind11_protobuf/proto_cast_util.cc b/pybind11_protobuf/proto_cast_util.cc index a634429..86b3ab1 100644 --- a/pybind11_protobuf/proto_cast_util.cc +++ b/pybind11_protobuf/proto_cast_util.cc @@ -24,6 +24,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "pybind11_protobuf/check_unknown_fields.h" +#include "pybind11_protobuf/runtime_config.h" namespace py = pybind11; @@ -814,17 +815,22 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy, // 1. The binary does not have a py_proto_api instance, or // 2. a) the proto is from the default pool and // b) the binary is not using fast_cpp_protos. - if ((GlobalState::instance()->py_proto_api() == nullptr) || + if (GlobalState::instance()->py_proto_api() == nullptr || (src->GetDescriptor()->file()->pool() == DescriptorPool::generated_pool() && !GlobalState::instance()->using_fast_cpp())) { return GenericPyProtoCast(src, policy, parent, is_const); } - std::optional emsg = - check_unknown_fields::CheckAndBuildErrorMessageIfAny( - GlobalState::instance()->py_proto_api(), src); + bool build_error_message_if_any = runtime_config:: + ExtensionsWithUnknownFieldsPolicy::UnknownFieldsAreDisallowed(); + std::optional emsg = check_unknown_fields::RunCheck( + GlobalState::instance()->py_proto_api(), src, build_error_message_if_any); if (emsg) { + if (!build_error_message_if_any) { + // Found an unknown field: fall back to serialize/parse. + return GenericPyProtoCast(src, policy, parent, is_const); + } throw py::value_error(*emsg); } diff --git a/pybind11_protobuf/runtime_config.h b/pybind11_protobuf/runtime_config.h new file mode 100644 index 0000000..a3d1ee3 --- /dev/null +++ b/pybind11_protobuf/runtime_config.h @@ -0,0 +1,42 @@ +#ifndef PYBIND11_PROTOBUF_RUNTIME_CONFIG_H_ +#define PYBIND11_PROTOBUF_RUNTIME_CONFIG_H_ + +namespace pybind11_protobuf::runtime_config { + +// For background see "Protobuf Extensions" section in README.md. +class ExtensionsWithUnknownFieldsPolicy { + enum State { + // Initial state. + kWeakDisallow, + + // Primary use case: PyCLIF extensions might set this when being imported. + kWeakEnableFallbackToSerializeParse, + + // Primary use case: `:disallow_extensions_with_unknown_fields` in `deps` + // of a binary (or test). + kStrongDisallow + }; + + static State& GetStateSingleton() { + static State singleton = kWeakDisallow; + return singleton; + } + + public: + static void WeakEnableFallbackToSerializeParse() { + State& policy = GetStateSingleton(); + if (policy == kWeakDisallow) { + policy = kWeakEnableFallbackToSerializeParse; + } + } + + static void StrongSetDisallow() { GetStateSingleton() = kStrongDisallow; } + + static bool UnknownFieldsAreDisallowed() { + return GetStateSingleton() != kWeakEnableFallbackToSerializeParse; + } +}; + +} // namespace pybind11_protobuf::runtime_config + +#endif // PYBIND11_PROTOBUF_RUNTIME_CONFIG_H_