Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible device inputs #304

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion tests/unit/device_coro.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,26 @@ TEST_CASE("Device", "coro") {
#endif // TTG_HAVE_CUDA
}
};

auto tt = ttg::make_tt<ttg::ExecutionSpace::CUDA>(fn, ttg::edges(edge), ttg::edges(edge),
"device_task", {"edge_in"}, {"edge_out"});
ttg::make_graph_executable(tt);
if (ttg::default_execution_context().rank() == 0) tt->invoke(0, value_t{});
ttg::ttg_fence(ttg::default_execution_context());
}

SECTION("empty-select") {
ttg::Edge<void, void> edge;
auto fn = []() -> ttg::device::Task {
co_await ttg::device::select();
/* nothing else to do */
};
auto tt = ttg::make_tt<ttg::ExecutionSpace::CUDA>(fn, ttg::edges(edge), ttg::edges(),
"device_task", {"edge_in"}, {"edge_out"});
ttg::make_graph_executable(tt);
tt->invoke();
ttg::ttg_fence(ttg::default_execution_context());
};

}

#endif // TTG_IMPL_DEVICE_SUPPORT
1 change: 1 addition & 0 deletions ttg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ if (TARGET MADworld)
set(ttg-mad-headers
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/buffer.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/device.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/devicefunc.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/fwd.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/import.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/ttg.h
Expand Down
12 changes: 6 additions & 6 deletions ttg/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
#include "ttg/config.h"
#include "ttg/fwd.h"

#if defined(TTG_USE_PARSEC)
#include "ttg/parsec/ttg.h"
#elif defined(TTG_USE_MADNESS)
#include "ttg/madness/ttg.h"
#endif // TTG_USE_{PARSEC|MADNESS}

#include "ttg/runtimes.h"
#include "ttg/util/demangle.h"
#include "ttg/util/hash.h"
Expand Down Expand Up @@ -37,12 +43,6 @@
#include "ttg/device/device.h"
#include "ttg/device/task.h"

#if defined(TTG_USE_PARSEC)
#include "ttg/parsec/ttg.h"
#elif defined(TTG_USE_MADNESS)
#include "ttg/madness/ttg.h"
#endif // TTG_USE_{PARSEC|MADNESS}

// these headers use the default backend
#include "ttg/run.h"

Expand Down
1 change: 1 addition & 0 deletions ttg/ttg/base/tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "ttg/base/terminal.h"
#include "ttg/util/demangle.h"
#include "ttg/util/trace.h"

namespace ttg {

Expand Down
20 changes: 20 additions & 0 deletions ttg/ttg/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,32 @@
#include <memory>

#include "ttg/fwd.h"
#include "ttg/util/meta.h"

namespace ttg {

template<typename T, typename Allocator = std::allocator<std::decay_t<T>>>
using Buffer = TTG_IMPL_NS::Buffer<T, Allocator>;

namespace meta {

/* Specialize some traits */

template<typename T, typename A>
struct is_buffer<ttg::Buffer<T, A>> : std::true_type
{ };

template<typename T, typename A>
struct is_buffer<const ttg::Buffer<T, A>> : std::true_type
{ };

/* buffers are const if their value types are const */
template<typename T, typename A>
struct is_const<ttg::Buffer<T, A>> : std::is_const<T>
{ };

} // namespace meta

} // namespace ttg

#endif // TTG_buffer_H
4 changes: 4 additions & 0 deletions ttg/ttg/device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ namespace ttg::device {
bool is_invalid() const {
return (m_space == ttg::ExecutionSpace::Invalid);
}

static Device host() {
return {};
}
};
} // namespace ttg::device

Expand Down
79 changes: 76 additions & 3 deletions ttg/ttg/device/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,80 @@
#include <type_traits>
#include <span>


#include "ttg/fwd.h"
#include "ttg/impl_selector.h"
#include "ttg/ptr.h"
#include "ttg/devicescope.h"

#ifdef TTG_HAVE_COROUTINE

namespace ttg::device {

namespace detail {

struct device_input_data_t {
using impl_data_t = decltype(TTG_IMPL_NS::buffer_data(std::declval<ttg::Buffer<int>>()));

device_input_data_t(impl_data_t data, ttg::scope scope, bool isconst, bool isscratch)
: impl_data(data), scope(scope), is_const(isconst), is_scratch(isscratch)
{ }
impl_data_t impl_data;
ttg::scope scope;
bool is_const;
bool is_scratch;
};

template <typename... Ts>
struct to_device_t {
std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
};

/* extract buffer information from to_device_t */
template<typename... Ts, std::size_t... Is>
auto extract_buffer_data(detail::to_device_t<Ts...>& a, std::index_sequence<Is...>) {
using arg_types = std::tuple<Ts...>;
return std::array<device_input_data_t, sizeof...(Ts)>{
device_input_data_t{TTG_IMPL_NS::buffer_data(std::get<Is>(a.ties)),
std::get<Is>(a.ties).scope(),
ttg::meta::is_const_v<std::tuple_element_t<Is, arg_types>>,
ttg::meta::is_devicescratch_v<std::tuple_element_t<Is, arg_types>>}...};
}
} // namespace detail

struct Input {
private:
std::vector<detail::device_input_data_t> m_data;

public:
Input() { }
template<typename... Args>
Input(Args&&... args)
: m_data{{TTG_IMPL_NS::buffer_data(args), args.scope(),
std::is_const_v<std::remove_reference_t<Args>>,
ttg::meta::is_devicescratch_v<std::decay_t<Args>>}...}
{ }

template<typename T>
void add(T&& v) {
using type = std::remove_reference_t<T>;
m_data.emplace_back(TTG_IMPL_NS::buffer_data(v), v.scope(), std::is_const_v<type>,
ttg::meta::is_devicescratch_v<type>);
}

ttg::span<detail::device_input_data_t> span() {
return ttg::span(m_data);
}
};

namespace detail {
// overload for Input
template <>
struct to_device_t<Input> {
Input& input;
};
} // namespace detail

/**
* Select a device to execute on based on the provided buffer and scratchspace objects.
* Returns an object that should be awaited on using \c co_await.
Expand All @@ -33,6 +92,11 @@ namespace ttg::device {
return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
}

[[nodiscard]]
inline auto select(Input& input) {
return detail::to_device_t<Input>{input};
}

namespace detail {

enum ttg_device_coro_state {
Expand Down Expand Up @@ -448,8 +512,9 @@ namespace ttg::device {
ttg::Runtime Runtime = ttg::ttg_runtime>
inline detail::send_t broadcast(rangeT &&keylist, valueT &&value) {
ttg::detail::value_copy_handler<Runtime> copy_handler;
return detail::send_t{broadcast_coro<i>(std::tie(keylist), copy_handler(std::forward<valueT>(value)),
std::move(copy_handler))};
return detail::send_t{detail::broadcast_coro<i>(std::tie(keylist),
copy_handler(std::forward<valueT>(value)),
std::move(copy_handler))};
}

/* overload with explicit terminals and keylist passed by const reference */
Expand Down Expand Up @@ -556,7 +621,15 @@ namespace ttg::device {

template<typename... Ts>
ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
auto arr = detail::extract_buffer_data(a, std::make_index_sequence<sizeof...(Ts)>{});
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(ttg::span(arr)));
/* TODO: are we allowed to not suspend here and launch the kernel directly? */
m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
return {};
}

ttg::suspend_always await_transform(detail::to_device_t<Input>&& a) {
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.input.span()));
/* TODO: are we allowed to not suspend here and launch the kernel directly? */
m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
return {};
Expand Down
19 changes: 19 additions & 0 deletions ttg/ttg/devicescratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "ttg/devicescope.h"
#include "ttg/fwd.h"
#include "ttg/util/meta.h"

