Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive interpolation #102

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a095651
Fix a minor type error and add a const version.
BrendanKKrueger Sep 30, 2024
f51c410
First draft of recursive implementation.
BrendanKKrueger Sep 30, 2024
8ac909c
I was querying the grids in the wrong order, so my indices and weight…
BrendanKKrueger Oct 1, 2024
1cf992c
format
BrendanKKrueger Oct 1, 2024
45156e1
Rename for now.
BrendanKKrueger Oct 1, 2024
ff9a916
Added a customization point where the argument _could_ be a different…
BrendanKKrueger Oct 1, 2024
5b7e13b
More generalizing.
BrendanKKrueger Oct 1, 2024
74cdfbc
Rearrange.
BrendanKKrueger Oct 1, 2024
4d5ceef
Change how the arguments are iterated over to populate the indexweigh…
BrendanKKrueger Oct 1, 2024
f093ae4
Add int
BrendanKKrueger Oct 1, 2024
b21e940
number template is now redundant
BrendanKKrueger Oct 1, 2024
c885dcf
Use interpolate_alt for remaining interpToReal.
BrendanKKrueger Oct 1, 2024
2acc1fe
A little cleanup
BrendanKKrueger Oct 1, 2024
fc2462a
Cleanup
BrendanKKrueger Oct 1, 2024
f881e5b
format
BrendanKKrueger Oct 1, 2024
94a0014
consistency
BrendanKKrueger Oct 2, 2024
939e602
Better handling of passing the indices to the dataView_ lookup (plus …
BrendanKKrueger Oct 3, 2024
52192b6
Replace separate index and weights_t with combined index_and_weights_t.
BrendanKKrueger Oct 4, 2024
4d997b5
a bit of clarifying comments
BrendanKKrueger Oct 4, 2024
1f53801
Streamline iwlist iteration.
BrendanKKrueger Oct 8, 2024
55197b2
Minor tweaks
BrendanKKrueger Oct 22, 2024
99cf011
Delete TODO comment.
BrendanKKrueger Oct 23, 2024
7b22602
Rename "interpolate" to "interpToScalar", and add a note about the fu…
BrendanKKrueger Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/sphinx/src/databox.rst
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,15 @@ so on. These interpolation routines are hand-tuned for performance.
or try to interpolate on indices that are not interpolatable.
This is checked with an ``assert`` statement.

.. warning::
The ``DataBox::interpToReal`` method is deprecated, and will be replaced by
the ``DataBox::interpToScalar`` method. The ``DataBox::interpToScalar``
method is already available, so we recommend changing your code to use that
instead so as to future-proof your code against the upcoming removal of
``DataBox::interpToReal``. The semantics of the two functions are
identical, but the change to ``DataBox::interpToScalar`` will enable new
features and improve maintainability of Spiner.

Mixed interpolation and indexing
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
302 changes: 116 additions & 186 deletions spiner/databox.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,18 @@ class DataBox {
PORTABLE_FORCEINLINE_FUNCTION T interpToReal(const T x4, const T x3,
const T x2, const int idx,
const T x1) const noexcept;
// Entry point for recursive interpolation. See interp_core for actual
// recursion. There are multiple versions of interp_core to handle the
// different inputs allowed:
// * int: an index telling the exact data point to use for that axis
// * T : a coordinate to interpolate to that point on that axis
template <typename... Coords>
PORTABLE_FORCEINLINE_FUNCTION T
interpToScalar(const Coords... coords) const noexcept;

// TODO: In principle, the logic for interpToScalar and interp_core could be
// extended to work on these routines. I've not looked at how easy it
// would be, so it may be more work than it's worth?
// Interpolates SLOWEST indices of databox to a new
// DataBox, interpolated at that slowest index.
// WARNING: requires memory to be pre-allocated.
Expand Down Expand Up @@ -436,6 +448,43 @@ class DataBox {
status_ = DataStatus::AllocatedHost;
}
}

// Recursive interpolation: coordinate to interpolate to along this axis
template <std::size_t N, typename... Args>
PORTABLE_FORCEINLINE_FUNCTION T
interp_core(const index_and_weights_t<T> *iwlist, const T coordinate, Args... other_args) const noexcept;
// Recursive interpolation: index for exact data point to use along this axis
template <std::size_t N, typename... Args>
PORTABLE_FORCEINLINE_FUNCTION T
interp_core(const index_and_weights_t<T> *iwlist, const int index, Args... other_args) const noexcept;

template <typename... Args>
static PORTABLE_INLINE_FUNCTION void
append_index_and_weights(index_and_weights_t<T> *iwlist, const Grid_t *grid, const T x,
Args... other_args) {
// Leading argument is a coordinate: Need to compute index and weights.
grid->weights(x, iwlist[0]);
// Note: grids are in reverse order relative to arguments
append_index_and_weights(iwlist + 1, grid - 1, other_args...);
Comment on lines +467 to +468
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inverting the order here is a little frightening... we might want some kind of (debug enabled) sanity check that the grid array is the same size as the arg list when we call this for the first time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, something has to be inverted:

  • The order of the arguments is chosen to reflect the order of the interpolation recursion
  • The order of grids_ is reversed relative to the order of the arguments
  • The append_index_and_weights method could recurse in the opposite order as the interpolation, which would mean that the loop over grids wouldn't be backwards. But then either we're filling iwlist back to front or we're reading iwlist back to front during the interpolation recursion.

I was assuming that assert(canInterpToReal_(N)); would do the kind of sanity checking that you're looking for. If not, what sort of check(s) would you consider sufficient here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly related: When filling iwlist, I skip axes where the argument is already an index, because we don't need to store another copy of the same index and the weights are implicitly 1 and 0, so iwlist is redundant for those axes. However, because I declare iwlist as having length N, the memory is already allocated. After thinking about it, I think that this is just one extra thing to remember and we're better off just using iwlist consistently regardless of whether the argument is an index or a coordinate. This may be relevant if we change how iwlist is filled (e.g., if we fill it back-to-front). So I'm changing that behavior to just always use iwlist consistently.

}
template <typename... Args>
static PORTABLE_INLINE_FUNCTION void
append_index_and_weights(index_and_weights_t<T> *iwlist, const Grid_t *grid,
const int index, Args... other_args) {
// Leading argument is an index: We know the answer so we don't actually
// need to store this information for the recursion, but it keeps the
// bookkeeping cleaner and allows for some debugging checks to be added in
// more easily if necessary.
iwlist->index = index;
iwlist->w0 = 1;
iwlist->w1 = 0;
// Note: grids are in reverse order relative to arguments
append_index_and_weights(iwlist + 1, grid - 1, other_args...);
}
template <typename... Args>
static PORTABLE_INLINE_FUNCTION void
append_index_and_weights(index_and_weights_t<T> *iwlist, const Grid_t *grid) {
} // terminate recursion
};

