diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index 393f1160..9e56531f 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -55,33 +55,39 @@ function (interp::AbstractInterpolation)(u::AbstractVector, t::AbstractVector) u end -const EXTRAPOLATION_ERROR = "Cannot extrapolate as `extrapolate` keyword passed was `false`" -struct ExtrapolationError <: Exception end -function Base.showerror(io::IO, e::ExtrapolationError) - print(io, EXTRAPOLATION_ERROR) +const DOWN_EXTRAPOLATION_ERROR = "Cannot extrapolate down as `extrapolation_down` keyword passed was `:none`" +struct DownExtrapolationError <: Exception end +function Base.showerror(io::IO, ::DownExtrapolationError) + print(io, DOWN_EXTRAPOLATION_ERROR) +end + +const UP_EXTRAPOLATION_ERROR = "Cannot extrapolate up as `extrapolation_up` keyword passed was `:none`" +struct UpExtrapolationError <: Exception end +function Base.showerror(io::IO, ::UpExtrapolationError) + print(io, UP_EXTRAPOLATION_ERROR) end const INTEGRAL_NOT_FOUND_ERROR = "Cannot integrate it analytically. Please use Numerical Integration methods." struct IntegralNotFoundError <: Exception end -function Base.showerror(io::IO, e::IntegralNotFoundError) +function Base.showerror(io::IO, ::IntegralNotFoundError) print(io, INTEGRAL_NOT_FOUND_ERROR) end const DERIVATIVE_NOT_FOUND_ERROR = "Derivatives greater than second order is not supported." struct DerivativeNotFoundError <: Exception end -function Base.showerror(io::IO, e::DerivativeNotFoundError) +function Base.showerror(io::IO, ::DerivativeNotFoundError) print(io, DERIVATIVE_NOT_FOUND_ERROR) end const INTEGRAL_INVERSE_NOT_FOUND_ERROR = "Cannot invert the integral analytically. Please use Numerical methods." struct IntegralInverseNotFoundError <: Exception end -function Base.showerror(io::IO, e::IntegralInverseNotFoundError) +function Base.showerror(io::IO, ::IntegralInverseNotFoundError) print(io, INTEGRAL_INVERSE_NOT_FOUND_ERROR) end const INTEGRAL_NOT_INVERTIBLE_ERROR = "The Interpolation is not positive everywhere so its integral is not invertible." struct IntegralNotInvertibleError <: Exception end -function Base.showerror(io::IO, e::IntegralNotInvertibleError) +function Base.showerror(io::IO, ::IntegralNotInvertibleError) print(io, INTEGRAL_NOT_INVERTIBLE_ERROR) end @@ -90,6 +96,8 @@ export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, BSplineInterpolation, BSplineApprox, CubicHermiteSpline, PCHIPInterpolation, QuinticHermiteSpline, LinearInterpolationIntInv, ConstantInterpolationIntInv +const extrapolation_types::Vector{Symbol} = [:none, :constant, :linear, :extension] + # added for RegularizationSmooth, JJS 11/27/21 ### Regularization data smoothing and interpolation struct RegularizationSmooth{uType, tType, T, T2, N, ITP <: AbstractInterpolation{T, N}} <: diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index d0bfe26e..f8f65755 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -314,14 +314,18 @@ struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} < k::kType # knot vector c::cType # B-spline control points sc::scType # Spline coefficients (preallocated memory) - extrapolate::Bool + extrapolation_down::Symbol + extrapolation_up::Symbol iguesser::Guesser{tType} cache_parameters::Bool linear_lookup::Bool function QuadraticSpline( - u, t, I, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t) + u, t, I, p, k, c, sc, extrapolation_down, + extrapolation_up, cache_parameters, assume_linear_t) linear_lookup = seems_linear(assume_linear_t, t) N = get_output_dim(u) + validate_extrapolation(extrapolation_down) + validate_extrapolation(extrapolation_up) new{typeof(u), typeof(t), typeof(I), typeof(p.α), typeof(k), typeof(c), typeof(sc), eltype(u), N}(u, t, @@ -330,7 +334,8 @@ struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} < k, c, sc, - extrapolate, + extrapolation_down, + extrapolation_up, Guesser(t), cache_parameters, linear_lookup @@ -339,7 +344,7 @@ struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} < end function QuadraticSpline( - u::uType, t; extrapolate = false, + u::uType, t; extrapolation_down::Symbol = :none, extrapolation_up::Symbol = :none, cache_parameters = false, assume_linear_t = 1e-2) where {uType <: AbstractVector{<:Number}} u, t = munge_data(u, t) @@ -352,9 +357,9 @@ function QuadraticSpline( p = QuadraticSplineParameterCache(u, t, k, c, sc, cache_parameters) A = QuadraticSpline( - u, t, nothing, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t) + u, t, nothing, p, k, c, sc, extrapolation_down, extrapolation_up, cache_parameters, assume_linear_t) I = cumulative_integral(A, cache_parameters) - QuadraticSpline(u, t, I, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t) + QuadraticSpline(u, t, I, p, k, c, sc, extrapolation_down, extrapolation_up, cache_parameters, assume_linear_t) end function QuadraticSpline( diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index f47f9e3c..8dbf0f75 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -1,7 +1,39 @@ function _interpolate(A, t) - ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && - throw(ExtrapolationError()) - return _interpolate(A, t, A.iguesser) + if t < first(A.t) + _extrapolate_down(A, t) + elseif t > last(A.t) + _extrapolate_up(A, t) + else + _interpolate(A, t, A.iguesser) + end +end + +function _extrapolate_down(A, t) + (; extrapolation_down) = A + if extrapolation_down == :none + throw(ExtrapolationError(DOWN_EXTRAPOLATION_ERROR)) + elseif extrapolation_down == :constant + first(A.u) + elseif extrapolation_down == :linear + slope = derivative(A, first(A.t)) + first(A.u) + slope * (t - first(A.t)) + elseif extrapolation_down == :extension + _interpolate(A, t, A.iguesser) + end +end + +function _extrapolate_up(A, t) + (; extrapolation_up) = A + if extrapolation_up == :none + throw(ExtrapolationError(DOWN_EXTRAPOLATION_ERROR)) + elseif extrapolation_up == :constant + last(A.u) + elseif extrapolation_up == :linear + slope = derivative(A, last(A.t)) + last(A.u) + slope * (t - last(A.t)) + elseif extrapolation_up == :extension + _interpolate(A, t, A.iguesser) + end end # Linear Interpolation diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 2d1c6432..77470423 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -248,6 +248,12 @@ function get_parameters(A::QuinticHermiteSpline, idx) end end +function validate_extrapolation(method::Symbol) + if method ∉ extrapolation_types + error("Invalid extrapolation method `$method` supplied, use one of $extrapolation_types.") + end +end + function du_PCHIP(u, t) h = diff(u) δ = h ./ diff(t)