Skip to content

Commit

Permalink
Fix usage of ttg::persistent()
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Aug 14, 2024
1 parent f31fedb commit 1e34827
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
19 changes: 11 additions & 8 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,14 @@ static constexpr ttg::ExecutionSpace space = ttg::ExecutionSpace::Host;
#if defined(BLOCK_SPARSE_GEMM) && defined(BTAS_IS_USABLE)
using scalar_t = double;

#if HAVE_SPMM_DEVICE
using blk_t = DeviceTensor<scalar_t, btas::DEFAULT::range,
btas::mohndle<btas::varray<scalar_t, TiledArray::device_pinned_allocator<scalar_t>>,
btas::Handle::shared_ptr>>;
btas::mohndle<btas::varray<scalar_t
#if HAVE_SPMM_DEVICE
, TiledArray::device_pinned_allocator<scalar_t>
#else // HAVE_SPMM_DEVICE
using blk_t = btas::Tensor<scalar_t, btas::DEFAULT::range, btas::mohndle<btas::varray<scalar_t>, btas::Handle::shared_ptr>>;
#endif // HAVE_SPMM_DEVICE
>,
#endif // HAVE_SPMM_DEVICE
btas::Handle::shared_ptr>>;
//#include <atomic>
//static std::atomic<uint64_t> reduce_count = 0;

Expand Down Expand Up @@ -234,7 +235,9 @@ class Read_SpMatrix : public TT<Key<3>,
const auto i = it.row();
// IF the receiver uses the same keymap, these sends are local
if (rank == this->ij_keymap_(Key<2>(std::initializer_list<long>({i, j})))) {
::send<0>(Key<2>(std::initializer_list<long>({i, j})), it.value(), out);
::send<0>(Key<2>(std::initializer_list<long>({i, j})),
ttg::persistent(it.value()),
out);
}
}
}
Expand All @@ -258,14 +261,14 @@ class Write_SpMatrix : public TT<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, ttg:
, write_back_(write_back)
{ }

void op(const Key<2> &key, typename baseT::input_values_tuple_type &&elem, std::tuple<> &) {
void op(const Key<2> &key, typename baseT::input_refs_tuple_type &&elem, std::tuple<> &) {
if (write_back_) {
std::lock_guard<std::mutex> lock(mtx_);
ttg::trace("rank =", default_execution_context().rank(),
"/ thread_id =", reinterpret_cast<std::uintptr_t>(pthread_self()), "spmm.cc Write_SpMatrix wrote {",
key[0], ",", key[1], "} = ", baseT::template get<0>(elem), " in ", static_cast<void *>(&matrix_),
" with mutex @", static_cast<void *>(&mtx_), " for object @", static_cast<void *>(this));
values_.emplace_back(key[0], key[1], baseT::template get<0>(elem));
values_.emplace_back(key[0], key[1], std::move(baseT::template get<0>(elem)));
}
}

Expand Down
2 changes: 1 addition & 1 deletion ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -4577,7 +4577,7 @@ struct ttg::detail::value_copy_handler<ttg::Runtime::PaRSEC> {
bool inserted = ttg_parsec::detail::add_copy_to_task(copy, caller);
assert(inserted);
copy_to_remove = copy; // we want to remove the copy from the task once done sending
do_release = false; // we don't release the copy since we didn't allocate it
do_release = true; // we don't release the copy since we didn't allocate it
copy->add_ref(); // add a reference so that TTG does not attempt to delete this object
}
return vref.value_ref;
Expand Down

0 comments on commit 1e34827

Please sign in to comment.