// Read an array, shallow
Expand All @@ -446,198 +495,81 @@ inline void DataBox<T, Grid_t, Concept>::setArray(PortableMDArray<T> &A) {
setAllIndexed_();
}

template <typename T, typename Grid_t, typename Concept>
template <typename... Coords>
PORTABLE_INLINE_FUNCTION T DataBox<T, Grid_t, Concept>::interpToScalar(
const Coords... coords) const noexcept {
constexpr std::size_t N = sizeof...(Coords);
assert(canInterpToReal_(N));
index_and_weights_t<T> iwlist[N];
append_index_and_weights(iwlist, &(grids_[N - 1]), coords...);
return interp_core<N>(iwlist, coords...);
}

template <typename T, typename Grid_t, typename Concept>
template <std::size_t N, typename... Args>
PORTABLE_FORCEINLINE_FUNCTION T DataBox<T, Grid_t, Concept>::interp_core(
const index_and_weights_t<T> *iwlist, const T coordinate, Args... other_args) const noexcept {
const auto &current = iwlist[0];
static_assert(N > 0, "interp_core<0> must have already converted all coordinates to indices");
// recursive case
const T v0 = interp_core<N-1>(iwlist + 1, other_args..., current.index);
const T v1 = interp_core<N-1>(iwlist + 1, other_args..., current.index + 1);
return current.w0 * v0 + current.w1 * v1;
}

