From 43a0f129b97ed705ef419ac565821f1ac4c533c4 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Tue, 4 Jun 2024 16:12:04 -0400 Subject: [PATCH] Add static assert for task return and check for derived of ttg::coroutine_handle TTs may return ttg::coroutine_handle (make_tt) or ttg::device::Task (custom TTs) Signed-off-by: Joseph Schuchart --- ttg/ttg/tt.h | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/ttg/ttg/tt.h b/ttg/ttg/tt.h index 6b3972984..1ca33fe9e 100644 --- a/ttg/ttg/tt.h +++ b/ttg/ttg/tt.h @@ -178,23 +178,28 @@ namespace ttg { #ifndef TTG_PROCESS_TT_OP_RETURN #ifdef TTG_HAVE_COROUTINE -#define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) \ - { \ - using return_type = decltype(invoke); \ - if constexpr (std::is_same_v) { \ - invoke; \ - id = ttg::TaskCoroutineID::Invalid; \ - } else { \ - auto coro_return = invoke; \ - if constexpr (std::is_same_v>) \ - id = ttg::TaskCoroutineID::ResumableTask; \ - else if constexpr (std::is_same_v>) \ - id = ttg::TaskCoroutineID::DeviceTask; \ - else \ - std::abort(); \ - result = coro_return.address(); \ - } \ +#define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) \ + { \ + using return_type = decltype(invoke); \ + if constexpr (std::is_same_v) { \ + invoke; \ + id = ttg::TaskCoroutineID::Invalid; \ + } else { \ + auto coro_return = invoke; \ + static_assert(std::is_same_v || \ + std::is_base_of_v, decltype(coro_return)>|| \ + std::is_base_of_v, \ + decltype(coro_return)>); \ + if constexpr (std::is_base_of_v, decltype(coro_return)>) \ + id = ttg::TaskCoroutineID::ResumableTask; \ + else if constexpr (std::is_base_of_v< \ + ttg::coroutine_handle, \ + decltype(coro_return)>) \ + id = ttg::TaskCoroutineID::DeviceTask; \ + else \ + std::abort(); \ + result = coro_return.address(); \ + } \ } #else #define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) invoke