From 1ba4fd70fc6685e73070791432e52ab478fb8b0b Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Fri, 18 Oct 2024 09:14:15 -0700 Subject: [PATCH] Add a consistency check to the parser to verify that has bits are valid on parse failure, and fix the issues found by tests. Keeping the invariant can help performance and future changes. PiperOrigin-RevId: 687326603 --- src/google/protobuf/BUILD.bazel | 3 + .../protobuf/generated_message_tctable_impl.h | 8 + .../generated_message_tctable_lite.cc | 175 +++++++++++++++++- .../generated_message_tctable_lite_test.cc | 39 ++++ 4 files changed, 224 insertions(+), 1 deletion(-) diff --git a/src/google/protobuf/BUILD.bazel b/src/google/protobuf/BUILD.bazel index 38aa70484594..08c2398f59ae 100644 --- a/src/google/protobuf/BUILD.bazel +++ b/src/google/protobuf/BUILD.bazel @@ -1583,9 +1583,12 @@ cc_test( }), deps = [ ":cc_test_protos", + ":descriptor_visitor", ":port", + ":protobuf", ":protobuf_lite", "//src/google/protobuf/io", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", diff --git a/src/google/protobuf/generated_message_tctable_impl.h b/src/google/protobuf/generated_message_tctable_impl.h index 1334822dd82e..d09993ecabf8 100644 --- a/src/google/protobuf/generated_message_tctable_impl.h +++ b/src/google/protobuf/generated_message_tctable_impl.h @@ -822,6 +822,9 @@ class PROTOBUF_EXPORT TcParser final { }; } + static void VerifyHasBitConsistency(const MessageLite* msg, + const TcParseTableBase* table); + private: // Optimized small tag varint parser for int32/int64 template @@ -1027,6 +1030,8 @@ class PROTOBUF_EXPORT TcParser final { static absl::string_view MessageName(const TcParseTableBase* table); static absl::string_view FieldName(const TcParseTableBase* table, const TcParseTableBase::FieldEntry*); + static int FieldNumber(const TcParseTableBase* table, + const TcParseTableBase::FieldEntry*); static bool ChangeOneof(const TcParseTableBase* table, const TcParseTableBase::FieldEntry& entry, uint32_t field_num, ParseContext* ctx, @@ -1152,6 +1157,9 @@ inline PROTOBUF_ALWAYS_INLINE const char* TcParser::ParseLoop( if (ABSL_PREDICT_FALSE(table->has_post_loop_handler)) { return table->post_loop_handler(msg, ptr, ctx); } + if (ABSL_PREDICT_FALSE(PerformDebugChecks() && ptr == nullptr)) { + VerifyHasBitConsistency(msg, table); + } return ptr; } diff --git a/src/google/protobuf/generated_message_tctable_lite.cc b/src/google/protobuf/generated_message_tctable_lite.cc index 6ed1157ce2ce..10c434124205 100644 --- a/src/google/protobuf/generated_message_tctable_lite.cc +++ b/src/google/protobuf/generated_message_tctable_lite.cc @@ -20,7 +20,9 @@ #include "absl/log/absl_log.h" #include "absl/numeric/bits.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "google/protobuf/arenastring.h" #include "google/protobuf/generated_enum_util.h" #include "google/protobuf/generated_message_tctable_decl.h" @@ -73,6 +75,120 @@ const char* TcParser::GenericFallbackLite(PROTOBUF_TC_PARAM_DECL) { PROTOBUF_TC_PARAM_PASS); } +namespace { +bool ReadHas(const FieldEntry& entry, const MessageLite* msg) { + auto has_idx = static_cast(entry.has_idx); + const auto& hasblock = TcParser::RefAt(msg, has_idx / 32 * 4); + return (hasblock & (uint32_t{1} << (has_idx % 32))) != 0; +} +} // namespace + +void TcParser::VerifyHasBitConsistency(const MessageLite* msg, + const TcParseTableBase* table) { + namespace fl = internal::field_layout; + if (table->has_bits_offset == 0) { + // Nothing to check + return; + } + + for (const auto& entry : table->field_entries()) { + const auto print_error = [&] { + return absl::StrFormat("Type=%s Field=%d\n", msg->GetTypeName(), + FieldNumber(table, &entry)); + }; + if ((entry.type_card & fl::kFcMask) != fl::kFcOptional) return; + const bool has_bit = ReadHas(entry, msg); + const void* base = msg; + const void* default_base = table->default_instance(); + if ((entry.type_card & field_layout::kSplitMask) == + field_layout::kSplitTrue) { + const size_t offset = table->field_aux(kSplitOffsetAuxIdx)->offset; + base = TcParser::RefAt(base, offset); + default_base = TcParser::RefAt(default_base, offset); + } + switch (entry.type_card & fl::kFkMask) { + case fl::kFkVarint: + case fl::kFkFixed: + // Numerics can have any value when the has bit is on. + if (has_bit) return; + switch (entry.type_card & fl::kRepMask) { + case fl::kRep8Bits: + ABSL_CHECK_EQ(RefAt(base, entry.offset), + RefAt(default_base, entry.offset)) + << print_error(); + break; + case fl::kRep32Bits: + ABSL_CHECK_EQ(RefAt(base, entry.offset), + RefAt(default_base, entry.offset)) + << print_error(); + break; + case fl::kRep64Bits: + ABSL_CHECK_EQ(RefAt(base, entry.offset), + RefAt(default_base, entry.offset)) + << print_error(); + break; + } + break; + + case fl::kFkString: + switch (entry.type_card & fl::kRepMask) { + case field_layout::kRepAString: + if (has_bit) { + // Must not point to the default. + ABSL_CHECK(!RefAt(base, entry.offset).IsDefault()) + << print_error(); + } else { + // We should technically check that the value matches the default + // value of the field, but the prototype does not actually contain + // this value. Non-empty defaults are loaded on access. + } + break; + case field_layout::kRepCord: + if (!has_bit) { + // If the has bit is off, it must match the default. + ABSL_CHECK_EQ(RefAt(base, entry.offset), + RefAt(default_base, entry.offset)) + << print_error(); + } + break; + case field_layout::kRepIString: + if (!has_bit) { + // If the has bit is off, it must match the default. + ABSL_CHECK_EQ( + RefAt(base, entry.offset).Get(), + RefAt(default_base, entry.offset).Get()) + << print_error(); + } + break; + case field_layout::kRepSString: + Unreachable(); + } + break; + case fl::kFkMessage: + switch (entry.type_card & fl::kRepMask) { + case fl::kRepMessage: + case fl::kRepGroup: + if (has_bit) { + ABSL_CHECK(RefAt(base, entry.offset) != + nullptr) + << print_error(); + } else { + // An off has_bit does not imply a null pointer. + // We might have a previous instance that we cached. + } + break; + default: + Unreachable(); + } + break; + + default: + // All other types are not `optional`. + Unreachable(); + } + } +} + ////////////////////////////////////////////////////////////////////////////// // Core fast parsing implementation: ////////////////////////////////////////////////////////////////////////////// @@ -226,6 +342,46 @@ absl::string_view TcParser::FieldName(const TcParseTableBase* table, field_index + 1); } +int TcParser::FieldNumber(const TcParseTableBase* table, + const TcParseTableBase::FieldEntry* entry) { + // The data structure was not designed to be queried in this direction, so + // we have to do a linear search over the entries to see which one matches + // while keeping track of the field number. + // But it is fine because we are only using this for debug check messages. + size_t need_to_skip = entry - table->field_entries_begin(); + const auto visit_bitmap = [&](uint32_t field_bitmap, + int base_field_number) -> absl::optional { + for (; field_bitmap != 0; field_bitmap &= field_bitmap - 1) { + if (need_to_skip == 0) { + return absl::countr_zero(field_bitmap) + base_field_number; + } + --need_to_skip; + } + return absl::nullopt; + }; + if (auto number = visit_bitmap(~table->skipmap32, 1)) { + return *number; + } + + for (const uint16_t* lookup_table = table->field_lookup_begin(); + lookup_table[0] != 0xFFFF || lookup_table[1] != 0xFFFF;) { + uint32_t fstart = lookup_table[0] | (lookup_table[1] << 16); + lookup_table += 2; + const uint16_t num_skip_entries = *lookup_table++; + for (uint16_t i = 0; i < num_skip_entries; ++i) { + // for each group of 16 fields we have: a + // bitmap of 16 bits a 16-bit field-entry + // offset for the first of them. + if (auto number = visit_bitmap(static_cast(~*lookup_table), + fstart + 16 * i)) { + return *number; + } + lookup_table += 2; + } + } + Unreachable(); +} + PROTOBUF_NOINLINE const char* TcParser::Error(PROTOBUF_TC_PARAM_NO_DATA_DECL) { (void)ctx; (void)ptr; @@ -1403,6 +1559,19 @@ PROTOBUF_ALWAYS_INLINE inline bool IsValidUTF8(ArenaStringPtr& field) { } +void EnsureArenaStringIsNotDefault(const MessageLite* msg, + ArenaStringPtr* field) { + // If we failed here we might have left the string in its IsDefault state, but + // already set the has bit which breaks the message invariants. We must make + // it consistent again. We do that by guaranteeing the string always exists. + if (field->IsDefault()) { + field->Set("", msg->GetArena()); + } +} +// The rest do nothing. +PROTOBUF_UNUSED void EnsureArenaStringIsNotDefault(const MessageLite* msg, + void*) {} + } // namespace template @@ -1423,6 +1592,7 @@ inline PROTOBUF_ALWAYS_INLINE const char* TcParser::SingularString( ptr = ReadStringNoArena(msg, ptr, ctx, data.aux_idx(), table, field); } if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) { + EnsureArenaStringIsNotDefault(msg, &field); PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); } switch (utf8) { @@ -2180,7 +2350,10 @@ PROTOBUF_NOINLINE const char* TcParser::MpString(PROTOBUF_TC_PARAM_DECL) { std::string* str = field.MutableNoCopy(nullptr); ptr = InlineGreedyStringParser(str, ptr, ctx); } - if (!ptr) break; + if (ABSL_PREDICT_FALSE(ptr == nullptr)) { + EnsureArenaStringIsNotDefault(msg, &field); + break; + } is_valid = MpVerifyUtf8(field.Get(), table, entry, xform_val); break; } diff --git a/src/google/protobuf/generated_message_tctable_lite_test.cc b/src/google/protobuf/generated_message_tctable_lite_test.cc index 136973f63a5b..1e4b40f572ab 100644 --- a/src/google/protobuf/generated_message_tctable_lite_test.cc +++ b/src/google/protobuf/generated_message_tctable_lite_test.cc @@ -11,14 +11,19 @@ #include #include +#include "absl/algorithm/container.h" #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor_database.h" +#include "google/protobuf/descriptor_visitor.h" #include "google/protobuf/generated_message_tctable_decl.h" #include "google/protobuf/generated_message_tctable_impl.h" #include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/message_lite.h" #include "google/protobuf/parse_context.h" #include "google/protobuf/unittest.pb.h" #include "google/protobuf/wire_format_lite.h" @@ -333,6 +338,10 @@ class FindFieldEntryTest : public ::testing::Test { return TcParser::FieldName(&table.header, entry); } + static int FieldNumber(const TcParseTableBase* table, size_t index) { + return TcParser::FieldNumber(table, table->field_entries_begin() + index); + } + // Calls the private `MessageName` function. template @@ -346,6 +355,36 @@ class FindFieldEntryTest : public ::testing::Test { static constexpr int small_scan_size() { return TcParser::kMtSmallScanSize; } }; +TEST_F(FindFieldEntryTest, FieldNumberWorksForAllFields) { + // Look at all types registered in the binary and verify that field number + // calculation works for all the fields. + auto* gen_db = DescriptorPool::internal_generated_database(); + std::vector all_file_names; + gen_db->FindAllFileNames(&all_file_names); + + for (const auto& filename : all_file_names) { + SCOPED_TRACE(filename); + const auto* file = + DescriptorPool::generated_pool()->FindFileByName(filename); + VisitDescriptors(*file, [&](const Descriptor& desc) { + SCOPED_TRACE(desc.full_name()); + const auto* prototype = + MessageFactory::generated_factory()->GetPrototype(&desc); + const auto* tc_table = internal::GetClassData(*prototype)->tc_table; + + std::vector sorted_field_numbers; + for (auto* field : internal::FieldRange(&desc)) { + sorted_field_numbers.push_back(field->number()); + } + absl::c_sort(sorted_field_numbers); + + for (int i = 0; i < desc.field_count(); ++i) { + EXPECT_EQ(FieldNumber(tc_table, i), sorted_field_numbers[i]); + } + }); + } +} + TEST_F(FindFieldEntryTest, SequentialFieldRange) { // Look up fields that are within the range of `lookup_table_offset`. // clang-format off