template <typename T, typename Grid_t, typename Concept>
template <std::size_t N, typename... Args>
PORTABLE_FORCEINLINE_FUNCTION T DataBox<T, Grid_t, Concept>::interp_core(
const index_and_weights_t<T> *iwlist, const int index, Args... other_args) const noexcept {
if constexpr (N == 0) {
// base case
return dataView_(index, other_args...);
} else {
// recursive case
// -- Note: We don't actually need to use iwlist[0], but for bookkeeping
// purposes we have to advance to the next entry
return interp_core<N-1>(iwlist + 1, other_args..., index);
}
}

template <typename T, typename Grid_t, typename Concept>
PORTABLE_INLINE_FUNCTION T
DataBox<T, Grid_t, Concept>::interpToReal(const T x) const noexcept {
assert(canInterpToReal_(1));
int ix;
weights_t<T> w;
grids_[0].weights(x, ix, w);
return w[0] * dataView_(ix) + w[1] * dataView_(ix + 1);
return interpToScalar(x);
}

template <typename T, typename Grid_t, typename Concept>
PORTABLE_FORCEINLINE_FUNCTION T DataBox<T, Grid_t, Concept>::interpToReal(
const T x2, const T x1) const noexcept {
assert(canInterpToReal_(2));
int ix1, ix2;
weights_t<T> w1, w2;
grids_[0].weights(x1, ix1, w1);
grids_[1].weights(x2, ix2, w2);
// TODO: prefectch corners for speed?
// TODO: re-order access pattern?
return (w2[0] *
(w1[0] * dataView_(ix2, ix1) + w1[1] * dataView_(ix2, ix1 + 1)) +
w2[1] * (w1[0] * dataView_(ix2 + 1, ix1) +
w1[1] * dataView_(ix2 + 1, ix1 + 1)));
return interpToScalar(x2, x1);
}

template <typename T, typename Grid_t, typename Concept>
PORTABLE_FORCEINLINE_FUNCTION T DataBox<T, Grid_t, Concept>::interpToReal(
const T x3, const T x2, const T x1) const noexcept {
assert(canInterpToReal_(3));
int ix[3];
weights_t<T> w[3];
grids_[0].weights(x1, ix[0], w[0]);
grids_[1].weights(x2, ix[1], w[1]);
grids_[2].weights(x3, ix[2], w[2]);
// TODO: prefect corners for speed?
// TODO: re-order access pattern?
return (
w[2][0] * (w[1][0] * (w[0][0] * dataView_(ix[2], ix[1], ix[0]) +
w[0][1] * dataView_(ix[2], ix[1], ix[0] + 1)) +
w[1][1] * (w[0][0] * dataView_(ix[2], ix[1] + 1, ix[0]) +
w[0][1] * dataView_(ix[2], ix[1] + 1, ix[0] + 1))) +
w[2][1] *
(w[1][0] * (w[0][0] * dataView_(ix[2] + 1, ix[1], ix[0]) +
w[0][1] * dataView_(ix[2] + 1, ix[1], ix[0] + 1)) +
w[1][1] * (w[0][0] * dataView_(ix[2] + 1, ix[1] + 1, ix[0]) +
w[0][1] * dataView_(ix[2] + 1, ix[1] + 1, ix[0] + 1))));
return interpToScalar(x3, x2, x1);
}

