Skip to content

Commit

Permalink
Merge pull request #304 from devreal/flexible-device-inputs
Browse files Browse the repository at this point in the history
Flexible device inputs
  • Loading branch information
therault authored Dec 12, 2024
2 parents 59e7452 + 6426016 commit c12d41d
Show file tree
Hide file tree
Showing 21 changed files with 397 additions and 86 deletions.
15 changes: 14 additions & 1 deletion tests/unit/device_coro.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,26 @@ TEST_CASE("Device", "coro") {
#endif // TTG_HAVE_CUDA
}
};

auto tt = ttg::make_tt<ttg::ExecutionSpace::CUDA>(fn, ttg::edges(edge), ttg::edges(edge),
"device_task", {"edge_in"}, {"edge_out"});
ttg::make_graph_executable(tt);
if (ttg::default_execution_context().rank() == 0) tt->invoke(0, value_t{});
ttg::ttg_fence(ttg::default_execution_context());
}

SECTION("empty-select") {
ttg::Edge<void, void> edge;
auto fn = []() -> ttg::device::Task {
co_await ttg::device::select();
/* nothing else to do */
};
auto tt = ttg::make_tt<ttg::ExecutionSpace::CUDA>(fn, ttg::edges(edge), ttg::edges(),
"device_task", {"edge_in"}, {"edge_out"});
ttg::make_graph_executable(tt);
tt->invoke();
ttg::ttg_fence(ttg::default_execution_context());
};

}

