diff --git a/cpp/BUILD b/cpp/BUILD index 8213174..2a500d3 100644 --- a/cpp/BUILD +++ b/cpp/BUILD @@ -72,17 +72,12 @@ cc_library( ) cc_library( - name = "shareable_dependency", - hdrs = ["shareable_dependency.h"], + name = "tri_state_ptr", + hdrs = ["tri_state_ptr.h"], deps = [ + ":common", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/synchronization", - "@com_google_riegeli//riegeli/base:initializer", - "@com_google_riegeli//riegeli/base:shared_ptr", - "@com_google_riegeli//riegeli/base:stable_dependency", - "@com_google_riegeli//riegeli/base:type_traits", ], ) @@ -113,9 +108,10 @@ cc_library( ":common", ":layout_cc_proto", ":sequenced_chunk_writer", - ":shareable_dependency", ":thread_pool", + ":tri_state_ptr", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -125,7 +121,6 @@ cc_library( "@com_google_riegeli//riegeli/base:initializer", "@com_google_riegeli//riegeli/base:object", "@com_google_riegeli//riegeli/base:options_parser", - "@com_google_riegeli//riegeli/base:shared_ptr", "@com_google_riegeli//riegeli/base:status", "@com_google_riegeli//riegeli/bytes:chain_writer", "@com_google_riegeli//riegeli/bytes:writer", @@ -166,8 +161,8 @@ cc_library( ":layout_cc_proto", ":masked_reader", ":parallel_for", - ":shareable_dependency", ":thread_pool", + ":tri_state_ptr", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/functional:function_ref", @@ -180,7 +175,6 @@ cc_library( "@com_google_riegeli//riegeli/base:initializer", "@com_google_riegeli//riegeli/base:object", "@com_google_riegeli//riegeli/base:options_parser", - "@com_google_riegeli//riegeli/base:shared_ptr", "@com_google_riegeli//riegeli/base:status", "@com_google_riegeli//riegeli/bytes:reader", "@com_google_riegeli//riegeli/chunk_encoding:chunk", @@ -217,14 +211,13 @@ cc_test( ) cc_test( - name = "shareable_dependency_test", - srcs = ["shareable_dependency_test.cc"], + name = "tri_state_ptr_test", + srcs = ["tri_state_ptr_test.cc"], deps = [ ":common", - ":shareable_dependency", ":thread_pool", + ":tri_state_ptr", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", "@com_google_riegeli//riegeli/base:initializer", ], diff --git a/cpp/array_record_reader.cc b/cpp/array_record_reader.cc index 2794ec9..c57aa57 100644 --- a/cpp/array_record_reader.cc +++ b/cpp/array_record_reader.cc @@ -187,18 +187,28 @@ void ArrayRecordReaderBase::Initialize() { if (!ok()) { return; } + auto reader = get_backing_reader(); + if (!reader->ok()) { + Fail(reader->status()); + return; + } + if (!reader->SupportsNewReader()) { + Fail(InvalidArgumentError( + "ArrayRecordReader only work on inputs with random access support.")); + return; + } + uint32_t max_parallelism = 1; if (state_->pool) { max_parallelism = state_->pool->NumThreads(); if (state_->options.max_parallelism().has_value()) { - max_parallelism = - std::min(max_parallelism, state_->options.max_parallelism().value()); + max_parallelism = std::min( + max_parallelism, state_->options.max_parallelism().value()); } } state_->options.set_max_parallelism(max_parallelism); AR_ENDO_TASK("Reading ArrayRecord footer"); - const auto reader = get_backing_reader(); RiegeliFooterMetadata footer_metadata; ChunkDecoder footer_decoder; { @@ -324,9 +334,9 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecords( if (state_->chunk_offsets.empty()) { return absl::OkStatus(); } - uint64_t num_chunk_groups = - CeilOfRatio(state_->chunk_offsets.size(), state_->chunk_group_size); - const auto reader = get_backing_reader(); + uint64_t num_chunk_groups = CeilOfRatio( + state_->chunk_offsets.size(), state_->chunk_group_size); + auto reader = get_backing_reader(); auto status = ParallelForWithStatus<1>( Seq(num_chunk_groups), state_->pool, [&](size_t buf_idx) -> absl::Status { uint64_t chunk_idx_start = buf_idx * state_->chunk_group_size; @@ -398,17 +408,19 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecordsInRange( begin, end, NumRecords()); } uint64_t chunk_idx_begin = begin / state_->record_group_size; - uint64_t chunk_idx_end = CeilOfRatio(end, state_->record_group_size); + uint64_t chunk_idx_end = + CeilOfRatio(end, state_->record_group_size); uint64_t num_chunks = chunk_idx_end - chunk_idx_begin; - uint64_t num_chunk_groups = CeilOfRatio(num_chunks, state_->chunk_group_size); + uint64_t num_chunk_groups = + CeilOfRatio(num_chunks, state_->chunk_group_size); - const auto reader = get_backing_reader(); + auto reader = get_backing_reader(); auto status = ParallelForWithStatus<1>( Seq(num_chunk_groups), state_->pool, [&](size_t buf_idx) -> absl::Status { uint64_t chunk_idx_start = chunk_idx_begin + buf_idx * state_->chunk_group_size; // inclusive index, not the conventional exclusive index. - uint64_t last_chunk_idx = std::min( + uint64_t last_chunk_idx = std::min( chunk_idx_begin + (buf_idx + 1) * state_->chunk_group_size - 1, chunk_idx_end - 1); uint64_t buf_len = state_->ChunkEndOffset(last_chunk_idx) - @@ -525,7 +537,7 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecordsWithIndices( } } } - const auto reader = get_backing_reader(); + auto reader = get_backing_reader(); auto status = ParallelForWithStatus<1>( IndicesOf(chunk_indices_per_buffer), state_->pool, [&](size_t buf_idx) -> absl::Status { @@ -604,7 +616,7 @@ bool ArrayRecordReaderBase::SeekRecord(uint64_t record_index) { if (!ok()) { return false; } - state_->record_idx = std::min(record_index, state_->num_records); + state_->record_idx = std::min(record_index, state_->num_records); return true; } @@ -654,9 +666,10 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) { std::vector decoders; decoders.reserve(state_->chunk_group_size); uint64_t chunk_start = buffer_idx * state_->chunk_group_size; - uint64_t chunk_end = std::min(state_->chunk_offsets.size(), - (buffer_idx + 1) * state_->chunk_group_size); - const auto reader = get_backing_reader(); + uint64_t chunk_end = + std::min(state_->chunk_offsets.size(), + (buffer_idx + 1) * state_->chunk_group_size); + auto reader = get_backing_reader(); for (uint64_t chunk_idx = chunk_start; chunk_idx < chunk_end; ++chunk_idx) { uint64_t chunk_offset = state_->chunk_offsets[chunk_idx]; uint64_t chunk_end_offset = state_->ChunkEndOffset(chunk_idx); @@ -690,7 +703,7 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) { std::make_shared>>(); state_->future_decoders.push( {buffer_to_add, decoder_promise->get_future()}); - const auto reader = get_backing_reader(); + auto reader = get_backing_reader(); std::vector chunk_offsets; chunk_offsets.reserve(state_->chunk_group_size); uint64_t chunk_start = buffer_to_add * state_->chunk_group_size; @@ -704,7 +717,7 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) { state_->ChunkEndOffset(chunk_end - 1) - chunk_offsets[0]; auto task = [reader, decoder_promise, chunk_offsets, buffer_to_add, - buffer_len] { + buffer_len]() mutable { AR_ENDO_JOB("ArrayRecordReaderBase::ReadAheadFromBuffer", absl::StrCat("buffer_idx: ", buffer_to_add, " buffer_len: ", buffer_len)); diff --git a/cpp/array_record_reader.h b/cpp/array_record_reader.h index 05717e6..962a67e 100644 --- a/cpp/array_record_reader.h +++ b/cpp/array_record_reader.h @@ -46,7 +46,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "cpp/common.h" -#include "cpp/shareable_dependency.h" +#include "cpp/tri_state_ptr.h" #include "cpp/thread_pool.h" #include "google/protobuf/message_lite.h" #include "riegeli/base/initializer.h" @@ -293,7 +293,8 @@ class ArrayRecordReaderBase : public riegeli::Object { void Initialize(); - virtual DependencyShare get_backing_reader() const = 0; + virtual TriStatePtr::SharedRef get_backing_reader() + const = 0; private: bool ReadAheadFromBuffer(uint64_t buffer_idx); @@ -334,7 +335,7 @@ class ArrayRecordReaderBase : public riegeli::Object { // if (!reader.Close()) return reader.status(); // // ArrayRecordReader is thread compatible, not thread-safe. -template +template class ArrayRecordReader : public ArrayRecordReaderBase { public: DECLARE_MOVE_ONLY_CLASS(ArrayRecordReader); @@ -344,34 +345,24 @@ class ArrayRecordReader : public ArrayRecordReaderBase { Options options = Options(), ARThreadPool* pool = nullptr) : ArrayRecordReaderBase(std::move(options), pool), - main_reader_(std::move(src)) { - auto& unique = main_reader_.WaitUntilUnique(); - if (!unique->ok()) { - Fail(unique->status()); - return; - } - if (!unique->SupportsNewReader()) { - Fail(InvalidArgumentError( - "ArrayRecordReader only work on inputs with random access support.")); - return; - } + main_reader_(std::make_unique>( + std::move(src))) { Initialize(); } protected: - DependencyShare get_backing_reader() const override { - return main_reader_.Share(); + TriStatePtr::SharedRef get_backing_reader() const override { + return main_reader_->MakeShared(); } void Done() override { - auto& unique = main_reader_.WaitUntilUnique(); - if (unique.IsOwning()) { - if (!unique->Close()) Fail(unique->status()); - } + if (main_reader_ == nullptr) return; + auto unique = main_reader_->WaitAndMakeUnique(); + if (!unique->Close()) Fail(unique->status()); } private: - ShareableDependency main_reader_; + std::unique_ptr> main_reader_; }; template diff --git a/cpp/array_record_writer.cc b/cpp/array_record_writer.cc index 4e28c4c..48691ba 100644 --- a/cpp/array_record_writer.cc +++ b/cpp/array_record_writer.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include "absl/base/thread_annotations.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -38,6 +39,7 @@ limitations under the License. #include "cpp/common.h" #include "cpp/layout.pb.h" #include "cpp/sequenced_chunk_writer.h" +#include "cpp/tri_state_ptr.h" #include "cpp/thread_pool.h" #include "google/protobuf/message_lite.h" #include "riegeli/base/object.h" @@ -262,7 +264,8 @@ class ArrayRecordWriterBase::SubmitChunkCallback } // Aggregate the offsets information and write it to the file. - void WriteFooterAndPostscript(SequencedChunkWriterBase* writer); + void WriteFooterAndPostscript( + TriStatePtr::SharedRef writer); private: const Options options_; @@ -385,7 +388,7 @@ void ArrayRecordWriterBase::Done() { } chunk_promise.set_value(EncodeChunk(chunk_encoder_.get())); } - submit_chunk_callback_->WriteFooterAndPostscript(writer.get()); + submit_chunk_callback_->WriteFooterAndPostscript(std::move(writer)); } std::unique_ptr ArrayRecordWriterBase::CreateEncoder() { @@ -486,7 +489,7 @@ void ArrayRecordWriterBase::SubmitChunkCallback::operator()( } void ArrayRecordWriterBase::SubmitChunkCallback::WriteFooterAndPostscript( - SequencedChunkWriterBase* writer) { + TriStatePtr::SharedRef writer) { // Flushes prior chunks writer->SubmitFutureChunks(true); // Footer and postscript must pad to block boundary diff --git a/cpp/array_record_writer.h b/cpp/array_record_writer.h index 4ad0d05..1f374e9 100644 --- a/cpp/array_record_writer.h +++ b/cpp/array_record_writer.h @@ -70,7 +70,7 @@ limitations under the License. #include "absl/types/span.h" #include "cpp/common.h" #include "cpp/sequenced_chunk_writer.h" -#include "cpp/shareable_dependency.h" +#include "cpp/tri_state_ptr.h" #include "cpp/thread_pool.h" #include "riegeli/base/initializer.h" #include "riegeli/base/object.h" @@ -317,7 +317,7 @@ class ArrayRecordWriterBase : public riegeli::Object { ArrayRecordWriterBase(ArrayRecordWriterBase&& other) noexcept; ArrayRecordWriterBase& operator=(ArrayRecordWriterBase&& other) noexcept; - virtual DependencyShare get_writer() = 0; + virtual TriStatePtr::SharedRef get_writer() = 0; // Initializes and validates the underlying writer states. void Initialize(); @@ -380,37 +380,34 @@ class ArrayRecordWriter : public ArrayRecordWriterBase { Options options = Options(), ARThreadPool* pool = nullptr) : ArrayRecordWriterBase(std::move(options), pool), - main_writer_( - std::make_unique>(std::move(dest))) { - auto& unique = main_writer_.WaitUntilUnique(); - if (!unique->ok()) { - Fail(unique->status()); + main_writer_(std::make_unique>( + std::make_unique>(std::move(dest)))) { + auto writer = get_writer(); + if (!writer->ok()) { + Fail(writer->status()); return; } Initialize(); } protected: - DependencyShare get_writer() final { - return main_writer_.Share(); + TriStatePtr::SharedRef get_writer() final { + return main_writer_->MakeShared(); } void Done() override { - // WaitUntilUnique ensures all pending tasks are finished. - auto& unique = main_writer_.WaitUntilUnique(); + if (main_writer_ == nullptr) return; ArrayRecordWriterBase::Done(); if (!ok()) { return; } - if (unique.IsOwning()) { - if (!unique->Close()) Fail(unique->status()); - } + // Ensures all pending tasks are finished. + auto unique = main_writer_->WaitAndMakeUnique(); + if (!unique->Close()) Fail(unique->status()); } private: - ShareableDependency>> - main_writer_; + std::unique_ptr> main_writer_; }; template diff --git a/cpp/shareable_dependency.h b/cpp/shareable_dependency.h deleted file mode 100644 index 691f866..0000000 --- a/cpp/shareable_dependency.h +++ /dev/null @@ -1,240 +0,0 @@ -/* Copyright 2022 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// `ShareableDependency` wraps a -// `riegeli::StableDependency`. It allows creating concurrent -// shares of it of type `DependencyShare`, and waiting until all shares -// are no longer in use. It waits explicitly or implicitly when the -// `ShareableDependency` is destroyed or reassigned. -// -// The main use case of `ShareableDependency` is to create a share for a -// detached thread, and allow the owner object to invoke non-const methods after -// all detached threads are finished. -// -// This is especially important for riegeli objects because we must call the -// non-const `Close()` on exit, and we cannot do that while other threads are -// accessing the object. -// -// Example usage: -// -// ShareableDependency main(riegeli::Maker(...)); -// -// // Detached thread with a refobj which increased the refcnt by 1. -// pool->Schedule([refobj = main.Share()] { -// refobj->FooMethod(...); -// }); -// -// // Blocks until refobj goes out of scope. -// auto& unique = main.WaitUntilUnique(); -// if (unique.IsOwning()) unique->Close(); -// -// `main` blocks on destruction when it is not unique. Therefore prevents -// refobj to be a dangling pointer. - -#ifndef ARRAY_RECORD_CPP_SHAREABLE_DEPENDENCY_H_ -#define ARRAY_RECORD_CPP_SHAREABLE_DEPENDENCY_H_ - -#include - -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/synchronization/mutex.h" -#include "riegeli/base/initializer.h" -#include "riegeli/base/intrusive_shared_ptr.h" -#include "riegeli/base/ref_count.h" -#include "riegeli/base/stable_dependency.h" -#include "riegeli/base/type_traits.h" - -namespace array_record { - -// `DependencyShare` wraps a `Handle` and tracks the lifetime of all -// shares from the given `ShareableDependency`. -template -class DependencyShare { - public: - DependencyShare() = default; - - DependencyShare(const DependencyShare& other) = default; - DependencyShare& operator=(const DependencyShare& other) = default; - - // The source is left empty. - DependencyShare(DependencyShare&& other) = default; - // The source is left empty. - DependencyShare& operator=(DependencyShare&& other) = default; - - Handle get() const { - DCHECK(sharing_ != nullptr); - return sharing_->handle; - } - template ::value, - int> = 0> - decltype(*std::declval()) operator*() const { - return *get(); - } - template < - typename DependentHandle = Handle, - std::enable_if_t::value, int> = 0> - Handle operator->() const { - return get(); - } - - private: - template - friend class ShareableDependency; - - struct Sharing; - - explicit DependencyShare(Sharing* sharing); - - riegeli::IntrusiveSharedPtr sharing_; -}; - -// `ShareableDependency` wraps a -// `riegeli::StableDependency`. It allows creating concurrent -// shares of it of type `DependencyShare`, and waiting until all shares -// are no longer in use. It waits explicitly or implicitly when the -// `ShareableDependency` is destroyed or reassigned. -template -class ShareableDependency { - public: - // Creates an empty `ShareableDependency`. - ShareableDependency() = default; - - // Creates a `ShareableDependency` storing the `Manager`. - explicit ShareableDependency(riegeli::Initializer manager) - : dependency_(std::move(manager)), - sharing_(new Sharing(dependency_.get())) {} - - // The source is left empty. - ShareableDependency(ShareableDependency&& other) = default; - // Waits until `*this` is empty or unique. The source is left empty. - ShareableDependency& operator=(ShareableDependency&& other) = default; - - // Waits until `*this` is empty or unique. - ~ShareableDependency() = default; - - // Makes `*this` equivalent to a newly constructed `ShareableDependency`. This - // avoids constructing a temporary `ShareableDependency` and moving from it. - ABSL_ATTRIBUTE_REINITIALIZES void Reset(); - ABSL_ATTRIBUTE_REINITIALIZES void Reset( - riegeli::Initializer manager); - - // Creates a `DependencyShare` sharing a pointer from `*this`. - // - // As long as the `DependencyShare` is alive, `*this` will wait in its - // destructor, assignment, and `WaitUntilUnique()`. - // - // An empty `ShareableDependency` yields an empty `DependencyShare`. - DependencyShare Share() const; - - // Waits until `*this` is empty or unique. Returns a reference to a - // `StableDependency` storing the `Manager`. - riegeli::StableDependency& WaitUntilUnique(); - - // Returns `true` if there are no alive shares of `*this`. - bool IsUnique() const; - - private: - using Sharing = typename DependencyShare::Sharing; - - struct Deleter; - - riegeli::StableDependency dependency_; - std::unique_ptr sharing_; -}; - -//////////////////////////////////////////////////////////////////////////////// -// IMPLEMENTATION DETAILS -//////////////////////////////////////////////////////////////////////////////// - -template -struct DependencyShare::Sharing { - explicit Sharing(Handle handle) : handle(std::move(handle)) {} - - void Ref() const { ref_count.Ref(); } - void Unref() const { - // Notify the `ShareableDependency` if there are no more shares. - absl::MutexLock l(&mu); - if (ref_count.Unref()) { - DLOG(FATAL) - << "The last DependencyShare outlived the ShareableDependency"; - } - } - bool HasUniqueOwner() const { return ref_count.HasUniqueOwner(); } - void WaitUntilUnique() const { - absl::MutexLock l(&mu, absl::Condition(this, &Sharing::HasUniqueOwner)); - } - - Handle handle; - mutable absl::Mutex mu; - riegeli::RefCount ref_count; -}; - -template -DependencyShare::DependencyShare(Sharing* sharing) : sharing_(sharing) { - if (sharing_ != nullptr) sharing_->Ref(); -} - -template -struct ShareableDependency::Deleter { - void operator()(Sharing* sharing) const { - sharing->WaitUntilUnique(); - delete sharing; - } -}; - -template -void ShareableDependency::Reset() { - sharing_.reset(); - dependency_.Reset(); -} - -template -void ShareableDependency::Reset( - riegeli::Initializer manager) { - WaitUntilUnique().Reset(std::move(manager)); - if (sharing_ == nullptr) { - sharing_.reset(new Sharing(dependency_.get())); - } else { - sharing_->handle = dependency_.get(); - } -} - -template -DependencyShare ShareableDependency::Share() const { - return DependencyShare(sharing_.get()); -} - -template -riegeli::StableDependency& -ShareableDependency::WaitUntilUnique() { - if (sharing_ != nullptr) sharing_->WaitUntilUnique(); - return dependency_; -} - -template -bool ShareableDependency::IsUnique() const { - return sharing_ != nullptr && sharing_->HasUniqueOwner(); -} - -} // namespace array_record - -#endif // ARRAY_RECORD_CPP_SHAREABLE_DEPENDENCY_H_ diff --git a/cpp/shareable_dependency_test.cc b/cpp/shareable_dependency_test.cc deleted file mode 100644 index 2083dd7..0000000 --- a/cpp/shareable_dependency_test.cc +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright 2022 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "cpp/shareable_dependency.h" - -#include -#include - -#include "gtest/gtest.h" -#include "absl/synchronization/notification.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "cpp/common.h" -#include "cpp/thread_pool.h" -#include "riegeli/base/maker.h" - -namespace array_record { -namespace { - -class FooBase { - public: - virtual ~FooBase() = default; - virtual int value() const = 0; - virtual void add_value(int v) = 0; - virtual void mul_value(int v) = 0; -}; - -class Foo : public FooBase { - public: - explicit Foo(int v) : value_(v) {} - DECLARE_MOVE_ONLY_CLASS(Foo); - - int value() const override { return value_; }; - void add_value(int v) override { value_ += v; } - void mul_value(int v) override { value_ *= v; } - - private: - int value_; -}; - -class ShareableDependencyTest : public testing::Test { - public: - ShareableDependencyTest() : pool_(ArrayRecordGlobalPool()) {} - - protected: - ARThreadPool* pool_; -}; - -TEST_F(ShareableDependencyTest, SanityTest) { - ShareableDependency main(riegeli::Maker(1)); - EXPECT_TRUE(main.IsUnique()); - - auto new_main = std::move(main); - EXPECT_TRUE(new_main.IsUnique()); - // Not owning after move - EXPECT_FALSE(main.IsUnique()); // NOLINT(bugprone-use-after-move) - - main = std::move(new_main); - EXPECT_TRUE(main.IsUnique()); - // Not owning after move - EXPECT_FALSE(new_main.IsUnique()); // NOLINT(bugprone-use-after-move) - - absl::Notification notification; - pool_->Schedule([refobj = main.Share(), ¬ification] { - notification.WaitForNotification(); - absl::SleepFor(absl::Milliseconds(10)); - EXPECT_EQ(refobj->value(), 1); - const auto second_ref = refobj; - refobj->add_value(1); - }); - EXPECT_FALSE(main.IsUnique()); - notification.Notify(); - auto& unique = main.WaitUntilUnique(); - // Value is now 2 - unique->mul_value(3); - EXPECT_EQ(unique->value(), 6); - // Destruction blocks until thread is executed -} - -TEST_F(ShareableDependencyTest, SanityTestWithReset) { - ShareableDependency> main; - EXPECT_FALSE(main.IsUnique()); - - main.Reset(riegeli::Maker(1)); - EXPECT_TRUE(main.IsUnique()); - - absl::Notification notification; - pool_->Schedule([refobj = main.Share(), ¬ification] { - notification.WaitForNotification(); - absl::SleepFor(absl::Milliseconds(10)); - EXPECT_EQ(refobj->value(), 1); - const auto second_ref = refobj; - refobj->add_value(1); - }); - EXPECT_FALSE(main.IsUnique()); - notification.Notify(); - auto& unique = main.WaitUntilUnique(); - // Value is now 2 - unique->mul_value(3); - EXPECT_EQ(unique->value(), 6); - // Destruction blocks until thread is executed -} - -} // namespace -} // namespace array_record diff --git a/cpp/tri_state_ptr.h b/cpp/tri_state_ptr.h new file mode 100644 index 0000000..197e129 --- /dev/null +++ b/cpp/tri_state_ptr.h @@ -0,0 +1,191 @@ +/* Copyright 2024 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef ARRAY_RECORD_CPP_TRI_STATE_PTR_H_ +#define ARRAY_RECORD_CPP_TRI_STATE_PTR_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "cpp/common.h" + +namespace array_record { + +/** TriStatePtr is a wrapper around a pointer that allows for a unique and + * shared reference. + * + * There are three states: + * + * - NoRef: The object does not have shared or unique references. + * - Sharing: The object is shared. + * - Unique: The object is referenced by a unique pointer wrapper. + * + * The state transition from NoRef to Shared when MakeShared is called. + * An internal refernce count is incremented when a SharedRef is created. + * + * SharedRef ref = MakeShared(); -- + * NoRef ----------------------------> Sharing / | MakeShared() + * All SharedRef deallocated <-- + * <---------------------------- + * + * The state can also transition to Unique when WaitAndMakeUnique is called. + * We can only hold one unique reference at a time. + * + * UniqueRef ref = WaitAndMakeUnique(); + * NoRef ----------------------------> Unique + * The UniqueRef is deallocated + * <---------------------------- + * + * Other than the state transition above, state transitions methods would block + * until the specified state is possible. On deallocation, the destructor blocks + * until the state is NoRef. + * + * Example usage: + * + * TriStatePtr main(riegeli::Maker(...)); + * // Create a shared reference to work on other threads. + * pool->Schedule([refobj = foo_ptr.MakeShared()] { + * refobj->FooMethod(); + * }); + * + * // Blocks until refobj is out of scope. + * auto unique_ref = main.WaitAndMakeUnique(); + * unique_ref->CleanupFoo(); + * + */ +template +class TriStatePtr { + public: + DECLARE_IMMOBILE_CLASS(TriStatePtr); + TriStatePtr() = default; + + ~TriStatePtr() { + absl::MutexLock l(&mu_); + mu_.Await(absl::Condition( + +[](State* sharing_state) { return *sharing_state == State::kNoRef; }, + &state_)); + } + + // explicit TriStatePtr(std::unique_ptr ptr) : ptr_(std::move(ptr)) {} + explicit TriStatePtr(std::unique_ptr ptr) : ptr_(std::move(ptr)) {} + + class SharedRef { + public: + SharedRef(TriStatePtr* parent) : parent_(parent) {} + + SharedRef(const SharedRef& other) : parent_(other.parent_) { + parent_->ref_count_++; + } + SharedRef& operator=(const SharedRef& other) { + this->parent_ = other.parent_; + this->parent_->ref_count_++; + return *this; + } + + SharedRef(SharedRef&& other) : parent_(other.parent_) { + other.parent_ = nullptr; + } + SharedRef& operator=(SharedRef&& other) { + this->parent_ = other.parent_; + other.parent_ = nullptr; + return *this; + } + + ~SharedRef() { + if (parent_ == nullptr) { + return; + } + int32_t ref_count = + parent_->ref_count_.fetch_sub(1, std::memory_order_acq_rel) - 1; + if (ref_count == 0) { + absl::MutexLock l(&parent_->mu_); + parent_->state_ = State::kNoRef; + } + } + + const BaseT& operator*() const { return *parent_->ptr_.get(); } + const BaseT* operator->() const { return parent_->ptr_.get(); } + BaseT& operator*() { return *parent_->ptr_.get(); } + BaseT* operator->() { return parent_->ptr_.get(); } + + private: + TriStatePtr* parent_ = nullptr; + }; + + class UniqueRef { + public: + DECLARE_MOVE_ONLY_CLASS(UniqueRef); + UniqueRef(TriStatePtr* parent) : parent_(parent) {} + + ~UniqueRef() { + absl::MutexLock l(&parent_->mu_); + parent_->state_ = State::kNoRef; + } + + const BaseT& operator*() const { return *parent_->ptr_.get(); } + const BaseT* operator->() const { return parent_->ptr_.get(); } + BaseT& operator*() { return *parent_->ptr_.get(); } + BaseT* operator->() { return parent_->ptr_.get(); } + + private: + TriStatePtr* parent_; + }; + + SharedRef MakeShared() { + absl::MutexLock l(&mu_); + mu_.Await(absl::Condition( + +[](State* sharing_state) { return *sharing_state != State::kUnique; }, + &state_)); + state_ = State::kSharing; + ref_count_++; + return SharedRef(this); + } + + UniqueRef WaitAndMakeUnique() { + absl::MutexLock l(&mu_); + mu_.Await(absl::Condition( + +[](State* sharing_state) { return *sharing_state == State::kNoRef; }, + &state_)); + state_ = State::kUnique; + return UniqueRef(this); + } + + enum class State { + kNoRef = 0, + kSharing = 1, + kUnique = 2, + }; + + State state() const { + absl::MutexLock l(&mu_); + return state_; + } + + private: + mutable absl::Mutex mu_; + std::atomic_int32_t ref_count_ = 0; + State state_ ABSL_GUARDED_BY(mu_) = State::kNoRef; + std::unique_ptr ptr_; +}; + +} // namespace array_record + +#endif // ARRAY_RECORD_CPP_TRI_STATE_PTR_H_ diff --git a/cpp/tri_state_ptr_test.cc b/cpp/tri_state_ptr_test.cc new file mode 100644 index 0000000..b596150 --- /dev/null +++ b/cpp/tri_state_ptr_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2022 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cpp/tri_state_ptr.h" + +#include "gtest/gtest.h" +#include "absl/synchronization/notification.h" +#include "cpp/common.h" +#include "cpp/thread_pool.h" +#include "riegeli/base/maker.h" + +namespace array_record { +namespace { + +class FooBase { + public: + virtual ~FooBase() = default; + virtual int value() const = 0; + virtual void add_value(int v) = 0; + virtual void mul_value(int v) = 0; +}; + +class Foo : public FooBase { + public: + explicit Foo(int v) : value_(v) {} + DECLARE_MOVE_ONLY_CLASS(Foo); + + int value() const override { return value_; }; + void add_value(int v) override { value_ += v; } + void mul_value(int v) override { value_ *= v; } + + private: + int value_; +}; + +class TriStatePtrTest : public testing::Test { + public: + TriStatePtrTest() : pool_(ArrayRecordGlobalPool()) {} + + protected: + ARThreadPool* pool_; +}; + +TEST_F(TriStatePtrTest, SanityTest) { + TriStatePtr foo_main(riegeli::Maker(1).UniquePtr()); + EXPECT_EQ(foo_main.state(), TriStatePtr::State::kNoRef); + absl::Notification notification; + { + pool_->Schedule( + [foo_shared = foo_main.MakeShared(), ¬ification]() mutable { + notification.WaitForNotification(); + EXPECT_EQ(foo_shared->value(), 1); + const auto second_foo_shared = foo_shared; + foo_shared->add_value(1); + EXPECT_EQ(second_foo_shared->value(), 2); + }); + } + EXPECT_EQ(foo_main.state(), TriStatePtr::State::kSharing); + notification.Notify(); + auto foo_unique = foo_main.WaitAndMakeUnique(); + foo_unique->mul_value(3); + EXPECT_EQ(foo_unique->value(), 6); + EXPECT_EQ(foo_main.state(), TriStatePtr::State::kUnique); +} + +} // namespace +} // namespace array_record