template <typename T, typename Grid_t, typename Concept>
PORTABLE_FORCEINLINE_FUNCTION T DataBox<T, Grid_t, Concept>::interpToReal(
const T x3, const T x2, const T x1, const int idx) const noexcept {
assert(rank_ == 4);
for (int r = 1; r < rank_; ++r) {
assert(indices_[r] == IndexType::Interpolated);
assert(grids_[r].isWellFormed());
}
int ix[3];
weights_t<T> w[3];
grids_[1].weights(x1, ix[0], w[0]);
grids_[2].weights(x2, ix[1], w[1]);
grids_[3].weights(x3, ix[2], w[2]);
// TODO: prefect corners for speed?
// TODO: re-order access pattern?
return (
w[2][0] *
(w[1][0] * (w[0][0] * dataView_(ix[2], ix[1], ix[0], idx) +
w[0][1] * dataView_(ix[2], ix[1], ix[0] + 1, idx)) +
w[1][1] * (w[0][0] * dataView_(ix[2], ix[1] + 1, ix[0], idx) +
w[0][1] * dataView_(ix[2], ix[1] + 1, ix[0] + 1, idx))) +
w[2][1] *
(w[1][0] * (w[0][0] * dataView_(ix[2] + 1, ix[1], ix[0], idx) +
w[0][1] * dataView_(ix[2] + 1, ix[1], ix[0] + 1, idx)) +
w[1][1] *
(w[0][0] * dataView_(ix[2] + 1, ix[1] + 1, ix[0], idx) +
w[0][1] * dataView_(ix[2] + 1, ix[1] + 1, ix[0] + 1, idx))));
return interpToScalar(x3, x2, x1, idx);
}

// DH: this is a large function to force an inline, perhaps just make it a
// suggestion to the compiler?
template <typename T, typename Grid_t, typename Concept>
PORTABLE_FORCEINLINE_FUNCTION T DataBox<T, Grid_t, Concept>::interpToReal(
const T x4, const T x3, const T x2, const T x1) const noexcept {
assert(canInterpToReal_(4));
T x[] = {x1, x2, x3, x4};
int ix[4];
weights_t<T> w[4];
for (int i = 0; i < 4; ++i) {
grids_[i].weights(x[i], ix[i], w[i]);
}
// TODO(JMM): This is getty pretty gross. Should we automate?
// Hand-written is probably faster, though.
// Breaking line-limit to make this easier to read
return (
w[3][0] *
(w[2][0] *
(w[1][0] *
(w[0][0] * dataView_(ix[3], ix[2], ix[1], ix[0]) +
w[0][1] * dataView_(ix[3], ix[2], ix[1], ix[0] + 1)) +
w[1][1] *
(w[0][0] * dataView_(ix[3], ix[2], ix[1] + 1, ix[0]) +
w[0][1] * dataView_(ix[3], ix[2], ix[1] + 1, ix[0] + 1))) +
w[2][1] *
(w[1][0] *
(w[0][0] * dataView_(ix[3], ix[2] + 1, ix[1], ix[0]) +
w[0][1] * dataView_(ix[3], ix[2] + 1, ix[1], ix[0] + 1)) +
w[1][1] *
(w[0][0] * dataView_(ix[3], ix[2] + 1, ix[1] + 1, ix[0]) +
w[0][1] *
dataView_(ix[3], ix[2] + 1, ix[1] + 1, ix[0] + 1)))) +
w[3][1] *
(w[2][0] *
(w[1][0] *
(w[0][0] * dataView_(ix[3] + 1, ix[2], ix[1], ix[0]) +
w[0][1] * dataView_(ix[3] + 1, ix[2], ix[1], ix[0] + 1)) +
w[1][1] *
(w[0][0] * dataView_(ix[3] + 1, ix[2], ix[1] + 1, ix[0]) +
w[0][1] *
dataView_(ix[3] + 1, ix[2], ix[1] + 1, ix[0] + 1))) +
w[2][1] * (w[1][0] * (w[0][0] * dataView_(ix[3] + 1, ix[2] + 1,
ix[1], ix[0]) +
w[0][1] * dataView_(ix[3] + 1, ix[2] + 1,
ix[1], ix[0] + 1)) +
w[1][1] * (w[0][0] * dataView_(ix[3] + 1, ix[2] + 1,
ix[1] + 1, ix[0]) +
w[0][1] * dataView_(ix[3] + 1, ix[2] + 1,
ix[1] + 1, ix[0] + 1))))

);
return interpToScalar(x4, x3, x2, x1);
}

