Skip to content

Commit

Permalink
pybind11_protobuf:
Browse files Browse the repository at this point in the history
`LOG(WARNING)` `FALL BACK TO PROTOBUF SERIALIZE/PARSE` based on `ExtensionsWithUnknownFieldsPolicy`.

Low-level change in preparation for PyCLIF-pybind11 rollout.

PiperOrigin-RevId: 604127759
  • Loading branch information
Ralf W. Grosse-Kunstleve authored and copybara-github committed Feb 5, 2024
1 parent 3b11990 commit ce2f2a6
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 24 deletions.
3 changes: 0 additions & 3 deletions pybind11_protobuf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ cc_library(
name = "check_unknown_fields",
srcs = ["check_unknown_fields.cc"],
hdrs = ["check_unknown_fields.h"],
visibility = [
"//visibility:private",
],
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
5 changes: 1 addition & 4 deletions pybind11_protobuf/check_unknown_fields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,

std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* message, bool build_error_message_if_any) {
const ::google::protobuf::Message* message) {
const auto* root_descriptor = message->GetDescriptor();
HasUnknownFields search{py_proto_api, root_descriptor};
if (!search.FindUnknownFieldsRecursive(message, 0u)) {
Expand All @@ -193,9 +193,6 @@ std::optional<std::string> CheckRecursively(
search.FieldFQN())) != 0) {
return std::nullopt;
}
if (!build_error_message_if_any) {
return ""; // This indicates that an unknown field was found.
}
return search.BuildErrorMessage();
}

Expand Down
2 changes: 1 addition & 1 deletion pybind11_protobuf/check_unknown_fields.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,

std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* top_message, bool build_error_message_if_any);
const ::google::protobuf::Message* top_message);

} // namespace pybind11_protobuf::check_unknown_fields

Expand Down
14 changes: 9 additions & 5 deletions pybind11_protobuf/proto_cast_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <iostream>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -828,14 +829,17 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,

std::optional<std::string> unknown_field_message =
check_unknown_fields::CheckRecursively(
GlobalState::instance()->py_proto_api(), src,
check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::
UnknownFieldsAreDisallowed());
GlobalState::instance()->py_proto_api(), src);
if (unknown_field_message) {
if (!unknown_field_message->empty()) {
if (check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::
UnknownFieldsAreDisallowed()) {
throw py::value_error(*unknown_field_message);
}
// Fall back to serialize/parse.
static auto fall_back_log_shown = new std::unordered_set<std::string>();
if (fall_back_log_shown->insert(*unknown_field_message).second) {
LOG(WARNING) << "FALL BACK TO PROTOBUF SERIALIZE/PARSE: "
<< *unknown_field_message;
}
return GenericPyProtoCast(src, policy, parent, is_const);
}

Expand Down
32 changes: 24 additions & 8 deletions pybind11_protobuf/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ pybind_extension(
],
)

EXTENSION_TEST_DEPS_COMMON = [
":extension_in_other_file_in_deps_py_pb2",
":extension_in_other_file_py_pb2",
":extension_nest_repeated_py_pb2",
":extension_py_pb2",
":test_py_pb2", # fixdeps: keep - Direct dependency needed in open-source version, see https://github.com/grpc/grpc/issues/22811
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
]

py_test(
name = "extension_test",
srcs = ["extension_test.py"],
Expand All @@ -170,14 +180,20 @@ py_test(
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":extension_in_other_file_in_deps_py_pb2",
":extension_in_other_file_py_pb2",
":extension_nest_repeated_py_pb2",
":extension_py_pb2",
":test_py_pb2", # fixdeps: keep - Direct dependency needed in open-source version, see https://github.com/grpc/grpc/issues/22811
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
deps = EXTENSION_TEST_DEPS_COMMON + ["@com_google_protobuf//:protobuf_python"],
)

py_test(
name = "extension_disallow_unknown_fields_test",
srcs = ["extension_test.py"],
data = [
":extension_module.so",
":proto_enum_module.so",
],
main = "extension_test.py",
python_version = "PY3",
srcs_version = "PY3",
deps = EXTENSION_TEST_DEPS_COMMON + [
"@com_google_protobuf//:protobuf_python",
],
)
Expand Down
5 changes: 5 additions & 0 deletions pybind11_protobuf/tests/extension_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ void DefReserialize(py::module_& m, const char* py_name) {
PYBIND11_MODULE(extension_module, m) {
pybind11_protobuf::ImportNativeProtoCasters();

m.def("extensions_with_unknown_fields_are_disallowed", []() {
return pybind11_protobuf::check_unknown_fields::
ExtensionsWithUnknownFieldsPolicy::UnknownFieldsAreDisallowed();
});

m.def("get_base_message", []() -> BaseMessage { return {}; });

m.def(
Expand Down
13 changes: 10 additions & 3 deletions pybind11_protobuf/tests/extension_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from pybind11_protobuf.tests import extension_pb2


def unknown_field_exception_is_expected():
return (
api_implementation.Type() == 'cpp'
and m.extensions_with_unknown_fields_are_disallowed()
)


def get_py_message(value=5,
in_other_file_in_deps_value=None,
in_other_file_value=None):
Expand Down Expand Up @@ -103,7 +110,7 @@ def test_extension_in_other_file_roundtrip(self):

def test_reserialize_base_message(self):
a = get_py_message(in_other_file_value=63)
if api_implementation.Type() == 'cpp':
if unknown_field_exception_is_expected():
with self.assertRaises(ValueError) as ctx:
m.reserialize_base_message(a)
self.assertStartsWith(
Expand All @@ -127,7 +134,7 @@ def test_reserialize_nest_level2(self):
a = extension_pb2.NestLevel2(
nest_lvl1=extension_pb2.NestLevel1(
base_msg=get_py_message(in_other_file_value=52)))
if api_implementation.Type() == 'cpp':
if unknown_field_exception_is_expected():
with self.assertRaises(ValueError) as ctx:
m.reserialize_nest_level2(a)
self.assertStartsWith(
Expand All @@ -154,7 +161,7 @@ def test_reserialize_nest_repeated(self):
get_py_message(in_other_file_value=74),
get_py_message(in_other_file_value=85)
])
if api_implementation.Type() == 'cpp':
if unknown_field_exception_is_expected():
with self.assertRaises(ValueError) as ctx:
m.reserialize_nest_repeated(a)
self.assertStartsWith(
Expand Down

0 comments on commit ce2f2a6

Please sign in to comment.