Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SharedFuture #183

Merged
merged 7 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --incompatible_java_common_parameters=false --define=android_dexmerger_tool=d8_dexmerger --define=android_incremental_dexing_tool=d8_dexbuilder --nouse_workers_with_dexbuilder
build --cxxopt=-std=c++17 --cxxopt=-fcoroutines-ts --host_cxxopt=-std=c++17 --host_cxxopt=-fcoroutines-ts --incompatible_java_common_parameters=false --define=android_dexmerger_tool=d8_dexmerger --define=android_incremental_dexing_tool=d8_dexbuilder --nouse_workers_with_dexbuilder
9 changes: 5 additions & 4 deletions support-lib/cpp/Future.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ class Future {
return true;
}

template<typename ConcretePromise>
struct PromiseTypeBase {
Promise<T> _promise;
std::optional<djinni::expected<T, std::exception_ptr>> _result{};
Expand All @@ -379,7 +378,9 @@ class Future {
constexpr bool await_ready() const noexcept {
return false;
}
bool await_suspend(detail::CoroutineHandle<ConcretePromise> finished) const noexcept {
template <typename P>
bool await_suspend(detail::CoroutineHandle<P> finished) const noexcept {
li-feng-sc marked this conversation as resolved.
Show resolved Hide resolved
static_assert(std::is_convertible_v<P*, PromiseTypeBase*>);
auto& promise_type = finished.promise();
if (*promise_type._result) {
if constexpr (std::is_void_v<T>) {
Expand All @@ -406,7 +407,7 @@ class Future {
}
};

struct PromiseType: PromiseTypeBase<PromiseType>{
struct PromiseType: PromiseTypeBase {
template <typename V, typename = std::enable_if_t<std::is_convertible_v<V, T>>>
void return_value(V&& value) {
this->_result.emplace(std::forward<V>(value));
Expand All @@ -424,7 +425,7 @@ class Future {

#if defined(DJINNI_FUTURE_HAS_COROUTINE_SUPPORT)
template<>
struct Future<void>::PromiseType : PromiseTypeBase<PromiseType> {
struct Future<void>::PromiseType : PromiseTypeBase {
void return_void() {
_result.emplace();
}
Expand Down
162 changes: 162 additions & 0 deletions support-lib/cpp/SharedFuture.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/**
* Copyright 2021 Snap, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "Future.hpp"

#if defined(DJINNI_FUTURE_HAS_COROUTINE_SUPPORT)

#include <memory>
#include <optional>
#include <type_traits>
#include <variant>
#include <vector>

namespace djinni {

// SharedFuture is a wrapper around djinni::Future to allow multiple consumers (i.e. like std::shared_future)
// The API is designed to be similar to djinni::Future.
template<typename T>
class SharedFuture {
public:
// Create SharedFuture from Future. Runtime error if the future is already consumed.
explicit SharedFuture(Future<T>&& future);

// Transform into Future<T>.
Future<T> toFuture() const {
if (await_ready()) {
co_return await_resume(); // return stored value directly
li-feng-sc marked this conversation as resolved.
Show resolved Hide resolved
}
co_return co_await SharedFuture(*this); // retain copy during coroutine suspension
}

void wait() const {
waitIgnoringExceptions().wait();
}

decltype(auto) get() const {
wait();
return await_resume();
}

template <typename Func>
using ResultT = std::remove_cv_t<std::remove_reference_t<std::invoke_result_t<Func, const SharedFuture<T>&>>>;

// Transform the result of this future into a new future. The behavior is same as Future::then except that
// it doesn't consume the future, and can be called multiple times.
template<typename Func>
Future<ResultT<Func>> then(Func transform) const {
auto cpy = SharedFuture(*this); // retain copy during coroutine suspension
co_await cpy.waitIgnoringExceptions();
co_return transform(cpy);
}

// Same as above but returns SharedFuture.
template<typename Func>
SharedFuture<ResultT<Func>> thenShared(Func transform) const {
return SharedFuture<ResultT<Func>>(then(std::move(transform)));
}

// -- coroutine support implementation only; not intended externally --

This comment was marked as resolved.


bool await_ready() const {
std::scoped_lock lock(_sharedStates->mutex);
return _sharedStates->storedValue.has_value();
}

decltype(auto) await_resume() const {
if (!*_sharedStates->storedValue) {
std::rethrow_exception(_sharedStates->storedValue->error());
}
if constexpr (!std::is_void_v<T>) {
return const_cast<const T &>(_sharedStates->storedValue->value());
}
}

bool await_suspend(detail::CoroutineHandle<> h) const;

struct Promise : public Future<T>::promise_type {
SharedFuture<T> get_return_object() noexcept {
return SharedFuture(Future<T>::promise_type::get_return_object());
}
};
using promise_type = Promise;

private:
Future<void> waitIgnoringExceptions() const {
try {
co_await *this;
} catch (...) {
// Ignore exceptions.
}
}

struct SharedStates {
std::recursive_mutex mutex;
std::optional<djinni::expected<T, std::exception_ptr>> storedValue = std::nullopt;
std::vector<detail::CoroutineHandle<>> coroutineHandles;
};
// Use a shared_ptr to allow copying SharedFuture.
std::shared_ptr<SharedStates> _sharedStates = std::make_shared<SharedStates>();
};

// CTAD deduction guide to construct from Future directly.
template<typename T>
SharedFuture(Future<T>&&) -> SharedFuture<T>;

// ------------------ Implementation ------------------

template<typename T>
SharedFuture<T>::SharedFuture(Future<T>&& future) {
// `future` will invoke all continuations when it is ready.
future.then([sharedStates = _sharedStates](auto futureResult) {
std::vector toCall = [&] {
std::scoped_lock lock(sharedStates->mutex);
try {
if constexpr (std::is_void_v<T>) {
futureResult.get();
sharedStates->storedValue.emplace();
} else {
sharedStates->storedValue = futureResult.get();
}
} catch (...) {
sharedStates->storedValue = make_unexpected(std::current_exception());
}
return std::move(sharedStates->coroutineHandles);
}();
for (auto& handle : toCall) {
handle();
}
});
}

template<typename T>
bool SharedFuture<T>::await_suspend(detail::CoroutineHandle<> h) const {
{
std::unique_lock lock(_sharedStates->mutex);
if (!_sharedStates->storedValue) {
_sharedStates->coroutineHandles.push_back(std::move(h));
return true;
}
}
h();
return true;
}

} // namespace djinni

#endif
1 change: 1 addition & 0 deletions test-suite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ objc_library(
copts = [
"-ObjC++",
"-std=c++17",
"-fcoroutines-ts"
],
srcs = glob([
"generated-src/objc/**/*.mm",
Expand Down
91 changes: 91 additions & 0 deletions test-suite/handwritten-src/objc/tests/DBSharedFutureTest.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#import <Foundation/Foundation.h>
#import <XCTest/XCTest.h>

#include "../../../support-lib/cpp/SharedFuture.hpp"

@interface DBSharedFutureTest : XCTestCase
@end

@implementation DBSharedFutureTest

#ifdef DJINNI_FUTURE_HAS_COROUTINE_SUPPORT

- (void)setUp
{
[super setUp];
}

- (void)tearDown
{
[super tearDown];
}

- (void)testCreateFuture
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LiFengSC These are translated from the existing cpp tests

{
djinni::SharedFuture<int> resolvedInt(djinni::Promise<int>::resolve(42));
XCTAssertEqual(resolvedInt.get(), 42);

djinni::Promise<NSString*> strPromise;
djinni::SharedFuture futureString(strPromise.getFuture());

strPromise.setValue(@"foo");
XCTAssertEqualObjects(futureString.get(), @"foo");
}

- (void)testThen
{
djinni::Promise<int> intPromise;
djinni::SharedFuture<int> futureInt(intPromise.getFuture());

auto transformedInt = futureInt.thenShared([](const auto& resolved) { return 2 * resolved.get(); });

intPromise.setValue(42);
XCTAssertEqual(transformedInt.get(), 84);

// Also verify multiple consumers and chaining.
auto transformedString = futureInt.thenShared([](const auto& resolved) { return std::to_string(resolved.get()); });
auto futurePlusOneTimesTwo = futureInt.then([](auto resolved) { return resolved.get() + 1; }).then([](auto resolved) {
return 2 * resolved.get();
});
auto futureStringLen = transformedString.then([](auto resolved) { return resolved.get().length(); });

XCTAssertEqual(transformedString.get(), std::string("42"));
XCTAssertEqual(futurePlusOneTimesTwo.get(), (42 + 1) * 2);
XCTAssertEqual(futureStringLen.get(), 2);

XCTAssertEqual(futureInt.get(), 42);

auto voidFuture = transformedString.thenShared([](auto) {});
voidFuture.wait();

auto intFuture2 = voidFuture.thenShared([](auto) { return 43; });
XCTAssertEqual(intFuture2.get(), 43);
}

- (void)testException
{
// Also verify exception handling.
djinni::Promise<int> intPromise;
djinni::SharedFuture<int> futureInt(intPromise.getFuture());

intPromise.setException(std::runtime_error("mocked"));

XCTAssertThrows(futureInt.get());

auto thenResult = futureInt.then([](auto resolved) { return resolved.get(); });
XCTAssertThrows(thenResult.get());

auto withExceptionHandling = futureInt.thenShared([](const auto& resolved) {
try {
return resolved.get();
} catch (...) {
return -1;
}
});
withExceptionHandling.wait();
XCTAssertEqual(withExceptionHandling.get(), -1);
}

#endif

@end
Loading