Skip to content

Commit

Permalink
localize all uses of TTG_CXX_COROUTINE_NAMESPACE into util/coroutine.h
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Feb 26, 2024
1 parent 0c65e11 commit a4fbc20
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 45 deletions.
33 changes: 17 additions & 16 deletions ttg/ttg/device/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ namespace ttg::device {
constexpr bool await_ready() const noexcept { return false; }

/* always suspend */
constexpr void await_suspend( TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle<> ) const noexcept {}
template <typename Promise>
constexpr void await_suspend( ttg::coroutine_handle<Promise> ) const noexcept {}

void await_resume() noexcept {
if constexpr (sizeof...(Ts) > 0) {
Expand Down Expand Up @@ -86,7 +87,7 @@ namespace ttg::device {
namespace detail {
struct send_coro_promise_type;

using send_coro_handle_type = TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle<send_coro_promise_type>;
using send_coro_handle_type = ttg::coroutine_handle<send_coro_promise_type>;

/// a coroutine for sending data from the device
struct send_coro_state : public send_coro_handle_type {
Expand Down Expand Up @@ -118,21 +119,21 @@ namespace ttg::device {
/* do not suspend the coroutine on first invocation, we want to run
* the coroutine immediately and suspend only once.
*/
TTG_CXX_COROUTINE_NAMESPACE::suspend_never initial_suspend() {
ttg::suspend_never initial_suspend() {
return {};
}

/* we don't suspend the coroutine at the end.
* it can be destroyed once the send/broadcast is done
*/
TTG_CXX_COROUTINE_NAMESPACE::suspend_never final_suspend() noexcept {
ttg::suspend_never final_suspend() noexcept {
return {};
}

send_coro_state get_return_object() { return send_coro_state{send_coro_handle_type::from_promise(*this)}; }

/* the send coros only have an empty co_await */
TTG_CXX_COROUTINE_NAMESPACE::suspend_always await_transform(ttg::Void) {
ttg::suspend_always await_transform(ttg::Void) {
return {};
}

Expand Down Expand Up @@ -413,7 +414,7 @@ namespace ttg::device {
// fwd-decl
struct device_task_promise_type;
// base type for ttg::device::Task
using device_task_handle_type = TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle<device_task_promise_type>;
using device_task_handle_type = ttg::coroutine_handle<device_task_promise_type>;
} // namespace detail

/// A device::Task is a coroutine (a callable that can be suspended and resumed).
Expand Down Expand Up @@ -458,7 +459,7 @@ namespace ttg::device {
/* do not suspend the coroutine on first invocation, we want to run
* the coroutine immediately and suspend when we get the device transfers.
*/
TTG_CXX_COROUTINE_NAMESPACE::suspend_never initial_suspend() {
ttg::suspend_never initial_suspend() {
m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
return {};
}
Expand All @@ -467,19 +468,19 @@ namespace ttg::device {
* so we can access the promise.
* TODO: necessary? maybe we can save one suspend here
*/
TTG_CXX_COROUTINE_NAMESPACE::suspend_always final_suspend() noexcept {
ttg::suspend_always final_suspend() noexcept {
m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
return {};
}

/* Allow co_await on a tuple */
template<typename... Views>
TTG_CXX_COROUTINE_NAMESPACE::suspend_always await_transform(std::tuple<Views&...> &views) {
ttg::suspend_always await_transform(std::tuple<Views&...> &views) {
return yield_value(views);
}

template<typename... Ts>
TTG_CXX_COROUTINE_NAMESPACE::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
/* TODO: are we allowed to not suspend here and launch the kernel directly? */
m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
Expand All @@ -496,13 +497,13 @@ namespace ttg::device {
return a;
}

TTG_CXX_COROUTINE_NAMESPACE::suspend_always await_transform(std::vector<device::detail::send_t>&& v) {
ttg::suspend_always await_transform(std::vector<device::detail::send_t>&& v) {
m_sends = std::forward<std::vector<device::detail::send_t>>(v);
m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
return {};
}

TTG_CXX_COROUTINE_NAMESPACE::suspend_always await_transform(device::detail::send_t&& v) {
ttg::suspend_always await_transform(device::detail::send_t&& v) {
m_sends.clear();
m_sends.push_back(std::forward<device::detail::send_t>(v));
m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
Expand Down Expand Up @@ -560,7 +561,7 @@ namespace ttg::device {

struct device_reducer_promise_type;

using device_reducer_handle_type = TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle<device_reducer_promise_type>;
using device_reducer_handle_type = ttg::coroutine_handle<device_reducer_promise_type>;

/// task that can be resumed after some events occur
struct device_reducer : public device_reducer_handle_type {
Expand Down Expand Up @@ -596,7 +597,7 @@ namespace ttg::device {
/* do not suspend the coroutine on first invocation, we want to run
* the coroutine immediately and suspend when we get the device transfers.
*/
TTG_CXX_COROUTINE_NAMESPACE::suspend_never initial_suspend() {
ttg::suspend_never initial_suspend() {
m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
return {};
}
Expand All @@ -605,13 +606,13 @@ namespace ttg::device {
* so we can access the promise.
* TODO: necessary? maybe we can save one suspend here
*/
TTG_CXX_COROUTINE_NAMESPACE::suspend_always final_suspend() noexcept {
ttg::suspend_always final_suspend() noexcept {
m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
return {};
}

template<typename... Ts>
TTG_CXX_COROUTINE_NAMESPACE::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
/* TODO: are we allowed to not suspend here and launch the kernel directly? */
m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
Expand Down
8 changes: 4 additions & 4 deletions ttg/ttg/madness/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ namespace ttg_madness {
ttg::abort();
} else { // resume suspended coroutine
#ifdef TTG_HAS_COROUTINE
auto ret = static_cast<ttg::resumable_task>(ttg::coroutine_handle<>::from_address(suspended_task_address));
auto ret = static_cast<ttg::resumable_task>(ttg::coroutine_handle<ttg::resumable_task_state>::from_address(suspended_task_address));
assert(ret.ready());
ret.resume();
if (ret.completed()) {
Expand Down Expand Up @@ -388,14 +388,14 @@ namespace ttg_madness {
// so mark the events finished manually, parsec will rerun this task again and it should complete the second
// time
auto events =
static_cast<ttg::resumable_task>(ttg::coroutine_handle<>::from_address(suspended_task_address)).events();
static_cast<ttg::resumable_task>(ttg::coroutine_handle<ttg::resumable_task_state>::from_address(suspended_task_address)).events();
for (auto &event_ptr : events) {
event_ptr->finish();
}
assert(ttg::coroutine_handle<>::from_address(suspended_task_address).promise().ready());
assert(ttg::coroutine_handle<ttg::resumable_task_state>::from_address(suspended_task_address).promise().ready());

// resume the coroutine
auto ret = static_cast<ttg::resumable_task>(ttg::coroutine_handle<>::from_address(suspended_task_address));
auto ret = static_cast<ttg::resumable_task>(ttg::coroutine_handle<ttg::resumable_task_state>::from_address(suspended_task_address));
assert(ret.ready());
ret.resume();
if (ret.completed()) {
Expand Down
4 changes: 2 additions & 2 deletions ttg/ttg/make_tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class CallableWrapTTArgs
using op_return_type =
#ifdef TTG_HAS_COROUTINE
std::conditional_t<std::is_same_v<returnT, ttg::resumable_task>,
ttg::coroutine_handle<>,
ttg::coroutine_handle<ttg::resumable_task_state>,
#ifdef TTG_HAVE_DEVICE
std::conditional_t<std::is_same_v<returnT, ttg::device::Task>,
ttg::device::Task::base_type,
Expand All @@ -178,7 +178,7 @@ class CallableWrapTTArgs
if constexpr (!std::is_void_v<returnT>) { // protect from compiling for void returnT
#ifdef TTG_HAS_COROUTINE
if constexpr (std::is_same_v<returnT, ttg::resumable_task>) {
ttg::coroutine_handle<> coro_handle;
ttg::coroutine_handle<ttg::resumable_task_state> coro_handle;
// if task completed destroy it
if (ret.completed()) {
ret.destroy();
Expand Down
62 changes: 39 additions & 23 deletions ttg/ttg/util/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,29 @@

namespace ttg {

struct resumable_task_state;
// import std coroutine API into ttg namespace

// import coroutine_handle, with default promise redefined to resumable_task_state
template <typename Promise = resumable_task_state>
using suspend_always = TTG_CXX_COROUTINE_NAMESPACE::suspend_always;
using suspend_never = TTG_CXX_COROUTINE_NAMESPACE::suspend_never;
template <typename Promise>
using coroutine_handle = TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle<Promise>;

/// @defgroup resumable_task resumable_task coroutine

/// resumable_task is the original prototype TTG coroutine that awaits on generic events.
/// There is no proper support for it by TTG runtimes, but it can be useful for understanding how
/// coroutines work with TTG and potentially in the future as a model for universal resumable tasks

/// @{

// fwd-declares

struct resumable_task_state;

template <std::size_t N>
struct resumable_task_events;

/// represents a one-time event
/// represents a generic one-time event
struct event {
void finish() { finished_ = true; }

Expand All @@ -34,15 +47,15 @@ namespace ttg {
};

/// task that can be resumed after some events occur
struct resumable_task : public ttg::coroutine_handle<> {
using base_type = ttg::coroutine_handle<>;
struct resumable_task : public ttg::coroutine_handle<resumable_task_state> {
using base_type = ttg::coroutine_handle<resumable_task_state>;

/// these are members mandated by the promise_type concept
///@{
/// @name members mandated by the promise_type concept
/// @{

using promise_type = struct resumable_task_state;

///@}
/// @}

resumable_task(base_type base) : base_type(std::move(base)) {}

Expand All @@ -69,28 +82,28 @@ namespace ttg {
resumable_task_state& operator=(resumable_task_state&&) = delete;

constexpr static inline std::size_t MaxNumEvents = 20;
using handle_type = coroutine_handle<>;
using handle_type = coroutine_handle<resumable_task_state>;

/// these are members mandated by the promise_type concept
///@{
/// @name members mandated by the promise_type concept
/// @{

resumable_task get_return_object() { return resumable_task{handle_type::from_promise(*this)}; }

/// @note start task eagerly
TTG_CXX_COROUTINE_NAMESPACE::suspend_never initial_suspend() noexcept { return {}; }
suspend_never initial_suspend() noexcept { return {}; }

/// @note suspend task before destroying it so the runtime can know that the task is completed
TTG_CXX_COROUTINE_NAMESPACE::suspend_always final_suspend() noexcept {
suspend_always final_suspend() noexcept {
completed_ = true;
return {};
}
void return_void() {}
void unhandled_exception() {}

///@}
/// @}

/// these are optional members of the promise_type concept
///@{
/// @name optional members of the promise_type concept
/// @{

// these can be used to use optional storage provided by the runtime (e.g. part of the runtime's task data struct)
// N.B. the existing buffer must be passed to operator new via TLS
Expand All @@ -105,7 +118,7 @@ namespace ttg {
// ::operator delete(ptr, size);
// }

///@}
/// @}

/// @return true if ready to resume
constexpr bool ready() const {
Expand Down Expand Up @@ -160,12 +173,12 @@ namespace ttg {
template <typename... Events>
constexpr resumable_task_events(Events&&... events) : events_{(&events)...} {}

/// these are members mandated by the Awaiter concept
///@{
/// @name members mandated by the Awaiter concept
/// @{

constexpr bool await_ready() const { return await_ready(std::make_index_sequence<N>{}); }

void await_suspend(coroutine_handle<> pending_task) {
void await_suspend(coroutine_handle<resumable_task_state> pending_task) {
pending_task_ = pending_task;
pending_task_.promise().set_events(events_);
}
Expand All @@ -177,18 +190,21 @@ namespace ttg {
}
}

///@}
/// @}

private:
std::array<event*, N> events_;
coroutine_handle<> pending_task_;
coroutine_handle<resumable_task_state> pending_task_;
}; // resumable_task_events

// deduce the number of events properly
template <typename... Events>
resumable_task_events(Events&&...) -> resumable_task_events<sizeof...(Events)>;

static_assert(resumable_task_events<0>{}.await_ready() == true);

/// @}

} // namespace ttg

#endif // TTG_UTIL_COROUTINE_H

0 comments on commit a4fbc20

Please sign in to comment.