#endif // TTG_IMPL_DEVICE_SUPPORT
1 change: 1 addition & 0 deletions ttg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ if (TARGET MADworld)
set(ttg-mad-headers
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/buffer.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/device.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/devicefunc.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/fwd.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/import.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/ttg.h
Expand Down
12 changes: 6 additions & 6 deletions ttg/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
#include "ttg/config.h"
#include "ttg/fwd.h"

#if defined(TTG_USE_PARSEC)
#include "ttg/parsec/ttg.h"
#elif defined(TTG_USE_MADNESS)
#include "ttg/madness/ttg.h"
#endif // TTG_USE_{PARSEC|MADNESS}

#include "ttg/runtimes.h"
#include "ttg/util/demangle.h"
#include "ttg/util/hash.h"
Expand Down Expand Up @@ -37,12 +43,6 @@
#include "ttg/device/device.h"
#include "ttg/device/task.h"

#if defined(TTG_USE_PARSEC)
#include "ttg/parsec/ttg.h"
#elif defined(TTG_USE_MADNESS)
#include "ttg/madness/ttg.h"
#endif // TTG_USE_{PARSEC|MADNESS}

// these headers use the default backend
#include "ttg/run.h"

Expand Down
1 change: 1 addition & 0 deletions ttg/ttg/base/tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "ttg/base/terminal.h"
#include "ttg/util/demangle.h"
#include "ttg/util/trace.h"

namespace ttg {

Expand Down
20 changes: 20 additions & 0 deletions ttg/ttg/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,32 @@
#include <memory>

#include "ttg/fwd.h"
#include "ttg/util/meta.h"

namespace ttg {

template<typename T, typename Allocator = std::allocator<std::decay_t<T>>>
using Buffer = TTG_IMPL_NS::Buffer<T, Allocator>;

namespace meta {

/* Specialize some traits */

template<typename T, typename A>
struct is_buffer<ttg::Buffer<T, A>> : std::true_type
{ };

template<typename T, typename A>
struct is_buffer<const ttg::Buffer<T, A>> : std::true_type
{ };

/* buffers are const if their value types are const */
template<typename T, typename A>
struct is_const<ttg::Buffer<T, A>> : std::is_const<T>
{ };

} // namespace meta

} // namespace ttg

#endif // TTG_buffer_H
4 changes: 4 additions & 0 deletions ttg/ttg/device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ namespace ttg::device {
bool is_invalid() const {
return (m_space == ttg::ExecutionSpace::Invalid);
}

static Device host() {
return {};
}
};
} // namespace ttg::device

Expand Down
79 changes: 76 additions & 3 deletions ttg/ttg/device/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,80 @@
#include <type_traits>
#include <span>


#include "ttg/fwd.h"
#include "ttg/impl_selector.h"
#include "ttg/ptr.h"
#include "ttg/devicescope.h"

#ifdef TTG_HAVE_COROUTINE

namespace ttg::device {

namespace detail {

struct device_input_data_t {
using impl_data_t = decltype(TTG_IMPL_NS::buffer_data(std::declval<ttg::Buffer<int>>()));

device_input_data_t(impl_data_t data, ttg::scope scope, bool isconst, bool isscratch)
: impl_data(data), scope(scope), is_const(isconst), is_scratch(isscratch)
{ }
impl_data_t impl_data;
ttg::scope scope;
bool is_const;
bool is_scratch;
};

template <typename... Ts>
struct to_device_t {
std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
};

/* extract buffer information from to_device_t */
template<typename... Ts, std::size_t... Is>
auto extract_buffer_data(detail::to_device_t<Ts...>& a, std::index_sequence<Is...>) {
using arg_types = std::tuple<Ts...>;
return std::array<device_input_data_t, sizeof...(Ts)>{
device_input_data_t{TTG_IMPL_NS::buffer_data(std::get<Is>(a.ties)),
std::get<Is>(a.ties).scope(),
ttg::meta::is_const_v<std::tuple_element_t<Is, arg_types>>,
ttg::meta::is_devicescratch_v<std::tuple_element_t<Is, arg_types>>}...};
}
} // namespace detail

struct Input {
private:
std::vector<detail::device_input_data_t> m_data;

public:
Input() { }
template<typename... Args>
Input(Args&&... args)
: m_data{{TTG_IMPL_NS::buffer_data(args), args.scope(),
std::is_const_v<std::remove_reference_t<Args>>,
ttg::meta::is_devicescratch_v<std::decay_t<Args>>}...}
{ }

template<typename T>
void add(T&& v) {
using type = std::remove_reference_t<T>;
m_data.emplace_back(TTG_IMPL_NS::buffer_data(v), v.scope(), std::is_const_v<type>,
ttg::meta::is_devicescratch_v<type>);
}

ttg::span<detail::device_input_data_t> span() {
return ttg::span(m_data);
}
};

namespace detail {
// overload for Input
template <>
struct to_device_t<Input> {
Input& input;
};
} // namespace detail

/**
* Select a device to execute on based on the provided buffer and scratchspace objects.
* Returns an object that should be awaited on using \c co_await.
Expand All @@ -33,6 +92,11 @@ namespace ttg::device {
return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
}

[[nodiscard]]
inline auto select(Input& input) {
return detail::to_device_t<Input>{input};
}

namespace detail {

enum ttg_device_coro_state {
Expand Down Expand Up @@ -446,8 +510,9 @@ namespace ttg::device {
ttg::Runtime Runtime = ttg::ttg_runtime>
inline detail::send_t broadcast(rangeT &&keylist, valueT &&value) {
ttg::detail::value_copy_handler<Runtime> copy_handler;
return detail::send_t{broadcast_coro<i>(std::tie(keylist), copy_handler(std::forward<valueT>(value)),
std::move(copy_handler))};
return detail::send_t{detail::broadcast_coro<i>(std::tie(keylist),
copy_handler(std::forward<valueT>(value)),
std::move(copy_handler))};
}

/* overload with explicit terminals and keylist passed by const reference */
Expand Down Expand Up @@ -554,7 +619,15 @@ namespace ttg::device {

template<typename... Ts>
ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
auto arr = detail::extract_buffer_data(a, std::make_index_sequence<sizeof...(Ts)>{});
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(ttg::span(arr)));
/* TODO: are we allowed to not suspend here and launch the kernel directly? */
m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
return {};
}

ttg::suspend_always await_transform(detail::to_device_t<Input>&& a) {
bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.input.span()));
/* TODO: are we allowed to not suspend here and launch the kernel directly? */
m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
return {};
Expand Down
19 changes: 19 additions & 0 deletions ttg/ttg/devicescratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "ttg/devicescope.h"
#include "ttg/fwd.h"
#include "ttg/util/meta.h"

