Skip to content

Commit

Permalink
Create TriStatePtr to replace SharableDependency which is difficult
Browse files Browse the repository at this point in the history
to debug threading issues in OSS builds

PiperOrigin-RevId: 707138236
  • Loading branch information
dryman authored and copybara-github committed Dec 19, 2024
1 parent fb97d67 commit b276ac6
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 430 deletions.
24 changes: 9 additions & 15 deletions cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,12 @@ cc_library(
)

cc_library(
name = "shareable_dependency",
hdrs = ["shareable_dependency.h"],
name = "tri_state_ptr",
hdrs = ["tri_state_ptr.h"],
deps = [
":common",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/synchronization",
"@com_google_riegeli//riegeli/base:initializer",
"@com_google_riegeli//riegeli/base:shared_ptr",
"@com_google_riegeli//riegeli/base:stable_dependency",
"@com_google_riegeli//riegeli/base:type_traits",
],
)

Expand Down Expand Up @@ -113,9 +108,10 @@ cc_library(
":common",
":layout_cc_proto",
":sequenced_chunk_writer",
":shareable_dependency",
":thread_pool",
":tri_state_ptr",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand All @@ -125,7 +121,6 @@ cc_library(
"@com_google_riegeli//riegeli/base:initializer",
"@com_google_riegeli//riegeli/base:object",
"@com_google_riegeli//riegeli/base:options_parser",
"@com_google_riegeli//riegeli/base:shared_ptr",
"@com_google_riegeli//riegeli/base:status",
"@com_google_riegeli//riegeli/bytes:chain_writer",
"@com_google_riegeli//riegeli/bytes:writer",
Expand Down Expand Up @@ -166,8 +161,8 @@ cc_library(
":layout_cc_proto",
":masked_reader",
":parallel_for",
":shareable_dependency",
":thread_pool",
":tri_state_ptr",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/functional:function_ref",
Expand All @@ -180,7 +175,6 @@ cc_library(
"@com_google_riegeli//riegeli/base:initializer",
"@com_google_riegeli//riegeli/base:object",
"@com_google_riegeli//riegeli/base:options_parser",
"@com_google_riegeli//riegeli/base:shared_ptr",
"@com_google_riegeli//riegeli/base:status",
"@com_google_riegeli//riegeli/bytes:reader",
"@com_google_riegeli//riegeli/chunk_encoding:chunk",
Expand Down Expand Up @@ -217,12 +211,12 @@ cc_test(
)

cc_test(
name = "shareable_dependency_test",
srcs = ["shareable_dependency_test.cc"],
name = "tri_state_ptr_test",
srcs = ["tri_state_ptr_test.cc"],
deps = [
":common",
":shareable_dependency",
":thread_pool",
":tri_state_ptr",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
Expand Down
47 changes: 30 additions & 17 deletions cpp/array_record_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,18 +187,28 @@ void ArrayRecordReaderBase::Initialize() {
if (!ok()) {
return;
}
auto reader = get_backing_reader();
if (!reader->ok()) {
Fail(reader->status());
return;
}
if (!reader->SupportsNewReader()) {
Fail(InvalidArgumentError(
"ArrayRecordReader only work on inputs with random access support."));
return;
}

uint32_t max_parallelism = 1;
if (state_->pool) {
max_parallelism = state_->pool->NumThreads();
if (state_->options.max_parallelism().has_value()) {
max_parallelism =
std::min(max_parallelism, state_->options.max_parallelism().value());
max_parallelism = std::min<uint32_t>(
max_parallelism, state_->options.max_parallelism().value());
}
}
state_->options.set_max_parallelism(max_parallelism);

AR_ENDO_TASK("Reading ArrayRecord footer");
const auto reader = get_backing_reader();
RiegeliFooterMetadata footer_metadata;
ChunkDecoder footer_decoder;
{
Expand Down Expand Up @@ -324,9 +334,9 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecords(
if (state_->chunk_offsets.empty()) {
return absl::OkStatus();
}
uint64_t num_chunk_groups =
CeilOfRatio(state_->chunk_offsets.size(), state_->chunk_group_size);
const auto reader = get_backing_reader();
uint64_t num_chunk_groups = CeilOfRatio<uint64_t>(
state_->chunk_offsets.size(), state_->chunk_group_size);
auto reader = get_backing_reader();
auto status = ParallelForWithStatus<1>(
Seq(num_chunk_groups), state_->pool, [&](size_t buf_idx) -> absl::Status {
uint64_t chunk_idx_start = buf_idx * state_->chunk_group_size;
Expand Down Expand Up @@ -398,17 +408,19 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecordsInRange(
begin, end, NumRecords());
}
uint64_t chunk_idx_begin = begin / state_->record_group_size;
uint64_t chunk_idx_end = CeilOfRatio(end, state_->record_group_size);
uint64_t chunk_idx_end =
CeilOfRatio<uint64_t>(end, state_->record_group_size);
uint64_t num_chunks = chunk_idx_end - chunk_idx_begin;
uint64_t num_chunk_groups = CeilOfRatio(num_chunks, state_->chunk_group_size);
uint64_t num_chunk_groups =
CeilOfRatio<uint64_t>(num_chunks, state_->chunk_group_size);

const auto reader = get_backing_reader();
auto reader = get_backing_reader();
auto status = ParallelForWithStatus<1>(
Seq(num_chunk_groups), state_->pool, [&](size_t buf_idx) -> absl::Status {
uint64_t chunk_idx_start =
chunk_idx_begin + buf_idx * state_->chunk_group_size;
// inclusive index, not the conventional exclusive index.
uint64_t last_chunk_idx = std::min(
uint64_t last_chunk_idx = std::min<uint64_t>(
chunk_idx_begin + (buf_idx + 1) * state_->chunk_group_size - 1,
chunk_idx_end - 1);
uint64_t buf_len = state_->ChunkEndOffset(last_chunk_idx) -
Expand Down Expand Up @@ -525,7 +537,7 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecordsWithIndices(
}
}
}
const auto reader = get_backing_reader();
auto reader = get_backing_reader();
auto status = ParallelForWithStatus<1>(
IndicesOf(chunk_indices_per_buffer), state_->pool,
[&](size_t buf_idx) -> absl::Status {
Expand Down Expand Up @@ -604,7 +616,7 @@ bool ArrayRecordReaderBase::SeekRecord(uint64_t record_index) {
if (!ok()) {
return false;
}
state_->record_idx = std::min(record_index, state_->num_records);
state_->record_idx = std::min<uint64_t>(record_index, state_->num_records);
return true;
}

Expand Down Expand Up @@ -654,9 +666,10 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) {
std::vector<ChunkDecoder> decoders;
decoders.reserve(state_->chunk_group_size);
uint64_t chunk_start = buffer_idx * state_->chunk_group_size;
uint64_t chunk_end = std::min(state_->chunk_offsets.size(),
(buffer_idx + 1) * state_->chunk_group_size);
const auto reader = get_backing_reader();
uint64_t chunk_end =
std::min<uint64_t>(state_->chunk_offsets.size(),
(buffer_idx + 1) * state_->chunk_group_size);
auto reader = get_backing_reader();
for (uint64_t chunk_idx = chunk_start; chunk_idx < chunk_end; ++chunk_idx) {
uint64_t chunk_offset = state_->chunk_offsets[chunk_idx];
uint64_t chunk_end_offset = state_->ChunkEndOffset(chunk_idx);
Expand Down Expand Up @@ -690,7 +703,7 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) {
std::make_shared<std::promise<std::vector<ChunkDecoder>>>();
state_->future_decoders.push(
{buffer_to_add, decoder_promise->get_future()});
const auto reader = get_backing_reader();
auto reader = get_backing_reader();
std::vector<uint64_t> chunk_offsets;
chunk_offsets.reserve(state_->chunk_group_size);
uint64_t chunk_start = buffer_to_add * state_->chunk_group_size;
Expand All @@ -704,7 +717,7 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) {
state_->ChunkEndOffset(chunk_end - 1) - chunk_offsets[0];

auto task = [reader, decoder_promise, chunk_offsets, buffer_to_add,
buffer_len] {
buffer_len]() mutable {
AR_ENDO_JOB("ArrayRecordReaderBase::ReadAheadFromBuffer",
absl::StrCat("buffer_idx: ", buffer_to_add,
" buffer_len: ", buffer_len));
Expand Down
33 changes: 12 additions & 21 deletions cpp/array_record_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "cpp/common.h"
#include "cpp/shareable_dependency.h"
#include "cpp/tri_state_ptr.h"
#include "cpp/thread_pool.h"
#include "google/protobuf/message_lite.h"
#include "riegeli/base/initializer.h"
Expand Down Expand Up @@ -293,7 +293,8 @@ class ArrayRecordReaderBase : public riegeli::Object {

void Initialize();

virtual DependencyShare<riegeli::Reader*> get_backing_reader() const = 0;
virtual TriStatePtr<riegeli::Reader>::SharedRef get_backing_reader()
const = 0;

private:
bool ReadAheadFromBuffer(uint64_t buffer_idx);
Expand Down Expand Up @@ -334,7 +335,7 @@ class ArrayRecordReaderBase : public riegeli::Object {
// if (!reader.Close()) return reader.status();
//
// ArrayRecordReader is thread compatible, not thread-safe.
template <typename Src = riegeli::Reader*>
template <typename Src = riegeli::Reader>
class ArrayRecordReader : public ArrayRecordReaderBase {
public:
DECLARE_MOVE_ONLY_CLASS(ArrayRecordReader);
Expand All @@ -344,34 +345,24 @@ class ArrayRecordReader : public ArrayRecordReaderBase {
Options options = Options(),
ARThreadPool* pool = nullptr)
: ArrayRecordReaderBase(std::move(options), pool),
main_reader_(std::move(src)) {
auto& unique = main_reader_.WaitUntilUnique();
if (!unique->ok()) {
Fail(unique->status());
return;
}
if (!unique->SupportsNewReader()) {
Fail(InvalidArgumentError(
"ArrayRecordReader only work on inputs with random access support."));
return;
}
main_reader_(std::make_unique<TriStatePtr<riegeli::Reader>>(
std::move(src))) {
Initialize();
}

protected:
DependencyShare<riegeli::Reader*> get_backing_reader() const override {
return main_reader_.Share();
TriStatePtr<riegeli::Reader>::SharedRef get_backing_reader() const override {
return main_reader_->MakeShared();
}

void Done() override {
auto& unique = main_reader_.WaitUntilUnique();
if (unique.IsOwning()) {
if (!unique->Close()) Fail(unique->status());
}
if (main_reader_ == nullptr) return;
auto unique = main_reader_->WaitAndMakeUnique();
if (!unique->Close()) Fail(unique->status());
}

private:
ShareableDependency<riegeli::Reader*, Src> main_reader_;
std::unique_ptr<TriStatePtr<riegeli::Reader>> main_reader_;
};

template <typename Src>
Expand Down
9 changes: 6 additions & 3 deletions cpp/array_record_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
Expand All @@ -38,6 +39,7 @@ limitations under the License.
#include "cpp/common.h"
#include "cpp/layout.pb.h"
#include "cpp/sequenced_chunk_writer.h"
#include "cpp/tri_state_ptr.h"
#include "cpp/thread_pool.h"
#include "google/protobuf/message_lite.h"
#include "riegeli/base/object.h"
Expand Down Expand Up @@ -262,7 +264,8 @@ class ArrayRecordWriterBase::SubmitChunkCallback
}

// Aggregate the offsets information and write it to the file.
void WriteFooterAndPostscript(SequencedChunkWriterBase* writer);
void WriteFooterAndPostscript(
TriStatePtr<SequencedChunkWriterBase>::SharedRef writer);

private:
const Options options_;
Expand Down Expand Up @@ -385,7 +388,7 @@ void ArrayRecordWriterBase::Done() {
}
chunk_promise.set_value(EncodeChunk(chunk_encoder_.get()));
}
submit_chunk_callback_->WriteFooterAndPostscript(writer.get());
submit_chunk_callback_->WriteFooterAndPostscript(std::move(writer));
}

std::unique_ptr<riegeli::ChunkEncoder> ArrayRecordWriterBase::CreateEncoder() {
Expand Down Expand Up @@ -486,7 +489,7 @@ void ArrayRecordWriterBase::SubmitChunkCallback::operator()(
}

void ArrayRecordWriterBase::SubmitChunkCallback::WriteFooterAndPostscript(
SequencedChunkWriterBase* writer) {
TriStatePtr<SequencedChunkWriterBase>::SharedRef writer) {
// Flushes prior chunks
writer->SubmitFutureChunks(true);
// Footer and postscript must pad to block boundary
Expand Down
31 changes: 14 additions & 17 deletions cpp/array_record_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "cpp/common.h"
#include "cpp/sequenced_chunk_writer.h"
#include "cpp/shareable_dependency.h"
#include "cpp/tri_state_ptr.h"
#include "cpp/thread_pool.h"
#include "riegeli/base/initializer.h"
#include "riegeli/base/object.h"
Expand Down Expand Up @@ -317,7 +317,7 @@ class ArrayRecordWriterBase : public riegeli::Object {
ArrayRecordWriterBase(ArrayRecordWriterBase&& other) noexcept;
ArrayRecordWriterBase& operator=(ArrayRecordWriterBase&& other) noexcept;

virtual DependencyShare<SequencedChunkWriterBase*> get_writer() = 0;
virtual TriStatePtr<SequencedChunkWriterBase>::SharedRef get_writer() = 0;

// Initializes and validates the underlying writer states.
void Initialize();
Expand Down Expand Up @@ -380,37 +380,34 @@ class ArrayRecordWriter : public ArrayRecordWriterBase {
Options options = Options(),
ARThreadPool* pool = nullptr)
: ArrayRecordWriterBase(std::move(options), pool),
main_writer_(
std::make_unique<SequencedChunkWriter<Dest>>(std::move(dest))) {
auto& unique = main_writer_.WaitUntilUnique();
if (!unique->ok()) {
Fail(unique->status());
main_writer_(std::make_unique<TriStatePtr<SequencedChunkWriterBase>>(
std::make_unique<SequencedChunkWriter<Dest>>(std::move(dest)))) {
auto writer = get_writer();
if (!writer->ok()) {
Fail(writer->status());
return;
}
Initialize();
}

protected:
DependencyShare<SequencedChunkWriterBase*> get_writer() final {
return main_writer_.Share();
TriStatePtr<SequencedChunkWriterBase>::SharedRef get_writer() final {
return main_writer_->MakeShared();
}

void Done() override {
// WaitUntilUnique ensures all pending tasks are finished.
auto& unique = main_writer_.WaitUntilUnique();
if (main_writer_ == nullptr) return;
ArrayRecordWriterBase::Done();
if (!ok()) {
return;
}
if (unique.IsOwning()) {
if (!unique->Close()) Fail(unique->status());
}
// Ensures all pending tasks are finished.
auto unique = main_writer_->WaitAndMakeUnique();
if (!unique->Close()) Fail(unique->status());
}

private:
ShareableDependency<SequencedChunkWriterBase*,
std::unique_ptr<SequencedChunkWriter<Dest>>>
main_writer_;
std::unique_ptr<TriStatePtr<SequencedChunkWriterBase>> main_writer_;
};

template <typename Dest>
Expand Down
Loading

0 comments on commit b276ac6

Please sign in to comment.