template <typename T, typename Grid_t, typename Concept>
PORTABLE_FORCEINLINE_FUNCTION T DataBox<T, Grid_t, Concept>::interpToReal(
const T x4, const T x3, const T x2, const int idx,
const T x1) const noexcept {
assert(rank_ == 5);
assert(indices_[0] == IndexType::Interpolated);
assert(grids_[0].isWellFormed());
for (int i = 2; i < 5; ++i) {
assert(indices_[i] == IndexType::Interpolated);
assert(grids_[i].isWellFormed());
}
T x[] = {x1, x2, x3, x4};
int ix[4];
weights_t<T> w[4];
grids_[0].weights(x[0], ix[0], w[0]);
for (int i = 1; i < 4; ++i) {
grids_[i + 1].weights(x[i], ix[i], w[i]);
}
// TODO(JMM): This is getty pretty gross. Should we automate?
// Hand-written is probably faster, though.
// Breaking line-limit to make this easier to read
return (
w[3][0] *
(w[2][0] *
(w[1][0] *
(w[0][0] * dataView_(ix[3], ix[2], ix[1], idx, ix[0]) +
w[0][1] * dataView_(ix[3], ix[2], ix[1], idx, ix[0] + 1)) +
w[1][1] *
(w[0][0] * dataView_(ix[3], ix[2], ix[1] + 1, idx, ix[0]) +
w[0][1] *
dataView_(ix[3], ix[2], ix[1] + 1, idx, ix[0] + 1))) +
w[2][1] *
(w[1][0] *
(w[0][0] * dataView_(ix[3], ix[2] + 1, ix[1], idx, ix[0]) +
w[0][1] *
dataView_(ix[3], ix[2] + 1, ix[1], idx, ix[0] + 1)) +
w[1][1] * (w[0][0] * dataView_(ix[3], ix[2] + 1, ix[1] + 1, idx,
ix[0]) +
w[0][1] * dataView_(ix[3], ix[2] + 1, ix[1] + 1, idx,
ix[0] + 1)))) +
w[3][1] *
(w[2][0] *
(w[1][0] *
(w[0][0] * dataView_(ix[3] + 1, ix[2], ix[1], idx, ix[0]) +
w[0][1] *
dataView_(ix[3] + 1, ix[2], ix[1], idx, ix[0] + 1)) +
w[1][1] * (w[0][0] * dataView_(ix[3] + 1, ix[2], ix[1] + 1, idx,
ix[0]) +
w[0][1] * dataView_(ix[3] + 1, ix[2], ix[1] + 1, idx,
ix[0] + 1))) +
w[2][1] *
(w[1][0] * (w[0][0] * dataView_(ix[3] + 1, ix[2] + 1, ix[1], idx,
ix[0]) +
w[0][1] * dataView_(ix[3] + 1, ix[2] + 1, ix[1], idx,
ix[0] + 1)) +
w[1][1] * (w[0][0] * dataView_(ix[3] + 1, ix[2] + 1, ix[1] + 1,
idx, ix[0]) +
w[0][1] * dataView_(ix[3] + 1, ix[2] + 1, ix[1] + 1,
idx, ix[0] + 1))))

);
return interpToScalar(x4, x3, x2, idx, x1);
}

