Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a consistency check to the parser to verify that has bits are valid on #18904

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/google/protobuf/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1583,9 +1583,12 @@ cc_test(
}),
deps = [
":cc_test_protos",
":descriptor_visitor",
":port",
":protobuf",
":protobuf_lite",
"//src/google/protobuf/io",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings",
Expand Down
8 changes: 8 additions & 0 deletions src/google/protobuf/generated_message_tctable_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,9 @@ class PROTOBUF_EXPORT TcParser final {
};
}

static void VerifyHasBitConsistency(const MessageLite* msg,
const TcParseTableBase* table);

private:
// Optimized small tag varint parser for int32/int64
template <typename FieldType>
Expand Down Expand Up @@ -1027,6 +1030,8 @@ class PROTOBUF_EXPORT TcParser final {
static absl::string_view MessageName(const TcParseTableBase* table);
static absl::string_view FieldName(const TcParseTableBase* table,
const TcParseTableBase::FieldEntry*);
static int FieldNumber(const TcParseTableBase* table,
const TcParseTableBase::FieldEntry*);
static bool ChangeOneof(const TcParseTableBase* table,
const TcParseTableBase::FieldEntry& entry,
uint32_t field_num, ParseContext* ctx,
Expand Down Expand Up @@ -1152,6 +1157,9 @@ inline PROTOBUF_ALWAYS_INLINE const char* TcParser::ParseLoop(
if (ABSL_PREDICT_FALSE(table->has_post_loop_handler)) {
return table->post_loop_handler(msg, ptr, ctx);
}
if (ABSL_PREDICT_FALSE(PerformDebugChecks() && ptr == nullptr)) {
VerifyHasBitConsistency(msg, table);
}
return ptr;
}

Expand Down
175 changes: 174 additions & 1 deletion src/google/protobuf/generated_message_tctable_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include "absl/log/absl_log.h"
#include "absl/numeric/bits.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "google/protobuf/arenastring.h"
#include "google/protobuf/generated_enum_util.h"
#include "google/protobuf/generated_message_tctable_decl.h"
Expand Down Expand Up @@ -73,6 +75,120 @@ const char* TcParser::GenericFallbackLite(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_TC_PARAM_PASS);
}

namespace {
bool ReadHas(const FieldEntry& entry, const MessageLite* msg) {
auto has_idx = static_cast<uint32_t>(entry.has_idx);
const auto& hasblock = TcParser::RefAt<const uint32_t>(msg, has_idx / 32 * 4);
return (hasblock & (uint32_t{1} << (has_idx % 32))) != 0;
}
} // namespace

void TcParser::VerifyHasBitConsistency(const MessageLite* msg,
const TcParseTableBase* table) {
namespace fl = internal::field_layout;
if (table->has_bits_offset == 0) {
// Nothing to check
return;
}

for (const auto& entry : table->field_entries()) {
const auto print_error = [&] {
return absl::StrFormat("Type=%s Field=%d\n", msg->GetTypeName(),
FieldNumber(table, &entry));
};
if ((entry.type_card & fl::kFcMask) != fl::kFcOptional) return;
const bool has_bit = ReadHas(entry, msg);
const void* base = msg;
const void* default_base = table->default_instance();
if ((entry.type_card & field_layout::kSplitMask) ==
field_layout::kSplitTrue) {
const size_t offset = table->field_aux(kSplitOffsetAuxIdx)->offset;
base = TcParser::RefAt<const void*>(base, offset);
default_base = TcParser::RefAt<const void*>(default_base, offset);
}
switch (entry.type_card & fl::kFkMask) {
case fl::kFkVarint:
case fl::kFkFixed:
// Numerics can have any value when the has bit is on.
if (has_bit) return;
switch (entry.type_card & fl::kRepMask) {
case fl::kRep8Bits:
ABSL_CHECK_EQ(RefAt<bool>(base, entry.offset),
RefAt<bool>(default_base, entry.offset))
<< print_error();
break;
case fl::kRep32Bits:
ABSL_CHECK_EQ(RefAt<uint32_t>(base, entry.offset),
RefAt<uint32_t>(default_base, entry.offset))
<< print_error();
break;
case fl::kRep64Bits:
ABSL_CHECK_EQ(RefAt<uint64_t>(base, entry.offset),
RefAt<uint64_t>(default_base, entry.offset))
<< print_error();
break;
}
break;

case fl::kFkString:
switch (entry.type_card & fl::kRepMask) {
case field_layout::kRepAString:
if (has_bit) {
// Must not point to the default.
ABSL_CHECK(!RefAt<ArenaStringPtr>(base, entry.offset).IsDefault())
<< print_error();
} else {
// We should technically check that the value matches the default
// value of the field, but the prototype does not actually contain
// this value. Non-empty defaults are loaded on access.
}
break;
case field_layout::kRepCord:
if (!has_bit) {
// If the has bit is off, it must match the default.
ABSL_CHECK_EQ(RefAt<absl::Cord>(base, entry.offset),
RefAt<absl::Cord>(default_base, entry.offset))
<< print_error();
}
break;
case field_layout::kRepIString:
if (!has_bit) {
// If the has bit is off, it must match the default.
ABSL_CHECK_EQ(
RefAt<InlinedStringField>(base, entry.offset).Get(),
RefAt<InlinedStringField>(default_base, entry.offset).Get())
<< print_error();
}
break;
case field_layout::kRepSString:
Unreachable();
}
break;
case fl::kFkMessage:
switch (entry.type_card & fl::kRepMask) {
case fl::kRepMessage:
case fl::kRepGroup:
if (has_bit) {
ABSL_CHECK(RefAt<const MessageLite*>(base, entry.offset) !=
nullptr)
<< print_error();
} else {
// An off has_bit does not imply a null pointer.
// We might have a previous instance that we cached.
}
break;
default:
Unreachable();
}
break;

default:
// All other types are not `optional`.
Unreachable();
}
}
}

//////////////////////////////////////////////////////////////////////////////
// Core fast parsing implementation:
//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -226,6 +342,46 @@ absl::string_view TcParser::FieldName(const TcParseTableBase* table,
field_index + 1);
}

