diff --git a/ttg/ttg/tt.h b/ttg/ttg/tt.h index 6b3972984..1ca33fe9e 100644 --- a/ttg/ttg/tt.h +++ b/ttg/ttg/tt.h @@ -178,23 +178,28 @@ namespace ttg { #ifndef TTG_PROCESS_TT_OP_RETURN #ifdef TTG_HAVE_COROUTINE -#define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) \ - { \ - using return_type = decltype(invoke); \ - if constexpr (std::is_same_v) { \ - invoke; \ - id = ttg::TaskCoroutineID::Invalid; \ - } else { \ - auto coro_return = invoke; \ - if constexpr (std::is_same_v>) \ - id = ttg::TaskCoroutineID::ResumableTask; \ - else if constexpr (std::is_same_v>) \ - id = ttg::TaskCoroutineID::DeviceTask; \ - else \ - std::abort(); \ - result = coro_return.address(); \ - } \ +#define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) \ + { \ + using return_type = decltype(invoke); \ + if constexpr (std::is_same_v) { \ + invoke; \ + id = ttg::TaskCoroutineID::Invalid; \ + } else { \ + auto coro_return = invoke; \ + static_assert(std::is_same_v || \ + std::is_base_of_v, decltype(coro_return)>|| \ + std::is_base_of_v, \ + decltype(coro_return)>); \ + if constexpr (std::is_base_of_v, decltype(coro_return)>) \ + id = ttg::TaskCoroutineID::ResumableTask; \ + else if constexpr (std::is_base_of_v< \ + ttg::coroutine_handle, \ + decltype(coro_return)>) \ + id = ttg::TaskCoroutineID::DeviceTask; \ + else \ + std::abort(); \ + result = coro_return.address(); \ + } \ } #else #define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) invoke