Skip to content

Commit

Permalink
coroutine: experimental: generator: implement move and swap
Browse files Browse the repository at this point in the history
The coroutine generator move ctor, move assignment operator,
and swap override were never fully implemented.

This patch completes their implementation
and adds unit tests to cover move and swap
and also moving not-drained generator (to test
moving of their buffered value(s)).

Fixes scylladb#1789

Signed-off-by: Benny Halevy <[email protected]>
  • Loading branch information
bhalevy committed Jul 17, 2024
1 parent b8fc54d commit 1d7a876
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 8 deletions.
52 changes: 48 additions & 4 deletions include/seastar/coroutine/generator.hh
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ public:

generator_type get_return_object() noexcept;
void set_generator(generator_type* g) noexcept {
assert(!_generator);
_generator = g;
}

Expand Down Expand Up @@ -190,7 +189,6 @@ public:

auto get_return_object() noexcept -> generator_type;
void set_generator(generator_type* g) noexcept {
assert(!_generator);
_generator = g;
}

Expand Down Expand Up @@ -335,13 +333,27 @@ public:
generator(const generator&) = delete;
generator(generator&& other) noexcept
: _coro{std::exchange(other._coro, {})}
, _buffer_capacity{other._buffer_capacity} {}
, _promise(std::exchange(other._promise, nullptr))
, _values(std::move(other._values))
, _buffer_capacity{other._buffer_capacity}
, _exception(std::exchange(other._exception, nullptr)) {
if (_promise) {
_promise->set_generator(this);
}
}
generator& operator=(generator&& other) noexcept {
if (std::addressof(other) != this) {
auto old_coro = std::exchange(_coro, std::exchange(other._coro, {}));
if (old_coro) {
old_coro.destroy();
}
_promise = std::exchange(other._promise, nullptr);
if (_promise) {
_promise->set_generator(this);
}
_values = std::move(other._values);
const_cast<size_t&>(_buffer_capacity) = other._buffer_capacity;
_exception = std::exchange(other._exception, nullptr);
}
return *this;
}
Expand All @@ -353,6 +365,16 @@ public:

void swap(generator& other) noexcept {
std::swap(_coro, other._coro);
std::swap(_promise, other._promise);
if (_promise) {
_promise->set_generator(this);
}
if (other._promise) {
other._promise->set_generator(&other);
}
std::swap(_values, other._values);
std::swap(const_cast<size_t&>(_buffer_capacity), const_cast<size_t&>(other._buffer_capacity));
std::swap(_exception, other._exception);
}

internal::next_awaiter<T, generator> operator()() noexcept {
Expand Down Expand Up @@ -425,13 +447,26 @@ public:
}
generator(const generator&) = delete;
generator(generator&& other) noexcept
: _coro{std::exchange(other._coro, {})} {}
: _coro{std::exchange(other._coro, {})}
, _promise(std::exchange(other._promise, nullptr))
, _maybe_value(std::exchange(other._maybe_value, std::nullopt))
, _exception(std::exchange(other._exception, nullptr)) {
if (_promise) {
_promise->set_generator(this);
}
}
generator& operator=(generator&& other) noexcept {
if (std::addressof(other) != this) {
auto old_coro = std::exchange(_coro, std::exchange(other._coro, {}));
if (old_coro) {
old_coro.destroy();
}
_promise = std::exchange(other._promise, nullptr);
if (_promise) {
_promise->set_generator(this);
}
_maybe_value = std::exchange(other._maybe_value, std::nullopt);
_exception = std::exchange(other._exception, nullptr);
}
return *this;
}
Expand All @@ -443,6 +478,15 @@ public:

void swap(generator& other) noexcept {
std::swap(_coro, other._coro);
std::swap(_promise, other._promise);
if (_promise) {
_promise->set_generator(this);
}
if (other._promise) {
other._promise->set_generator(&other);
}
std::swap(_maybe_value, other._maybe_value);
std::swap(_exception, other._exception);
}