template <typename T, typename Grid_t, typename Concept>
Expand All @@ -648,16 +580,15 @@ DataBox<T, Grid_t, Concept>::interpFromDB(const DataBox<T, Grid_t, Concept> &db,
assert(db.grids_[db.rank_ - 1].isWellFormed());
assert(size() == (db.size() / db.dim(db.rank_)));

int ix;
weights_t<T> w;
index_and_weights_t<T> iw;
copyShape(db, 1);

db.grids_[db.rank_ - 1].weights(x, ix, w);
DataBox<T, Grid_t, Concept> lower(db.slice(ix)), upper(db.slice(ix + 1));
// lower = db.slice(ix);
// upper = db.slice(ix+1);
db.grids_[db.rank_ - 1].weights(x, iw);
DataBox<T, Grid_t, Concept> lower(db.slice(iw.index)), upper(db.slice(iw.index + 1));
// lower = db.slice(iw.index);
// upper = db.slice(iw.index+1);
for (int i = 0; i < size(); i++) {
dataView_(i) = w[0] * lower(i) + w[1] * upper(i);
dataView_(i) = iw.w0 * lower(i) + iw.w1 * upper(i);
}
}

Expand All @@ -672,35 +603,34 @@ DataBox<T, Grid_t, Concept>::interpFromDB(const DataBox<T, Grid_t, Concept> &db,
assert(db.grids_[db.rank_ - 2].isWellFormed());
assert(size() == (db.size() / (db.dim(db.rank_) * db.dim(db.rank_ - 1))));

int ix2, ix1;
weights_t<T> w2, w1;
index_and_weights_t<T> iw2, iw1;
copyShape(db, 2);

db.grids_[db.rank_ - 2].weights(x1, ix1, w1);
db.grids_[db.rank_ - 1].weights(x2, ix2, w2);
db.grids_[db.rank_ - 2].weights(x1, iw1);
db.grids_[db.rank_ - 1].weights(x2, iw2);
DataBox<T, Grid_t, Concept> corners[2][2]{
{db.slice(ix2, ix1), db.slice(ix2 + 1, ix1)},
{db.slice(ix2, ix1 + 1), db.slice(ix2 + 1, ix1 + 1)}};
{db.slice(iw2.index, iw1.index), db.slice(iw2.index + 1, iw1.index)},
{db.slice(iw2.index, iw1.index + 1), db.slice(iw2.index + 1, iw1.index + 1)}};
// copyShape(db,2);
//
// db.grids_[db.rank_-2].weights(x1, ix1, w1);
// db.grids_[db.rank_-1].weights(x2, ix2, w2);
// corners[0][0] = db.slice(ix2, ix1 );
// corners[1][0] = db.slice(ix2, ix1+1 );
// corners[0][1] = db.slice(ix2+1, ix1 );
// corners[1][1] = db.slice(ix2+1, ix1+1 );
// db.grids_[db.rank_-2].weights(x1, iw1);
// db.grids_[db.rank_-1].weights(x2, iw2);
// corners[0][0] = db.slice(iw2.index, iw1.index );
// corners[1][0] = db.slice(iw2.index, iw1.index+1 );
// corners[0][1] = db.slice(iw2.index+1, iw1.index );
// corners[1][1] = db.slice(iw2.index+1, iw1.index+1 );
/*
for (int i = 0; i < size(); i++) {
dataView_(i) = ( w2[0]*w1[0]*corners[0][0](i)
+ w2[0]*w1[1]*corners[1][0](i)
+ w2[1]*w1[0]*corners[0][1](i)
+ w2[1]*w1[1]*corners[1][1](i));
dataView_(i) = ( iw2.w0*iw1.w0*corners[0][0](i)
+ iw2.w0*iw1.w1*corners[1][0](i)
+ iw2.w1*iw1.w0*corners[0][1](i)
+ iw2.w1*iw1.w1*corners[1][1](i));
}
*/
for (int i = 0; i < size(); i++) {
dataView_(i) =
(w2[0] * (w1[0] * corners[0][0](i) + w1[1] * corners[1][0](i)) +
w2[1] * (w1[0] * corners[0][1](i) + w1[1] * corners[1][1](i)));
(iw2.w0 * (iw1.w0 * corners[0][0](i) + iw1.w1 * corners[1][0](i)) +
iw2.w1 * (iw1.w0 * corners[0][1](i) + iw1.w1 * corners[1][1](i)));
}
}

Expand Down
Loading
Loading