Skip to content

Commit

Permalink
Riegeli usage cleanup:
Browse files Browse the repository at this point in the history
Use `riegeli::SharedPtr` instead of `std::shared_ptr`. Do not convert it from
`std::unique_ptr`: `riegeli::SharedPtr` always allocates the control block
together with the object, hence this is not supported, and this is more
efficient anyway.

Utilize CTAD more.

Let `MaskedReader` constructor take `riegeli::Reader&` instead of
`std::unique_ptr<riegeli::Reader>`. There is no ownership transfer, the
`riegeli::Reader` is accessed only during construction.

Use `riegeli::SharedBuffer` instead of `std::shared_ptr<std::string>` in
`MaskedReader`. It has fewer indirections. It does not track the exact size
but this is not needed here.

Use defaulted move constructor and move assignment of `MaskedReader`. They do
the right thing.

Add `MaskedReader::Reset()` to reset the instance to a state equivalent to a
newly constructed state in-place.

Use `std::optional` instead of `absl::optional`.

Use `std::optional` or storing the value directly instead of `std::unique_ptr`
when movability is not needed.

PiperOrigin-RevId: 694608145
  • Loading branch information
QrczakMK authored and copybara-github committed Nov 9, 2024
1 parent 4ea151d commit 6e7c29b
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 142 deletions.
8 changes: 4 additions & 4 deletions cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ cc_library(
":common",
":thread_pool",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
"@com_google_absl//absl/synchronization",
],
Expand Down Expand Up @@ -131,6 +130,7 @@ cc_library(
"@com_google_riegeli//riegeli/base:initializer",
"@com_google_riegeli//riegeli/base:object",
"@com_google_riegeli//riegeli/base:options_parser",
"@com_google_riegeli//riegeli/base:shared_ptr",
"@com_google_riegeli//riegeli/base:status",
"@com_google_riegeli//riegeli/bytes:chain_writer",
"@com_google_riegeli//riegeli/bytes:writer",
Expand All @@ -152,10 +152,8 @@ cc_library(
deps = [
":common",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
"@com_google_riegeli//riegeli/base:object",
"@com_google_riegeli//riegeli/base:shared_buffer",
"@com_google_riegeli//riegeli/base:status",
"@com_google_riegeli//riegeli/base:types",
"@com_google_riegeli//riegeli/bytes:reader",
Expand Down Expand Up @@ -185,6 +183,7 @@ cc_library(
"@com_google_riegeli//riegeli/base:initializer",
"@com_google_riegeli//riegeli/base:object",
"@com_google_riegeli//riegeli/base:options_parser",
"@com_google_riegeli//riegeli/base:shared_ptr",
"@com_google_riegeli//riegeli/base:status",
"@com_google_riegeli//riegeli/bytes:reader",
"@com_google_riegeli//riegeli/chunk_encoding:chunk",
Expand All @@ -207,6 +206,7 @@ cc_test(
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@com_google_riegeli//riegeli/base:initializer",
"@com_google_riegeli//riegeli/base:shared_ptr",
"@com_google_riegeli//riegeli/bytes:chain_writer",
"@com_google_riegeli//riegeli/bytes:cord_writer",
"@com_google_riegeli//riegeli/bytes:string_reader",
Expand Down
34 changes: 18 additions & 16 deletions cpp/array_record_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ limitations under the License.
#include "cpp/parallel_for.h"
#include "cpp/thread_pool.h"
#include "third_party/protobuf/message_lite.h"
#include "riegeli/base/maker.h"
#include "riegeli/base/object.h"
#include "riegeli/base/options_parser.h"
#include "riegeli/base/shared_ptr.h"
#include "riegeli/base/status.h"
#include "riegeli/bytes/reader.h"
#include "riegeli/chunk_encoding/chunk.h"
Expand Down Expand Up @@ -168,12 +170,12 @@ ChunkDecoder ReadChunk(Reader& reader, size_t pos, size_t len) {
decoder.Fail(reader.status());
return decoder;
}
MaskedReader masked_reader(reader.NewReader(pos), len);
MaskedReader masked_reader(*reader.NewReader(pos), len);
if (!masked_reader.ok()) {
decoder.Fail(masked_reader.status());
return decoder;
}
auto chunk_reader = riegeli::DefaultChunkReader<>(&masked_reader);
riegeli::DefaultChunkReader chunk_reader(&masked_reader);
Chunk chunk;
if (!chunk_reader.ReadChunk(chunk)) {
decoder.Fail(chunk_reader.status());
Expand Down Expand Up @@ -343,15 +345,15 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecords(
MaskedReader masked_reader(riegeli::kClosed);
{
AR_ENDO_SCOPE("MaskedReader");
masked_reader = MaskedReader(
reader->NewReader(state_->chunk_offsets[chunk_idx_start]),
masked_reader.Reset(
*reader->NewReader(state_->chunk_offsets[chunk_idx_start]),
buf_len);
}
for (uint64_t chunk_idx = chunk_idx_start; chunk_idx <= last_chunk_idx;
++chunk_idx) {
AR_ENDO_SCOPE("ChunkReader+ChunkDecoder");
masked_reader.Seek(state_->chunk_offsets[chunk_idx]);
riegeli::DefaultChunkReader<> chunk_reader(&masked_reader);
riegeli::DefaultChunkReader chunk_reader(&masked_reader);
Chunk chunk;
if (ABSL_PREDICT_FALSE(!chunk_reader.ReadChunk(chunk))) {
return chunk_reader.status();
Expand Down Expand Up @@ -420,15 +422,15 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecordsInRange(
MaskedReader masked_reader(riegeli::kClosed);
{
AR_ENDO_SCOPE("MaskedReader");
masked_reader = MaskedReader(
reader->NewReader(state_->chunk_offsets[chunk_idx_start]),
masked_reader.Reset(
*reader->NewReader(state_->chunk_offsets[chunk_idx_start]),
buf_len);
}
for (uint64_t chunk_idx = chunk_idx_start; chunk_idx <= last_chunk_idx;
++chunk_idx) {
AR_ENDO_SCOPE("ChunkReader+ChunkDecoder");
masked_reader.Seek(state_->chunk_offsets[chunk_idx]);
riegeli::DefaultChunkReader<> chunk_reader(&masked_reader);
riegeli::DefaultChunkReader chunk_reader(&masked_reader);
Chunk chunk;
if (ABSL_PREDICT_FALSE(!chunk_reader.ReadChunk(chunk))) {
return chunk_reader.status();
Expand Down Expand Up @@ -539,14 +541,14 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecordsWithIndices(
MaskedReader masked_reader(riegeli::kClosed);
{
AR_ENDO_SCOPE("MaskedReader");
masked_reader = MaskedReader(
reader->NewReader(state_->chunk_offsets[buffer_chunks[0]]),
masked_reader.Reset(
*reader->NewReader(state_->chunk_offsets[buffer_chunks[0]]),
buf_len);
}
for (auto chunk_idx : buffer_chunks) {
AR_ENDO_SCOPE("ChunkReader+ChunkDecoder");
masked_reader.Seek(state_->chunk_offsets[chunk_idx]);
riegeli::DefaultChunkReader<> chunk_reader(&masked_reader);
riegeli::DefaultChunkReader chunk_reader(&masked_reader);
Chunk chunk;
if (ABSL_PREDICT_FALSE(!chunk_reader.ReadChunk(chunk))) {
return chunk_reader.status();
Expand Down Expand Up @@ -686,8 +688,8 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) {
// movable, OSS ThreadPool only takes std::function which requires all the
// captures to be copyable. Therefore we must wrap the promise in a
// shared_ptr to copy it over to the scheduled task.
auto decoder_promise =
std::make_shared<std::promise<std::vector<ChunkDecoder>>>();
riegeli::SharedPtr decoder_promise(
riegeli::Maker<std::promise<std::vector<ChunkDecoder>>>());
state_->future_decoders.push(
{buffer_to_add, decoder_promise->get_future()});
const auto reader = get_backing_reader();
Expand Down Expand Up @@ -719,8 +721,8 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) {
MaskedReader masked_reader(riegeli::kClosed);
{
AR_ENDO_SCOPE("MaskedReader");
masked_reader =
MaskedReader(reader->NewReader(chunk_offsets.front()), buffer_len);
masked_reader.Reset(*reader->NewReader(chunk_offsets.front()),
buffer_len);
}
if (!masked_reader.ok()) {
for (auto& decoder : decoders) {
Expand All @@ -733,7 +735,7 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) {
AR_ENDO_SCOPE("ChunkReader+ChunkDecoder");
for (auto local_chunk_idx : IndicesOf(chunk_offsets)) {
masked_reader.Seek(chunk_offsets[local_chunk_idx]);
auto chunk_reader = riegeli::DefaultChunkReader<>(&masked_reader);
riegeli::DefaultChunkReader chunk_reader(&masked_reader);
Chunk chunk;
if (!chunk_reader.ReadChunk(chunk)) {
decoders[local_chunk_idx].Fail(chunk_reader.status());
Expand Down
41 changes: 23 additions & 18 deletions cpp/array_record_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ limitations under the License.
#include "cpp/sequenced_chunk_writer.h"
#include "cpp/thread_pool.h"
#include "third_party/protobuf/message_lite.h"
#include "riegeli/base/maker.h"
#include "riegeli/base/object.h"
#include "riegeli/base/options_parser.h"
#include "riegeli/base/shared_ptr.h"
#include "riegeli/base/status.h"
#include "riegeli/bytes/chain_writer.h"
#include "riegeli/chunk_encoding/chunk.h"
Expand All @@ -66,7 +68,7 @@ constexpr uint32_t kZstdDefaultWindowLog = 20;
// Generated from `echo 'ArrayRecord' | md5sum | cut -b 1-16`
constexpr uint64_t kMagic = 0x71930e704fdae05eULL;

// zstd:3 gives a good trade-off for both the compression and decomopression
// zstd:3 gives a good trade-off for both the compression and decompression
// speed.
constexpr char kArrayRecordDefaultCompression[] = "zstd:3";

Expand Down Expand Up @@ -169,7 +171,7 @@ ArrayRecordWriterBase::Options::FromString(absl::string_view text) {
return options_parser.status();
}
// From our benchmarks we figured zstd:3 gives the best trade-off for both the
// compression and decomopression speed.
// compression and decompression speed.
if (text == "default" ||
(!absl::StrContains(compressor_text, "uncompressed") &&
!absl::StrContains(compressor_text, "brotli") &&
Expand Down Expand Up @@ -388,23 +390,28 @@ void ArrayRecordWriterBase::Done() {
submit_chunk_callback_->WriteFooterAndPostscript(writer.get());
}

std::unique_ptr<riegeli::ChunkEncoder> ArrayRecordWriterBase::CreateEncoder() {
std::unique_ptr<riegeli::ChunkEncoder> encoder;
riegeli::SharedPtr<riegeli::ChunkEncoder>
ArrayRecordWriterBase::CreateEncoder() {
auto wrap_encoder =
[this](auto encoder) -> riegeli::SharedPtr<riegeli::ChunkEncoder> {
if (pool_) {
return riegeli::SharedPtr(
riegeli::Maker<riegeli::DeferredEncoder>(std::move(encoder)));
} else {
return riegeli::SharedPtr(std::move(encoder));
}
};
if (options_.transpose()) {
encoder = std::make_unique<riegeli::TransposeEncoder>(
return wrap_encoder(riegeli::Maker<riegeli::TransposeEncoder>(
options_.compressor_options(),
riegeli::TransposeEncoder::TuningOptions().set_bucket_size(
options_.transpose_bucket_size()));
options_.transpose_bucket_size())));
} else {
encoder = std::make_unique<riegeli::SimpleEncoder>(
return wrap_encoder(riegeli::Maker<riegeli::SimpleEncoder>(
options_.compressor_options(),
riegeli::SimpleEncoder::TuningOptions().set_size_hint(
submit_chunk_callback_->get_last_decoded_data_size()));
}
if (pool_) {
return std::make_unique<riegeli::DeferredEncoder>(std::move(encoder));
submit_chunk_callback_->get_last_decoded_data_size())));
}
return encoder;
}

bool ArrayRecordWriterBase::WriteRecord(const google::protobuf::MessageLite& record) {
Expand Down Expand Up @@ -432,20 +439,18 @@ bool ArrayRecordWriterBase::WriteRecordImpl(Record&& record) {
if (chunk_encoder_->num_records() >= options_.group_size()) {
auto writer = get_writer();
auto encoder = std::move(chunk_encoder_);
auto chunk_promise =
std::make_shared<std::promise<absl::StatusOr<Chunk>>>();
riegeli::SharedPtr chunk_promise(
riegeli::Maker<std::promise<absl::StatusOr<Chunk>>>());
if (!writer->CommitFutureChunk(chunk_promise->get_future())) {
Fail(writer->status());
return false;
}
chunk_encoder_ = CreateEncoder();
if (pool_ && options_.max_parallelism().value() > 1) {
std::shared_ptr<riegeli::ChunkEncoder> shared_encoder =
std::move(encoder);
submit_chunk_callback_->TrackConcurrentChunkWriters();
pool_->Schedule([writer, shared_encoder, chunk_promise]() mutable {
pool_->Schedule([writer, encoder, chunk_promise]() mutable {
AR_ENDO_TASK("Encode riegeli chunk");
chunk_promise->set_value(EncodeChunk(shared_encoder.get()));
chunk_promise->set_value(EncodeChunk(encoder.get()));
writer->SubmitFutureChunks(false);
});
return true;
Expand Down
5 changes: 3 additions & 2 deletions cpp/array_record_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ limitations under the License.
#include "cpp/thread_pool.h"
#include "riegeli/base/initializer.h"
#include "riegeli/base/object.h"
#include "riegeli/base/shared_ptr.h"
#include "riegeli/bytes/writer.h"
#include "riegeli/chunk_encoding/chunk_encoder.h"
#include "riegeli/chunk_encoding/compressor_options.h"
Expand Down Expand Up @@ -326,7 +327,7 @@ class ArrayRecordWriterBase : public riegeli::Object {
void Done() override;

private:
std::unique_ptr<riegeli::ChunkEncoder> CreateEncoder();
riegeli::SharedPtr<riegeli::ChunkEncoder> CreateEncoder();
template <typename Record>
bool WriteRecordImpl(Record&& record);

Expand All @@ -336,7 +337,7 @@ class ArrayRecordWriterBase : public riegeli::Object {

Options options_;
ARThreadPool* pool_;
std::unique_ptr<riegeli::ChunkEncoder> chunk_encoder_;
riegeli::SharedPtr<riegeli::ChunkEncoder> chunk_encoder_;
std::unique_ptr<SubmitChunkCallback> submit_chunk_callback_;
};

Expand Down
Loading

0 comments on commit 6e7c29b

Please sign in to comment.