internal::next_awaiter<T, generator> operator()() noexcept {
Expand Down
94 changes: 90 additions & 4 deletions tests/unit/coroutines_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,14 @@ SEASTAR_TEST_CASE(test_as_future_preemption) {
BOOST_REQUIRE_THROW(f0.get(), std::runtime_error);
}

std::vector<int> gen_expected_fibs(unsigned count) {
std::vector<int> expected_fibs = {0, 1};
for (unsigned i = 2; i < count; ++i) {
expected_fibs.emplace_back(expected_fibs[i-2] + expected_fibs[i-1]);
}
return expected_fibs;
}

template<template<typename> class Container>
coroutine::experimental::generator<int, Container>
fibonacci_sequence(coroutine::experimental::buffer_size_t size, unsigned count) {
Expand All @@ -755,10 +763,8 @@ fibonacci_sequence(coroutine::experimental::buffer_size_t size, unsigned count)
}

template<template<typename> class Container>
seastar::future<> test_async_generator_drained() {
auto expected_fibs = {0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55};
auto fib = fibonacci_sequence<Container>(coroutine::experimental::buffer_size_t{2},
std::size(expected_fibs));
seastar::future<> test_async_generator_drained(coroutine::experimental::generator<int, Container> fib, unsigned count) {
auto expected_fibs = gen_expected_fibs(count);
for (auto expected_fib : expected_fibs) {
auto actual_fib = co_await fib();
BOOST_REQUIRE(actual_fib.has_value());
Expand All @@ -768,6 +774,37 @@ seastar::future<> test_async_generator_drained() {
BOOST_REQUIRE(!sentinel.has_value());
}

template<template<typename> class Container>
seastar::future<> test_async_generator_drained() {
unsigned count = 11;
co_return co_await test_async_generator_drained(fibonacci_sequence<Container>(coroutine::experimental::buffer_size_t{2}, count), count);
}

template<template<typename> class Container>
seastar::future<> test_move_async_generator_drained() {
unsigned count = 11;
auto fib0 = fibonacci_sequence<Container>(coroutine::experimental::buffer_size_t{2}, count);
co_await test_async_generator_drained(std::move(fib0), count);
auto fib1 = fibonacci_sequence<Container>(coroutine::experimental::buffer_size_t{2}, ++count);
fib0 = std::move(fib1);
co_await test_async_generator_drained(std::move(fib0), count);
fib0 = fibonacci_sequence<Container>(coroutine::experimental::buffer_size_t{2}, ++count);
fib1 = fibonacci_sequence<Container>(coroutine::experimental::buffer_size_t{2}, ++count);
fib0 = std::move(fib1);
co_await test_async_generator_drained(std::move(fib0), count);
}

template<template<typename> class Container>
seastar::future<> test_swap_async_generator_drained() {
unsigned count[2] = {11, 17};
auto fib0 = fibonacci_sequence<Container>(coroutine::experimental::buffer_size_t{2}, count[0]);
auto fib1 = fibonacci_sequence<Container>(coroutine::experimental::buffer_size_t{2}, count[1]);
std::swap(fib0, fib1);
std::swap(count[0], count[1]);
co_await test_async_generator_drained(std::move(fib0), count[0]);
co_await test_async_generator_drained(std::move(fib1), count[1]);
}

template<typename T>
using buffered_container = circular_buffer<T>;

Expand All @@ -779,6 +816,55 @@ SEASTAR_TEST_CASE(test_async_generator_drained_unbuffered) {
return test_async_generator_drained<std::optional>();
}

SEASTAR_TEST_CASE(test_move_async_generator_drained_buffered) {
return test_move_async_generator_drained<buffered_container>();
}

SEASTAR_TEST_CASE(test_move_async_generator_drained_unbuffered) {
return test_move_async_generator_drained<std::optional>();
}

SEASTAR_TEST_CASE(test_swap_async_generator_drained_buffered) {
return test_swap_async_generator_drained<buffered_container>();
}

SEASTAR_TEST_CASE(test_swap_async_generator_drained_unbuffered) {
return test_swap_async_generator_drained<std::optional>();
}

template<template<typename> class Container>
seastar::future<coroutine::experimental::generator<int, Container>> test_async_generator_drained_incrementally(coroutine::experimental::generator<int, Container> fib, std::optional<int> expected_value) {
auto actual_fib = co_await fib();
if (expected_value) {
BOOST_REQUIRE(actual_fib.has_value());
BOOST_REQUIRE_EQUAL(actual_fib.value(), *expected_value);
} else {
BOOST_REQUIRE(!actual_fib.has_value());
}
co_return fib;
}

template<template<typename> class Container>
seastar::future<> test_async_generator_drained_incrementally() {
unsigned count = 17;
auto expected_fibs = gen_expected_fibs(count);
auto fib = fibonacci_sequence<Container>(coroutine::experimental::buffer_size_t{2}, count);
for (auto it = expected_fibs.begin(); it != expected_fibs.end(); ++it) {
fib = co_await test_async_generator_drained_incrementally(std::move(fib), *it);
}
fib = co_await test_async_generator_drained_incrementally(std::move(fib), std::nullopt);
// once drained generator return std::nullopt
fib = co_await test_async_generator_drained_incrementally(std::move(fib), std::nullopt);
}

SEASTAR_TEST_CASE(test_async_generator_drained_incrementally_buffered) {
return test_async_generator_drained_incrementally<buffered_container>();
}

SEASTAR_TEST_CASE(test_async_generator_drained_incrementally_unbuffered) {
return test_async_generator_drained_incrementally<std::optional>();
}

template<template<typename> class Container>
seastar::future<> test_async_generator_not_drained() {
auto fib = fibonacci_sequence<Container>(coroutine::experimental::buffer_size_t{2},
Expand Down

0 comments on commit 1d7a876

Please sign in to comment.