Skip to content

Commit

Permalink
Merge pull request #261 from devreal/spmm-input-refs-not-values
Browse files Browse the repository at this point in the history
SPMM: take input by reference, not by value
  • Loading branch information
evaleev committed Jul 12, 2023
2 parents d59730e + 27af10b commit 3f9087c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
24 changes: 12 additions & 12 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class Write_SpMatrix : public TT<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, ttg:
Write_SpMatrix(SpMatrix<Blk> &matrix, Edge<Key<2>, Blk> &in, Keymap2 &&ij_keymap)
: baseT(edges(in), edges(), "write_spmatrix", {"Cij"}, {}, ij_keymap), matrix_(matrix) {}

void op(const Key<2> &key, typename baseT::input_values_tuple_type &&elem, std::tuple<> &) {
void op(const Key<2> &key, typename baseT::input_refs_tuple_type &&elem, std::tuple<> &) {
std::lock_guard<std::mutex> lock(mtx_);
ttg::trace("rank =", default_execution_context().rank(),
"/ thread_id =", reinterpret_cast<std::uintptr_t>(pthread_self()), "spmm.cc Write_SpMatrix wrote {",
Expand Down Expand Up @@ -310,7 +310,7 @@ class SpMM25D {
, b_rowidx_to_colidx_(b_rowidx_to_colidx)
, ijk_keymap_(ijk_keymap) {}

void op(const Key<3> &ikp, typename baseT::input_values_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ijk) {
void op(const Key<3> &ikp, typename baseT::input_refs_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ijk) {
const auto i = ikp[0];
const auto k = ikp[1];
const auto p = ikp[2];
Expand All @@ -327,7 +327,7 @@ class SpMM25D {
ijk_keys.emplace_back(Key<3>({i, j, k}));
}
}
::broadcast<0>(ijk_keys, baseT::template get<0>(a_ik), a_ijk);
::broadcast<0>(ijk_keys, std::move(baseT::template get<0>(a_ik)), a_ijk);
}

private:
Expand All @@ -346,7 +346,7 @@ class SpMM25D {
, b_rowidx_to_colidx_(b_rowidx_to_colidx)
, ijk_keymap_(ijk_keymap) {}

void op(const Key<2> &ik, typename baseT::input_values_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ikp) {
void op(const Key<2> &ik, typename baseT::input_refs_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ikp) {
const auto i = ik[0];
const auto k = ik[1];
ttg::trace("BcastA(", i, ", ", k, ")");
Expand All @@ -364,7 +364,7 @@ class SpMM25D {
procmap[p] = true;
}
}
::broadcast<0>(ikp_keys, baseT::template get<0>(a_ik), a_ikp);
::broadcast<0>(ikp_keys, std::move(baseT::template get<0>(a_ik)), a_ikp);
}

private:
Expand All @@ -385,7 +385,7 @@ class SpMM25D {
, a_colidx_to_rowidx_(a_colidx_to_rowidx)
, ijk_keymap_(ijk_keymap) {}

void op(const Key<3> &kjp, typename baseT::input_values_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_ijk) {
void op(const Key<3> &kjp, typename baseT::input_refs_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_ijk) {
const auto k = kjp[0];
const auto j = kjp[1];
const auto p = kjp[2];
Expand All @@ -401,7 +401,7 @@ class SpMM25D {
ijk_keys.emplace_back(Key<3>({i, j, k}));
}
}
::broadcast<0>(ijk_keys, baseT::template get<0>(b_kj), b_ijk);
::broadcast<0>(ijk_keys, std::move(baseT::template get<0>(b_kj)), b_ijk);
}

private:
Expand All @@ -420,7 +420,7 @@ class SpMM25D {
, a_colidx_to_rowidx_(a_colidx_to_rowidx)
, ijk_keymap_(ijk_keymap) {}

void op(const Key<2> &kj, typename baseT::input_values_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_kjp) {
void op(const Key<2> &kj, typename baseT::input_refs_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_kjp) {
const auto k = kj[0];
const auto j = kj[1];
// broadcast b_kj to all processors which will contain at least one c_ij such that a_ik exists
Expand All @@ -437,7 +437,7 @@ class SpMM25D {
procmap[p] = true;
}
}
::broadcast<0>(kjp_keys, baseT::template get<0>(b_kj), b_kjp);
::broadcast<0>(kjp_keys, std::move(baseT::template get<0>(b_kj)), b_kjp);
}

private:
Expand Down Expand Up @@ -489,7 +489,7 @@ class SpMM25D {
}
}

void op(const Key<3> &ijk, typename baseT::input_values_tuple_type &&_ijk,
void op(const Key<3> &ijk, typename baseT::input_refs_tuple_type &&_ijk,
std::tuple<Out<Key<2>, Blk>, Out<Key<3>, Blk>> &result) {
const auto i = ijk[0];
const auto j = ijk[1];
Expand Down Expand Up @@ -643,9 +643,9 @@ class SpMM25D {
ReduceC(Edge<Key<2>, Blk> &c_ij_p, Edge<Key<2>, Blk> &c_ij, const Keymap2 &ij_keymap)
: baseT(edges(c_ij_p), edges(c_ij), "SpMM25D::reduce_c", {"c_ij(p)"}, {"c_ij"}, ij_keymap) {}

void op(const Key<2> &ij, typename baseT::input_values_tuple_type &&c_ij_p, std::tuple<Out<Key<2>, Blk>> &c_ij) {
void op(const Key<2> &ij, typename baseT::input_refs_tuple_type &&c_ij_p, std::tuple<Out<Key<2>, Blk>> &c_ij) {
ttg::trace("ReduceC(", ij[0], ", ", ij[1], ")");
::send<0>(ij, baseT::template get<0>(c_ij_p), c_ij);
::send<0>(ij, std::move(baseT::template get<0>(c_ij_p)), c_ij);
}
}; // class ReduceC

Expand Down
2 changes: 1 addition & 1 deletion ttg/ttg/terminal.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ namespace ttg {
std::enable_if_t<!meta::is_void_v<Value>, void> broadcast(const rangeT &keylist, const Value &value) {
if (broadcast_callback) {
if constexpr (ttg::meta::is_iterable_v<rangeT>) {
broadcast_callback(ttg::span(&(*std::begin(keylist)), std::distance(std::begin(keylist), std::end(keylist))),
broadcast_callback(ttg::span<const keyT>(&(*std::begin(keylist)), std::distance(std::begin(keylist), std::end(keylist))),
value);
} else {
/* got something we cannot iterate over (single element?) so put one element in the span */
Expand Down

0 comments on commit 3f9087c

Please sign in to comment.