namespace ttg {

Expand All @@ -14,6 +15,24 @@ auto make_scratch(T* val, ttg::scope scope, std::size_t count = 1) {
return devicescratch<T>(val, scope, count);
}

namespace meta {

/* Specialize some traits */

template<typename T>
struct is_devicescratch<ttg::devicescratch<T>> : std::true_type
{ };

template<typename T>
struct is_devicescratch<const ttg::devicescratch<T>> : std::true_type
{ };

template<typename T>
struct is_const<ttg::devicescratch<T>> : std::is_const<T>
{ };

} // namespace meta

} // namespace ttg

#endif // TTG_DEVICESCRATCH_H
12 changes: 12 additions & 0 deletions ttg/ttg/madness/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "ttg/serialization/traits.h"

#include "ttg/device/device.h"

namespace ttg_madness {

/// A runtime-managed buffer mirrored between host and device memory
Expand Down Expand Up @@ -110,6 +112,12 @@ struct Buffer : private Allocator {
/* no-op */
}


bool is_current_on(ttg::device::Device dev) const {
assert(is_valid());
return true;
}

/* Get the owner device ID, i.e., the last updated
* device buffer. */
ttg::device::Device get_owner_device() const {
Expand Down Expand Up @@ -178,6 +186,10 @@ struct Buffer : private Allocator {
throw std::runtime_error("not implemented yet");
}

bool empty() const {
return (m_host_data == nullptr);
}

/* TODO: can we do this automatically?
* Pin the memory on all devices we currently track.
* Pinned memory won't be released by PaRSEC and can be used
Expand Down
30 changes: 30 additions & 0 deletions ttg/ttg/madness/devicefunc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef TTG_MAD_DEVICEFUNC_H
#define TTG_MAD_DEVICEFUNC_H

#include "ttg/madness/buffer.h"

namespace ttg_madness {

template<typename T, typename A>
auto buffer_data(const Buffer<T, A>& buffer) {
/* for now return the internal pointer, should be adapted if ever relevant for madness */
return buffer.current_device_ptr();
}

template<typename... Views>
inline bool register_device_memory(std::tuple<Views&...> &views)
{
/* nothing to do here */
return true;
}

template<typename T, std::size_t N>
inline bool register_device_memory(const ttg::span<T, N>& span)
{
/* nothing to do here */
return true;
}

} // namespace ttg_madness

#endif // TTG_MAD_DEVICEFUNC_H
4 changes: 4 additions & 0 deletions ttg/ttg/madness/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "ttg/fwd.h"
#include "ttg/util/typelist.h"
#include "ttg/util/span.h"

#include <future>

Expand Down Expand Up @@ -71,6 +72,9 @@ namespace ttg_madness {
template<typename... Views>
inline bool register_device_memory(std::tuple<Views&...> &views);

template<typename T, std::size_t N>
inline bool register_device_memory(const ttg::span<T, N>& span);

template<typename... Buffer>
inline void post_device_out(std::tuple<Buffer&...> &b);

Expand Down
10 changes: 8 additions & 2 deletions ttg/ttg/madness/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

#include "ttg/impl_selector.h"

/* include ttg header to make symbols available in case this header is included directly */
#include "../../ttg.h"
#include "ttg/base/keymap.h"
#include "ttg/base/tt.h"
#include "ttg/func.h"
#include "ttg/madness/device.h"
#include "ttg/madness/devicefunc.h"

/* needed for make_tt */
#include "ttg/device/task.h"

#include "ttg/runtimes.h"
#include "ttg/tt.h"
#include "ttg/util/bug.h"
Expand All @@ -27,6 +30,9 @@
#include "ttg/world.h"
#include "ttg/coroutine.h"

/* include ttg header to make symbols available in case this header is included directly */
#include "../../ttg.h"

#include <array>
#include <cassert>
#include <functional>
Expand Down
4 changes: 2 additions & 2 deletions ttg/ttg/make_tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ class CallableWrapTT

auto invoke_func_empty_tuple = [&](auto&&... args){
if constexpr(funcT_receives_input_tuple) {
invoke_func_handle_ret(std::tuple<>{}, std::forward<decltype(args)>(args)...);
return invoke_func_handle_ret(std::tuple<>{}, std::forward<decltype(args)>(args)...);
} else {
invoke_func_handle_ret(std::forward<decltype(args)>(args)...);
return invoke_func_handle_ret(std::forward<decltype(args)>(args)...);
}
};

Expand Down
Loading

0 comments on commit c12d41d

Please sign in to comment.