From 5fa0a71f56f1f345b5e93375fb69ae96135dc2f1 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Mon, 2 Sep 2024 17:03:36 +0200 Subject: [PATCH] Add 1d indexing --- include/clad/Differentiator/KokkosBuiltins.h | 26 ++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/include/clad/Differentiator/KokkosBuiltins.h b/include/clad/Differentiator/KokkosBuiltins.h index c205cf9c1..e050a0fd1 100644 --- a/include/clad/Differentiator/KokkosBuiltins.h +++ b/include/clad/Differentiator/KokkosBuiltins.h @@ -48,6 +48,19 @@ constructor_reverse_forw( ::Kokkos::View( "_diff_" + name, idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7)}; } +template +void constructor_pullback(::Kokkos::View* v, + const ::std::string& name, const size_t& idx0, + const size_t& idx1, const size_t& idx2, + const size_t& idx3, const size_t& idx4, + const size_t& idx5, const size_t& idx6, + const size_t& idx7, + ::Kokkos::View* d_v, + const ::std::string* /*d_name*/, + const size_t& /*d_idx0*/, const size_t* /*d_idx1*/, + const size_t* /*d_idx2*/, const size_t* /*d_idx3*/, + const size_t* /*d_idx4*/, const size_t* /*d_idx5*/, + const size_t* /*d_idx6*/, const size_t* /*d_idx7*/) {} /// View indexing template @@ -125,6 +138,19 @@ operator_call_pushforward(const View* v, Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, return {(*v)(i0, i1, i2, i3, i4, i5, i6, i7), (*d_v)(i0, i1, i2, i3, i4, i5, i6, i7)}; } +template +clad::ValueAndAdjoint operator_call_reverse_forw( + const ::Kokkos::View* v, Idx i0, + const ::Kokkos::View* d_v, Idx /*d_i0*/) { + return {(*v)(i0), (*d_v)(i0)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx i0, Diff d_y, + ::Kokkos::View* d_v, + Idx* /*d_i0*/) { + (*d_v)(i0) += d_y; +} } // namespace class_functions /// Kokkos functions (view utils)