Skip to content

Commit

Permalink
Factor out a method "GetMatchedStream".
Browse files Browse the repository at this point in the history
This simplifies the logic, but on the first call, this slightly slows down the reading (because if we want to read, we first call size to provoke the matching, which reads at least some part of the ciphertext. afterwards we read ciphertext from the start again).

This also fixes the bug that calling size() would fail if PRead isn't called before this.

PiperOrigin-RevId: 584556418
Change-Id: I47957a63c45b4e9a293b6e5ec70abf462bded9b9
  • Loading branch information
tholenst authored and copybara-github committed Nov 22, 2023
1 parent fed9289 commit de6dcf9
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 41 deletions.
3 changes: 3 additions & 0 deletions tink/streamingaead/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ cc_library(
"//tink/util:statusor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
],
)
Expand Down Expand Up @@ -411,8 +412,10 @@ cc_test(
"//proto:tink_cc_proto",
"//tink/subtle:random",
"//tink/subtle:test_util",
"//tink/util:buffer",
"//tink/util:ostream_output_stream",
"//tink/util:status",
"//tink/util:statusor",
"//tink/util:test_matchers",
"//tink/util:test_util",
"@com_google_absl//absl/memory",
Expand Down
3 changes: 3 additions & 0 deletions tink/streamingaead/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ tink_cc_library(
tink::streamingaead::shared_random_access_stream
absl::memory
absl::status
absl::string_view
absl::synchronization
tink::core::primitive_set
tink::core::random_access_stream
Expand Down Expand Up @@ -392,8 +393,10 @@ tink_cc_test(
tink::internal::test_random_access_stream
tink::subtle::random
tink::subtle::test_util
tink::util::buffer
tink::util::ostream_output_stream
tink::util::status
tink::util::statusor
tink::util::test_matchers
tink::util::test_util
tink::proto::tink_cc_proto
Expand Down
94 changes: 55 additions & 39 deletions tink/streamingaead/decrypting_random_access_stream.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019 Google Inc.
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -16,12 +16,14 @@

#include "tink/streamingaead/decrypting_random_access_stream.h"

#include <cstdint>
#include <memory>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tink/primitive_set.h"
#include "tink/random_access_stream.h"
Expand Down Expand Up @@ -61,30 +63,38 @@ StatusOr<std::unique_ptr<RandomAccessStream>> DecryptingRandomAccessStream::New(
}

util::Status DecryptingRandomAccessStream::PRead(
int64_t position, int count,
crypto::tink::util::Buffer* dest_buffer) {
{ // "fast-track": quickly proceed if matching has been attempted/found.
if (dest_buffer == nullptr) {
return util::Status(absl::StatusCode::kInvalidArgument,
"dest_buffer must be non-null");
}
if (count < 0) {
return util::Status(absl::StatusCode::kInvalidArgument,
"count cannot be negative");
}
if (count > dest_buffer->allocated_size()) {
return util::Status(absl::StatusCode::kInvalidArgument,
"buffer too small");
}
if (position < 0) {
return util::Status(absl::StatusCode::kInvalidArgument,
"position cannot be negative");
}
int64_t position, int count, crypto::tink::util::Buffer* dest_buffer) {
if (dest_buffer == nullptr) {
return util::Status(absl::StatusCode::kInvalidArgument,
"dest_buffer must be non-null");
}
if (count < 0) {
return util::Status(absl::StatusCode::kInvalidArgument,
"count cannot be negative");
}
if (count > dest_buffer->allocated_size()) {
return util::Status(absl::StatusCode::kInvalidArgument, "buffer too small");
}
if (position < 0) {
return util::Status(absl::StatusCode::kInvalidArgument,
"position cannot be negative");
}
crypto::tink::util::StatusOr<crypto::tink::RandomAccessStream*>
matched_stream = GetMatchedStream();
if (!matched_stream.ok()) {
return matched_stream.status();
}
return (*matched_stream)->PRead(position, count, dest_buffer);
}

crypto::tink::util::StatusOr<crypto::tink::RandomAccessStream*>
DecryptingRandomAccessStream::GetMatchedStream() const {
{
absl::ReaderMutexLock lock(&matching_mutex_);
if (matching_stream_ != nullptr) {
return matching_stream_->PRead(position, count, dest_buffer);
}
if (attempted_matching_) {
if (matching_stream_ != nullptr) {
return matching_stream_.get();
}
return Status(absl::StatusCode::kInvalidArgument,
"Did not find a decrypter matching the ciphertext stream.");
}
Expand All @@ -93,29 +103,35 @@ util::Status DecryptingRandomAccessStream::PRead(
absl::MutexLock lock(&matching_mutex_);

// Re-check that matching hasn't been attempted in the meantime.
if (matching_stream_ != nullptr) {
return matching_stream_->PRead(position, count, dest_buffer);
}
if (attempted_matching_) {
if (matching_stream_ != nullptr) {
return matching_stream_.get();
}
return Status(absl::StatusCode::kInvalidArgument,
"Did not find a decrypter matching the ciphertext stream.");
}

attempted_matching_ = true;
std::vector<StreamingAeadEntry*> all_primitives = primitives_->get_all();
util::StatusOr<std::unique_ptr<crypto::tink::util::Buffer>> buffer =
crypto::tink::util::Buffer::New(1);
if (!buffer.ok()) {
return buffer.status();
}
for (const StreamingAeadEntry* entry : all_primitives) {
StreamingAead& streaming_aead = entry->get_primitive();
auto shared_ct = absl::make_unique<SharedRandomAccessStream>(
ciphertext_source_.get());
auto shared_ct =
absl::make_unique<SharedRandomAccessStream>(ciphertext_source_.get());
auto decrypting_stream_result =
streaming_aead.NewDecryptingRandomAccessStream(
std::move(shared_ct), associated_data_);
streaming_aead.NewDecryptingRandomAccessStream(std::move(shared_ct),
associated_data_);
if (decrypting_stream_result.ok()) {
auto status =
decrypting_stream_result.value()->PRead(position, count, dest_buffer);
if (status.ok() || status.code() == absl::StatusCode::kOutOfRange) {
Status read_result =
decrypting_stream_result.value()->PRead(0, 1, buffer->get());
if (read_result.ok() || absl::IsOutOfRange(read_result)) {
// Found a match.
matching_stream_ = std::move(decrypting_stream_result.value());
return status;
return matching_stream_.get();
}
}
// Not a match, try the next primitive.
Expand All @@ -125,12 +141,12 @@ util::Status DecryptingRandomAccessStream::PRead(
}

StatusOr<int64_t> DecryptingRandomAccessStream::size() {
absl::ReaderMutexLock lock(&matching_mutex_);
if (matching_stream_ != nullptr) {
return matching_stream_->size();
crypto::tink::util::StatusOr<crypto::tink::RandomAccessStream*>
matched_stream = GetMatchedStream();
if (!matched_stream.ok()) {
return matched_stream.status();
}
// TODO(b/139722894): attempt matching here?
return Status(absl::StatusCode::kUnavailable, "no matching found yet");
return (*matched_stream)->size();
}

} // namespace streamingaead
Expand Down
8 changes: 6 additions & 2 deletions tink/streamingaead/decrypting_random_access_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,17 @@ class DecryptingRandomAccessStream : public crypto::tink::RandomAccessStream {
associated_data_(associated_data),
attempted_matching_(false),
matching_stream_(nullptr) {}

crypto::tink::util::StatusOr<crypto::tink::RandomAccessStream*>
GetMatchedStream() const;

std::shared_ptr<
crypto::tink::PrimitiveSet<crypto::tink::StreamingAead>> primitives_;
std::unique_ptr<crypto::tink::RandomAccessStream> ciphertext_source_;
std::string associated_data_;
mutable absl::Mutex matching_mutex_;
bool attempted_matching_ ABSL_GUARDED_BY(matching_mutex_);
std::unique_ptr<crypto::tink::RandomAccessStream> matching_stream_
mutable bool attempted_matching_ ABSL_GUARDED_BY(matching_mutex_);
mutable std::unique_ptr<crypto::tink::RandomAccessStream> matching_stream_
ABSL_GUARDED_BY(matching_mutex_);
};

Expand Down
24 changes: 24 additions & 0 deletions tink/streamingaead/decrypting_random_access_stream_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
#include "tink/streaming_aead.h"
#include "tink/subtle/random.h"
#include "tink/subtle/test_util.h"
#include "tink/util/buffer.h"
#include "tink/util/ostream_output_stream.h"
#include "tink/util/status.h"
#include "tink/util/statusor.h"
#include "tink/util/test_matchers.h"
#include "tink/util/test_util.h"
#include "proto/tink.pb.h"
Expand All @@ -48,6 +50,7 @@ namespace {

using crypto::tink::test::DummyStreamingAead;
using crypto::tink::test::IsOk;
using crypto::tink::test::IsOkAndHolds;
using crypto::tink::test::StatusIs;
using google::crypto::tink::KeysetInfo;
using google::crypto::tink::KeyStatusType;
Expand Down Expand Up @@ -356,6 +359,27 @@ TEST(DecryptingRandomAccessStreamTest, NullCiphertextSource) {
HasSubstr("ciphertext_source must be non-null")));
}

TEST(DecryptingRandomAccessStreamTest, CallSizeBeforePReadWorks) {
uint32_t key_id = 1234543;
std::string saead_name = "streaming_aead";
auto saead_set = GetTestStreamingAeadSet({{key_id, saead_name}});

std::string associated_data = "associated_data";

for (int pt_size : {0, 1, 100}) {
std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
std::unique_ptr<RandomAccessStream> ciphertext =
GetCiphertextSource(&(saead_set->get_primary()->get_primitive()),
plaintext, associated_data);

util::StatusOr<std::unique_ptr<RandomAccessStream>> dec_stream =
DecryptingRandomAccessStream::New(saead_set, std::move(ciphertext),
associated_data);
ASSERT_THAT(dec_stream, IsOk());
EXPECT_THAT((*dec_stream)->size(), IsOkAndHolds(pt_size));
}
}

} // namespace
} // namespace streamingaead
} // namespace tink
Expand Down

0 comments on commit de6dcf9

Please sign in to comment.