namespace ttg {

Expand All @@ -14,6 +15,24 @@ auto make_scratch(T* val, ttg::scope scope, std::size_t count = 1) {
return devicescratch<T>(val, scope, count);
}

namespace meta {

/* Specialize some traits */

template<typename T>
struct is_devicescratch<ttg::devicescratch<T>> : std::true_type
{ };

template<typename T>
struct is_devicescratch<const ttg::devicescratch<T>> : std::true_type
{ };

template<typename T>
struct is_const<ttg::devicescratch<T>> : std::is_const<T>
{ };

} // namespace meta

} // namespace ttg

#endif // TTG_DEVICESCRATCH_H
12 changes: 12 additions & 0 deletions ttg/ttg/madness/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "ttg/serialization/traits.h"

#include "ttg/device/device.h"

namespace ttg_madness {

/// A runtime-managed buffer mirrored between host and device memory
Expand Down Expand Up @@ -110,6 +112,12 @@ struct Buffer : private Allocator {
/* no-op */
}


bool is_current_on(ttg::device::Device dev) const {
assert(is_valid());
return true;
}

/* Get the owner device ID, i.e., the last updated
* device buffer. */
ttg::device::Device get_owner_device() const {
Expand Down Expand Up @@ -178,6 +186,10 @@ struct Buffer : private Allocator {
throw std::runtime_error("not implemented yet");
}

bool empty() const {
return (m_host_data == nullptr);
}

/* TODO: can we do this automatically?
* Pin the memory on all devices we currently track.
* Pinned memory won't be released by PaRSEC and can be used
Expand Down
30 changes: 30 additions & 0 deletions ttg/ttg/madness/devicefunc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef TTG_MAD_DEVICEFUNC_H
#define TTG_MAD_DEVICEFUNC_H

#include "ttg/madness/buffer.h"

namespace ttg_madness {

template<typename T, typename A>
auto buffer_data(const Buffer<T, A>& buffer) {
/* for now return the internal pointer, should be adapted if ever relevant for madness */
return buffer.current_device_ptr();
}

template<typename... Views>
inline bool register_device_memory(std::tuple<Views&...> &views)
{
/* nothing to do here */
return true;
}

template<typename T, std::size_t N>
inline bool register_device_memory(const ttg::span<T, N>& span)
{
/* nothing to do here */
return true;
}

} // namespace ttg_madness

