Skip to content

Commit

Permalink
Merge pull request #191 from lwJi/derivs2_2d
Browse files Browse the repository at this point in the history
Correct vectorized case for derivs2_2d
  • Loading branch information
eschnett authored Aug 2, 2023
2 parents 7a2b30a + eefecce commit bf0cf9a
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions Derivs/src/derivs.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,16 @@ deriv2_2d(const TS var, const T dx, const T dy) {
// We need fewer ndyvars than without vectorizon: Instead of `(2 *
// deriv_order + 1) * vsize` scalars, we only need to calculate
// `(2 * deriv_order + 1) + (vsize - 1)` scalars
constexpr std::ptrdiff_t ndyvars =
div_ceil(2 * deriv_order + 1 + vsize - 1, vsize);
constexpr std::ptrdiff_t maxnpoints = deriv_order + 1 + vsize - 1;
constexpr std::ptrdiff_t ndyvars = div_ceil(maxnpoints, vsize);
std::array<R, ndyvars> dyvar;

for (std::ptrdiff_t n = 0; n < ndyvars; ++n) {
const std::ptrdiff_t di = n - deriv_order;
for (std::ptrdiff_t n = 0; n < maxnpoints; n += vsize) {
const std::ptrdiff_t di = n - deriv_order / 2;
// Skip the unused central point, but only if there is no vectorization
if (vsize == 1 && di == 0)
continue;
dyvar[n] = deriv1d<deriv_order>(
dyvar[div_floor(n, vsize)] = deriv1d<deriv_order>(
[&](int dj) CCTK_ATTRIBUTE_ALWAYS_INLINE {
#ifdef CCTK_DEBUG
assert(di >= -deriv_order / 2);
Expand All @@ -303,23 +303,23 @@ deriv2_2d(const TS var, const T dx, const T dy) {
assert(di <= +deriv_order / 2);
#endif
if constexpr (vsize == 1)
return scalar_dyvar[deriv_order + di];
return scalar_dyvar[deriv_order / 2 + di];
else
return loadu<R>(&scalar_dyvar[deriv_order + di]);
return loadu<R>(&scalar_dyvar[deriv_order / 2 + di]);
},
dx);

} else {
// Calculate y-derivative first
constexpr std::ptrdiff_t ndyvars = 2 * deriv_order + 1;
constexpr std::ptrdiff_t ndyvars = deriv_order + 1;
std::array<R, ndyvars> dyvar;
#ifdef CCTK_DEBUG
for (std::ptrdiff_t n = 0; n < ndyvars; ++n)
dyvar[n] = Arith::nan<T>()();
#endif

for (std::ptrdiff_t n = 0; n < ndyvars; ++n) {
const std::ptrdiff_t di = n - deriv_order;
const std::ptrdiff_t di = n - deriv_order / 2;
// Skip the unused central point
if (di == 0)
continue;
Expand All @@ -343,7 +343,7 @@ deriv2_2d(const TS var, const T dx, const T dy) {
assert(di >= -deriv_order / 2);
assert(di <= +deriv_order / 2);
#endif
return dyvar[deriv_order + di];
return dyvar[deriv_order / 2 + di];
},
dx);
}
Expand Down

0 comments on commit bf0cf9a

Please sign in to comment.