Skip to content

Commit

Permalink
Merge pull request #3522 from pleroy/EfficientDownsampling
Browse files Browse the repository at this point in the history
A more efficient L-Infinity check
  • Loading branch information
pleroy authored Feb 7, 2023
2 parents fefdd4f + d34eb68 commit 455de8e
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 22 deletions.
32 changes: 17 additions & 15 deletions numerics/fit_hermite_spline_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,28 @@ absl::StatusOr<std::list<typename Samples::const_iterator>> FitHermiteSpline(
typename Hilbert<Difference<Value>>::NormType const& tolerance) {
using Iterator = typename Samples::const_iterator;

auto interpolation_error = [get_argument, get_derivative, get_value](
Iterator begin, Iterator last) {
return Hermite3<Argument, Value>(
{get_argument(*begin), get_argument(*last)},
{get_value(*begin), get_value(*last)},
{get_derivative(*begin), get_derivative(*last)})
.LInfinityError(Range(begin, last + 1), get_argument, get_value);
};
auto interpolation_error_is_within_tolerance =
[get_argument, get_derivative, get_value, tolerance](
Iterator const begin, Iterator const last) {
return Hermite3<Argument, Value>(
{get_argument(*begin), get_argument(*last)},
{get_value(*begin), get_value(*last)},
{get_derivative(*begin), get_derivative(*last)})
.LInfinityErrorIsWithin(
Range(begin, last + 1), get_argument, get_value, tolerance);
};

std::list<Iterator> tail;
std::list<Iterator> fit;
if (samples.size() < 3) {
// With 0 or 1 points there is nothing to interpolate, with 2 we cannot
// estimate the error.
return tail;
return fit;
}

Iterator begin = samples.begin();
Iterator const last = samples.end() - 1;
while (last - begin + 1 >= 3 &&
interpolation_error(begin, last) >= tolerance) {
!interpolation_error_is_within_tolerance(begin, last)) {
// Look for a cubic that fits the beginning within |tolerance| and
// such the cubic fitting one more sample would not fit the samples within
// |tolerance|.
Expand All @@ -68,13 +70,13 @@ absl::StatusOr<std::list<typename Samples::const_iterator>> FitHermiteSpline(
if (middle == lower) {
break;
}
if (interpolation_error(begin, middle) < tolerance) {
if (interpolation_error_is_within_tolerance(begin, middle)) {
lower = middle;
} else {
upper = middle;
}
}
tail.push_back(lower);
fit.push_back(lower);

begin = lower;
}
Expand All @@ -83,9 +85,9 @@ absl::StatusOr<std::list<typename Samples::const_iterator>> FitHermiteSpline(
// point, except at the end where we give up because we don't have enough
// points left.
#if PRINCIPIA_MUST_ALWAYS_DOWNSAMPLE
CHECK_LT(tail.size(), samples.size() - 2);
CHECK_LT(fit.size(), samples.size() - 2);
#endif
return tail;
return fit;
}

} // namespace internal_fit_hermite_spline
Expand Down
15 changes: 14 additions & 1 deletion numerics/hermite3.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ using geometry::Hilbert;
// TODO(phl): Invert the two template arguments for consistency with Derivative.
template<typename Argument, typename Value>
class Hermite3 final {
using NormType = typename Hilbert<Difference<Value>>::NormType;

public:
using Derivative1 = Derivative<Value, Argument>;

Expand All @@ -43,13 +45,24 @@ class Hermite3 final {
// Returns the largest error (in the given |norm|) between this polynomial and
// the given |samples|.
template<typename Samples>
typename Hilbert<Difference<Value>>::NormType LInfinityError(
NormType LInfinityError(
Samples const& samples,
std::function<Argument const&(typename Samples::value_type const&)> const&
get_argument,
std::function<Value const&(typename Samples::value_type const&)> const&
get_value) const;

// Returns true if the |LInfinityError| is less than |tolerance|. More
// efficient than the above function in the case where it returns false.
template<typename Samples>
bool LInfinityErrorIsWithin(
Samples const& samples,
std::function<Argument const&(typename Samples::value_type const&)> const&
get_argument,
std::function<Value const&(typename Samples::value_type const&)> const&
get_value,
NormType const& tolerance) const;

private:
using Derivative2 = Derivative<Derivative1, Argument>;
using Derivative3 = Derivative<Derivative2, Argument>;
Expand Down
29 changes: 23 additions & 6 deletions numerics/hermite3_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,13 @@ BoundedArray<Argument, 2> Hermite3<Argument, Value>::FindExtrema() const {

template<typename Argument, typename Value>
template<typename Samples>
typename Hilbert<Difference<Value>>::NormType
Hermite3<Argument, Value>::LInfinityError(
auto Hermite3<Argument, Value>::LInfinityError(
Samples const& samples,
std::function<Argument const&(
typename Samples::value_type const&)> const& get_argument,
std::function<Argument const&(typename Samples::value_type const&)> const&
get_argument,
std::function<Value const&(typename Samples::value_type const&)> const&
get_value) const {
typename Hilbert<Difference<Value>>::NormType result{};
get_value) const -> NormType {
NormType result{};
for (const auto& sample : samples) {
result = std::max(result,
Hilbert<Difference<Value>>::Norm(
Expand All @@ -81,6 +80,24 @@ Hermite3<Argument, Value>::LInfinityError(
return result;
}

template<typename Argument, typename Value>
template<typename Samples>
bool Hermite3<Argument, Value>::LInfinityErrorIsWithin(
Samples const& samples,
std::function<Argument const&(typename Samples::value_type const&)> const&
get_argument,
std::function<Value const&(typename Samples::value_type const&)> const&
get_value,
NormType const& tolerance) const {
for (const auto& sample : samples) {
if (Hilbert<Difference<Value>>::Norm(Evaluate(get_argument(sample)) -
get_value(sample)) >= tolerance) {
return false;
}
}
return true;
}

} // namespace internal_hermite3
} // namespace numerics
} // namespace principia
22 changes: 22 additions & 0 deletions numerics/hermite3_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ TEST_F(Hermite3Test, OneDimensionalInterpolationError) {
/*get_argument=*/[](auto&& pair) -> auto&& { return pair.first; },
/*get_value=*/[](auto&& pair) -> auto&& { return pair.second; }),
Eq(1 / 16.0));

EXPECT_TRUE(not_a_quartic.LInfinityErrorIsWithin(
samples,
/*get_argument=*/[](auto&& pair) -> auto&& { return pair.first; },
/*get_value=*/[](auto&& pair) -> auto&& { return pair.second; },
/*tolerance=*/0.1));
EXPECT_FALSE(not_a_quartic.LInfinityErrorIsWithin(
samples,
/*get_argument=*/[](auto&& pair) -> auto&& { return pair.first; },
/*get_value=*/[](auto&& pair) -> auto&& { return pair.second; },
/*tolerance=*/0.05));
}

TEST_F(Hermite3Test, ThreeDimensionalInterpolationError) {
Expand Down Expand Up @@ -138,6 +149,17 @@ TEST_F(Hermite3Test, ThreeDimensionalInterpolationError) {
/*get_argument=*/[](auto&& pair) -> auto&& { return pair.first; },
/*get_value=*/[](auto&& pair) -> auto&& { return pair.second; }),
IsNear(1.5_(1) * Centi(Metre)));

EXPECT_TRUE(not_a_circle.LInfinityErrorIsWithin(
samples,
/*get_argument=*/[](auto&& pair) -> auto&& { return pair.first; },
/*get_value=*/[](auto&& pair) -> auto&& { return pair.second; },
/*tolerance=*/2 * Centi(Metre)));
EXPECT_FALSE(not_a_circle.LInfinityErrorIsWithin(
samples,
/*get_argument=*/[](auto&& pair) -> auto&& { return pair.first; },
/*get_value=*/[](auto&& pair) -> auto&& { return pair.second; },
/*tolerance=*/1 * Centi(Metre)));
}

} // namespace numerics
Expand Down

0 comments on commit 455de8e

Please sign in to comment.