#endif // TTG_MAD_DEVICEFUNC_H
4 changes: 4 additions & 0 deletions ttg/ttg/madness/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "ttg/fwd.h"
#include "ttg/util/typelist.h"
#include "ttg/util/span.h"

#include <future>

Expand Down Expand Up @@ -71,6 +72,9 @@ namespace ttg_madness {
template<typename... Views>
inline bool register_device_memory(std::tuple<Views&...> &views);

template<typename T, std::size_t N>
inline bool register_device_memory(const ttg::span<T, N>& span);

template<typename... Buffer>
inline void post_device_out(std::tuple<Buffer&...> &b);

Expand Down
10 changes: 8 additions & 2 deletions ttg/ttg/madness/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

#include "ttg/impl_selector.h"

/* include ttg header to make symbols available in case this header is included directly */
#include "../../ttg.h"
#include "ttg/base/keymap.h"
#include "ttg/base/tt.h"
#include "ttg/func.h"
#include "ttg/madness/device.h"
#include "ttg/madness/devicefunc.h"

/* needed for make_tt */
#include "ttg/device/task.h"

#include "ttg/runtimes.h"
#include "ttg/tt.h"
#include "ttg/util/bug.h"
Expand All @@ -27,6 +30,9 @@
#include "ttg/world.h"
#include "ttg/coroutine.h"

/* include ttg header to make symbols available in case this header is included directly */
#include "../../ttg.h"

#include <array>
#include <cassert>
#include <functional>
Expand Down
4 changes: 2 additions & 2 deletions ttg/ttg/make_tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ class CallableWrapTT

auto invoke_func_empty_tuple = [&](auto&&... args){
if constexpr(funcT_receives_input_tuple) {
invoke_func_handle_ret(std::tuple<>{}, std::forward<decltype(args)>(args)...);
return invoke_func_handle_ret(std::tuple<>{}, std::forward<decltype(args)>(args)...);
} else {
invoke_func_handle_ret(std::forward<decltype(args)>(args)...);
return invoke_func_handle_ret(std::forward<decltype(args)>(args)...);
}
};

Expand Down
Loading
Loading