int TcParser::FieldNumber(const TcParseTableBase* table,
const TcParseTableBase::FieldEntry* entry) {
// The data structure was not designed to be queried in this direction, so
// we have to do a linear search over the entries to see which one matches
// while keeping track of the field number.
// But it is fine because we are only using this for debug check messages.
size_t need_to_skip = entry - table->field_entries_begin();
const auto visit_bitmap = [&](uint32_t field_bitmap,
int base_field_number) -> absl::optional<int> {
for (; field_bitmap != 0; field_bitmap &= field_bitmap - 1) {
if (need_to_skip == 0) {
return absl::countr_zero(field_bitmap) + base_field_number;
}
--need_to_skip;
}
return absl::nullopt;
};
if (auto number = visit_bitmap(~table->skipmap32, 1)) {
return *number;
}

for (const uint16_t* lookup_table = table->field_lookup_begin();
lookup_table[0] != 0xFFFF || lookup_table[1] != 0xFFFF;) {
uint32_t fstart = lookup_table[0] | (lookup_table[1] << 16);
lookup_table += 2;
const uint16_t num_skip_entries = *lookup_table++;
for (uint16_t i = 0; i < num_skip_entries; ++i) {
// for each group of 16 fields we have: a
// bitmap of 16 bits a 16-bit field-entry
// offset for the first of them.
if (auto number = visit_bitmap(static_cast<uint16_t>(~*lookup_table),
fstart + 16 * i)) {
return *number;
}
lookup_table += 2;
}
}
Unreachable();
}

PROTOBUF_NOINLINE const char* TcParser::Error(PROTOBUF_TC_PARAM_NO_DATA_DECL) {
(void)ctx;
(void)ptr;
Expand Down Expand Up @@ -1403,6 +1559,19 @@ PROTOBUF_ALWAYS_INLINE inline bool IsValidUTF8(ArenaStringPtr& field) {
}


void EnsureArenaStringIsNotDefault(const MessageLite* msg,
ArenaStringPtr* field) {
// If we failed here we might have left the string in its IsDefault state, but
// already set the has bit which breaks the message invariants. We must make
// it consistent again. We do that by guaranteeing the string always exists.
if (field->IsDefault()) {
field->Set("", msg->GetArena());
}
}
// The rest do nothing.
PROTOBUF_UNUSED void EnsureArenaStringIsNotDefault(const MessageLite* msg,
void*) {}

} // namespace

template <typename TagType, typename FieldType, TcParser::Utf8Type utf8>
Expand All @@ -1423,6 +1592,7 @@ inline PROTOBUF_ALWAYS_INLINE const char* TcParser::SingularString(
ptr = ReadStringNoArena(msg, ptr, ctx, data.aux_idx(), table, field);
}
if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) {
EnsureArenaStringIsNotDefault(msg, &field);
PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
switch (utf8) {
Expand Down Expand Up @@ -2180,7 +2350,10 @@ PROTOBUF_NOINLINE const char* TcParser::MpString(PROTOBUF_TC_PARAM_DECL) {
std::string* str = field.MutableNoCopy(nullptr);
ptr = InlineGreedyStringParser(str, ptr, ctx);
}
if (!ptr) break;
if (ABSL_PREDICT_FALSE(ptr == nullptr)) {
EnsureArenaStringIsNotDefault(msg, &field);
break;
}
is_valid = MpVerifyUtf8(field.Get(), table, entry, xform_val);
break;
}
Expand Down
39 changes: 39 additions & 0 deletions src/google/protobuf/generated_message_tctable_lite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/algorithm/container.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/descriptor_visitor.h"
#include "google/protobuf/generated_message_tctable_decl.h"
#include "google/protobuf/generated_message_tctable_impl.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/parse_context.h"
#include "google/protobuf/unittest.pb.h"
#include "google/protobuf/wire_format_lite.h"
Expand Down Expand Up @@ -333,6 +338,10 @@ class FindFieldEntryTest : public ::testing::Test {
return TcParser::FieldName(&table.header, entry);
}

static int FieldNumber(const TcParseTableBase* table, size_t index) {
return TcParser::FieldNumber(table, table->field_entries_begin() + index);
}

// Calls the private `MessageName` function.
template <size_t kFastTableSizeLog2, size_t kNumEntries, size_t kNumFieldAux,
size_t kNameTableSize, size_t kFieldLookupTableSize>
Expand All @@ -346,6 +355,36 @@ class FindFieldEntryTest : public ::testing::Test {
static constexpr int small_scan_size() { return TcParser::kMtSmallScanSize; }
};

TEST_F(FindFieldEntryTest, FieldNumberWorksForAllFields) {
// Look at all types registered in the binary and verify that field number
// calculation works for all the fields.
auto* gen_db = DescriptorPool::internal_generated_database();
std::vector<std::string> all_file_names;
gen_db->FindAllFileNames(&all_file_names);

for (const auto& filename : all_file_names) {
SCOPED_TRACE(filename);
const auto* file =
DescriptorPool::generated_pool()->FindFileByName(filename);
VisitDescriptors(*file, [&](const Descriptor& desc) {
SCOPED_TRACE(desc.full_name());
const auto* prototype =
MessageFactory::generated_factory()->GetPrototype(&desc);
const auto* tc_table = internal::GetClassData(*prototype)->tc_table;

std::vector<int> sorted_field_numbers;
for (auto* field : internal::FieldRange(&desc)) {
sorted_field_numbers.push_back(field->number());
}
absl::c_sort(sorted_field_numbers);

for (int i = 0; i < desc.field_count(); ++i) {
EXPECT_EQ(FieldNumber(tc_table, i), sorted_field_numbers[i]);
}
});
}
}

TEST_F(FindFieldEntryTest, SequentialFieldRange) {
// Look up fields that are within the range of `lookup_table_offset`.
// clang-format off
Expand Down
Loading