Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove converter, replace by macro pragma to generate overloads #43

Merged
merged 7 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
nim: [ '1.4.x', 'stable', 'devel' ]
nim: [ '1.6.x', 'stable', 'devel' ]
# Steps represent a sequence of tasks that will be executed as part of the job
name: Nim ${{ matrix.nim }} sample
steps:
Expand Down
44 changes: 30 additions & 14 deletions src/numericalnim/integrate.nim
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ import arraymancer

from ./interpolate import InterpolatorType, newHermiteSpline

# to annotate procedures with `{.genInterp.}` to generate `InterpolatorType` overloads
import private/macro_utils

## # Integration
## This module implements various integration routines.
## It provides:
##
##
## ## Integrate discrete data:
## - `trapz`, `simpson`: works for any spacing between points.
## - `romberg`: requires equally spaced points and the number of points must be of the form 2^k + 1 ie 3, 5, 9, 17, 33, 65, 129 etc.
Expand All @@ -27,7 +30,7 @@ runnableExamples:
## It also handles infinite integration limits.
## - `gaussQuad`: Fixed step size Gaussian quadrature.
## - `romberg`: Adaptive method based on Richardson Extrapolation.
## - `adaptiveSimpson`: Adaptive step size.
## - `adaptiveSimpson`: Adaptive step size.
## - `simpson`: Fixed step size.
## - `trapz`: Fixed step size.

Expand All @@ -36,7 +39,7 @@ runnableExamples:

proc f(x: float, ctx: NumContext[float, float]): float =
exp(x)

let a = 0.0
let b = Inf
let integral = adaptiveGauss(f, a, b)
Expand Down Expand Up @@ -74,10 +77,9 @@ type
IntervalList[T; U; V] = object
list: seq[IntervalType[T, U, V]] # contains all the intervals sorted from smallest to largest error


# N: #intervals
proc trapz*[T](f: NumContextProc[T, float], xStart, xEnd: float,
N = 500, ctx: NumContext[T, float] = nil): T =
N = 500, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using the trapezoidal rule.
##
## Input:
Expand Down Expand Up @@ -172,9 +174,8 @@ proc cumtrapz*[T](f: NumContextProc[T, float], X: openArray[float],
t += dx
result = hermiteInterpolate(X, times, y, dy)


proc simpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
N = 500, ctx: NumContext[T, float] = nil): T =
N = 500, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using Simpson's rule.
##
## Input:
Expand Down Expand Up @@ -252,7 +253,7 @@ proc simpson*[T](Y: openArray[T], X: openArray[float]): T =
result += alpha * ySorted[2*i + 2] + beta * ySorted[2*i + 1] + eta * ySorted[2*i]

proc adaptiveSimpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
tol = 1e-8, ctx: NumContext[T, float] = nil): T =
tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using an adaptive Simpson's rule.
##
## Input:
Expand Down Expand Up @@ -284,7 +285,7 @@ proc adaptiveSimpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
return left + right

proc internal_adaptiveSimpson[T](f: NumContextProc[T, float], xStart, xEnd: float,
tol: float, ctx: NumContext[T, float], reused_points: array[3, T]): T =
tol: float, ctx: NumContext[T, float], reused_points: array[3, T]): T {.genInterp.} =
let zero = reused_points[0] - reused_points[0]
let dx1 = (xEnd - xStart) / 2
let dx2 = (xEnd - xStart) / 4
Expand All @@ -302,7 +303,7 @@ proc internal_adaptiveSimpson[T](f: NumContextProc[T, float], xStart, xEnd: floa
return left + right

proc adaptiveSimpson2*[T](f: NumContextProc[T, float], xStart, xEnd: float,
tol = 1e-8, ctx: NumContext[T, float] = nil): T =
tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using an adaptive Simpson's rule.
##
## Input:
Expand Down Expand Up @@ -399,7 +400,7 @@ proc cumsimpson*[T](f: NumContextProc[T, float], X: openArray[float],
result = hermiteInterpolate(X, t, ys, dy)

proc romberg*[T](f: NumContextProc[T, float], xStart, xEnd: float,
depth = 8, tol = 1e-8, ctx: NumContext[T, float] = nil): T =
depth = 8, tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using Romberg Integration.
##
## Input:
Expand Down Expand Up @@ -594,7 +595,7 @@ proc getGaussLegendreWeights(nPoints: int): tuple[nodes: seq[float], weights: se
return gaussWeights[nPoints]

proc gaussQuad*[T](f: NumContextProc[T, float], xStart, xEnd: float,
N = 100, nPoints = 7, ctx: NumContext[T, float] = nil): T =
N = 100, nPoints = 7, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using Gaussian Quadrature.
## Has 20 different sets of weights, ranging from 1 to 20 function evaluations per subinterval.
##
Expand Down Expand Up @@ -654,7 +655,7 @@ proc calcGaussKronrod[T; U](f: NumContextProc[T, U], xStart, xEnd: U, ctx: NumCo


proc adaptiveGaussLocal*[T](f: NumContextProc[T, float],
xStart, xEnd: float, tol = 1e-8, ctx: NumContext[T, float] = nil): T =
xStart, xEnd: float, tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using an locally adaptive Gauss-Kronrod Quadrature.
##
## Input:
Expand Down Expand Up @@ -872,6 +873,18 @@ proc adaptiveGauss*[T; U](f_in: NumContextProc[T, U],
adaptiveGaussImpl()
return totalValue

proc adaptiveGauss*[T](f_in: InterpolatorType[T]; xStart_in, xEnd_in: T;
tol = 1e-8; initialPoints: openArray[T] = @[];
maxintervals: int = 10000; ctx: NumContext[T, T] = nil): T =
## NOTE: On Nim 2.0.8 we cannot use `{.genInterp.}` on the above proc, because of
## of the double generic it has `[T; U]`. It fails. So this is just a manual version
## of the generated code for the time being.
mixin eval
mixin InterpolatorType
mixin toNumContextProc
let ncp = toNumContextProc(f_in)
adaptiveGauss(ncp, xStart_in, xEnd_in, tol, initialPoints, maxintervals, ctx)

proc cumGaussSpline*[T; U](f_in: NumContextProc[T, U],
xStart_in, xEnd_in: U, tol = 1e-8, initialPoints: openArray[U] = @[], maxintervals: int = 10000, ctx: NumContext[T, U] = nil): InterpolatorType[T] =
## Calculate the cumulative integral of f using an globally adaptive Gauss-Kronrod Quadrature. Inf and -Inf can be used as integration limits.
Expand Down Expand Up @@ -909,7 +922,10 @@ proc cumGaussSpline*[T; U](f_in: NumContextProc[T, U],
result = newHermiteSpline[T](xs, ys)

proc cumGauss*[T](f_in: NumContextProc[T, float],
X: openArray[float], tol = 1e-8, initialPoints: openArray[float] = @[], maxintervals: int = 10000, ctx: NumContext[T, float] = nil): seq[T] =
X: openArray[float], tol = 1e-8,
initialPoints: openArray[float] = @[],
maxintervals: int = 10000,
ctx: NumContext[T, float] = nil): seq[T] {.genInterp.} =
## Calculate the cumulative integral of f using an globally adaptive Gauss-Kronrod Quadrature.
## Returns a sequence of values which is the cumulative integral of f at the points defined in X.
## Important: because of the much higher order of the Gauss-Kronrod quadrature (order 21) compared to the interpolating Hermite spline (order 3) you have to give it a large amount of initialPoints.
Expand Down
32 changes: 16 additions & 16 deletions src/numericalnim/interpolate.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ export rbf
## This module implements various interpolation routines.
## See also:
## - `rbf module<rbf.html>`_ for RBF interpolation of scattered data in arbitrary dimensions.
##
##
## ## 1D interpolation
## - Hermite spline (recommended): cubic spline that works with many types of values. Accepts derivatives if available.
## - Cubic spline: cubic spline that only works with `float`s.
## - Linear spline: Linear spline that works with many types of values.
##
##
## ### Extrapolation
## Extrapolation is supported for all 1D interpolators by passing the type of extrapolation as an argument of `eval`.
## The default is to use the interpolator's native method to extrapolate. This means that Linear does linear extrapolation,
Expand All @@ -26,7 +26,7 @@ export rbf

runnableExamples:
import numericalnim, std/[math, sequtils]

let x = linspace(0.0, 1.0, 10)
let y = x.mapIt(sin(it))

Expand Down Expand Up @@ -173,7 +173,7 @@ proc derivEval_cubicspline*[T](spline: InterpolatorType[T], x: float): T =

proc newCubicSpline*[T: SomeFloat](X: openArray[float], Y: openArray[
T]): InterpolatorType[T] =
## Returns a cubic spline.
## Returns a cubic spline.
let (xSorted, ySorted) = sortAndTrimDataset(@X, @Y)
let coeffs = constructCubicSpline(xSorted, ySorted)
result = InterpolatorType[T](X: xSorted, Y: ySorted, coeffs_T: coeffs, high: xSorted.high,
Expand Down Expand Up @@ -241,7 +241,7 @@ proc newHermiteSpline*[T](X: openArray[float], Y, dY: openArray[

proc newHermiteSpline*[T](X: openArray[float], Y: openArray[
T]): InterpolatorType[T] =
## Constructs a cubic Hermite spline by approximating the derivatives.
## Constructs a cubic Hermite spline by approximating the derivatives.
# if only (x, y) is given, use three-point difference to calculate dY.
let (xSorted, ySorted) = sortAndTrimDataset(@X, @Y)
var dySorted = newSeq[T](ySorted.len)
Expand Down Expand Up @@ -304,16 +304,16 @@ proc eval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extrapolat
## - `Edge`: Use the value of the left/right edge.
## - `Linear`: Uses linear extrapolation using the two points closest to the edge.
## - `Native` (default): Uses the native method of the interpolator to extrapolate. For Linear1D it will be a linear extrapolation, and for Cubic and Hermite splines it will be cubic extrapolation.
## - `Error`: Raises an `ValueError` if `x` is outside the range.
## - `Error`: Raises an `ValueError` if `x` is outside the range.
## - `extrapValue`: The extrapolation value to use when `extrap = Constant`.
##
##
## > Beware: `Native` extrapolation for the cubic splines can very quickly diverge if the extrapolation is too far away from the interpolation points.
when U is Missing:
assert extrap != Constant, "When using `extrap = Constant`, a value `extrapValue` must be supplied!"
else:
when not T is U:
{.error: &"Type of `extrap` ({U}) is not the same as the type of the interpolator ({T})!".}

let xLeft = x < interpolator.X[0]
let xRight = x > interpolator.X[^1]
if xLeft or xRight:
Expand All @@ -330,7 +330,7 @@ proc eval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extrapolat
if xLeft: interpolator.Y[0]
else: interpolator.Y[^1]
of Linear:
let (xs, ys) =
let (xs, ys) =
if xLeft:
((interpolator.X[0], interpolator.X[1]), (interpolator.Y[0], interpolator.Y[1]))
else:
Expand All @@ -341,7 +341,7 @@ proc eval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extrapolat
raise newException(ValueError, &"x = {x} isn't in the interval [{interpolator.X[0]}, {interpolator.X[^1]}]")

result = interpolator.eval_handler(interpolator, x)


proc derivEval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: ExtrapolateKind = Native, extrapValue: U = missing()): T =
## Evaluates the derivative of an interpolator.
Expand All @@ -351,9 +351,9 @@ proc derivEval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extra
## - `Edge`: Use the value of the left/right edge.
## - `Linear`: Uses linear extrapolation using the two points closest to the edge.
## - `Native` (default): Uses the native method of the interpolator to extrapolate. For Linear1D it will be a linear extrapolation, and for Cubic and Hermite splines it will be cubic extrapolation.
## - `Error`: Raises an `ValueError` if `x` is outside the range.
## - `Error`: Raises an `ValueError` if `x` is outside the range.
## - `extrapValue`: The extrapolation value to use when `extrap = Constant`.
##
##
## > Beware: `Native` extrapolation for the cubic splines can very quickly diverge if the extrapolation is too far away from the interpolation points.
when U is Missing:
assert extrap != Constant, "When using `extrap = Constant`, a value `extrapValue` must be supplied!"
Expand Down Expand Up @@ -390,7 +390,7 @@ proc derivEval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extra
result = interpolator.deriveval_handler(interpolator, x)

proc eval*[T; U](spline: InterpolatorType[T], x: openArray[float], extrap: ExtrapolateKind = Native, extrapValue: U = missing()): seq[T] =
## Evaluates an interpolator at all points in `x`.
## Evaluates an interpolator at all points in `x`.
result = newSeq[T](x.len)
for i, xi in x:
result[i] = eval(spline, xi, extrap, extrapValue)
Expand All @@ -399,7 +399,7 @@ proc toProc*[T](spline: InterpolatorType[T]): InterpolatorProc[T] =
## Returns a proc to evaluate the interpolator.
result = proc(x: float): T = eval(spline, x)

converter toNumContextProc*[T](spline: InterpolatorType[T]): NumContextProc[T, float] =
proc toNumContextProc*[T](spline: InterpolatorType[T]): NumContextProc[T, float] =
## Convert interpolator to `NumContextProc`.
result = proc(x: float, ctx: NumContext[T, float]): T = eval(spline, x)

Expand Down Expand Up @@ -655,11 +655,11 @@ proc eval_barycentric2d*[T, U](self: InterpolatorUnstructured2DType[T, U]; x, y:

proc newBarycentric2D*[T: SomeFloat, U](points: Tensor[T], values: Tensor[U]): InterpolatorUnstructured2DType[T, U] =
## Barycentric interpolation of scattered points in 2D.
##
##
## Inputs:
## - points: Tensor of shape (nPoints, 2) with the coordinates of all points.
## - values: Tensor of shape (nPoints) with the function values.
##
##
## Returns:
## - Interpolator object that can be evaluated using `interp.eval(x, y`.
assert points.rank == 2 and points.shape[1] == 2
Expand Down
95 changes: 95 additions & 0 deletions src/numericalnim/private/macro_utils.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import std / macros
proc checkArgNumContext(fn: NimNode) =
## Checks the first argument of the given proc is indeed a `NumContextProc` argument.
let params = fn.params
# FormalParams <- `.params`
# Ident "T"
# IdentDefs <- `params[1]`
# Sym "f"
# BracketExpr <- `params[1][1]`
# Sym "NumContextProc" <- `params[1][1][0]`
# Ident "T"
# Sym "float"
# Empty
expectKind params, nnkFormalParams
expectKind params[1], nnkIdentDefs
expectKind params[1][1], nnkBracketExpr
expectKind params[1][1][0], {nnkSym, nnkIdent}
if params[1][1][0].strVal != "NumContextProc":
error("The function annotated with `{.genInterp.}` does not take a `NumContextProc` as the firs argument.")

proc replaceNumCtxArg(fn: NimNode): NimNode =
## Checks the first argument of the given proc is indeed a `NumContextProc` argument.
## MUST run `checkArgNumContext` on `fn` first.
##
## It returns the identifier of the first argument.
var params = fn.params # see `checkArgNNumContext`
expectKind params[1][0], {nnkSym, nnkIdent}
result = ident(params[1][0].strVal)
params[1] = nnkIdentDefs.newTree(
result,
nnkBracketExpr.newTree(
ident"InterpolatorType",
ident"T"
),
newEmptyNode()
)
fn.params = params

proc untype(n: NimNode): NimNode =
case n.kind
of nnkSym: result = ident(n.strVal)
of nnkIdent: result = n
else:
error("Cannot untype the argument: " & $n.treerepr)

proc genOriginalCall(fn: NimNode, ncp: NimNode): NimNode =
## Generates a call to the original procedure `fn` with `ncp`
## as the first argument
let fnName = fn.name
let params = fn.params
# extract all arguments we need to pass from `params`
var p = newSeq[NimNode]()
p.add ncp
for i in 2 ..< params.len: # first param is return type, second is parameter we replace
expectKind params[i], nnkIdentDefs
if params[i].len in 0 .. 2:
error("Invalid parameter: " & $params[i].treerepr)
else: # one or more arg of this type
# IdentDefs <- Example with 2 arguments of the same type
# Ident "xStart" <- index `0`
# Ident "xEnd" <- index `len - 3 = 4 - 3 = 1`
# Ident "float"
# Empty
for j in 0 .. params[i].len - 3:
p.add untype(params[i][j])
# generate the call
result = nnkCall.newTree(fnName)
for el in p:
result.add el

macro genInterp*(fn: untyped): untyped =
## Takes a `proc` with a `NumContextProc` parameter as the first argument
## and returns two procedures:
## 1. The original proc
## 2. An overload, which converts an `InterpolatorType[T]` argument to a
## `NumContextProc[T, float]` using the conversion proc.
doAssert fn.kind in {nnkProcDef, nnkFuncDef}
result = newStmtList(fn)
# 1. check arg
checkArgNumContext(fn)
# 2. generate overload
var new = fn.copyNimTree()
# 2a. replace first argument by `InterpolatorType[T]`
let arg = new.replaceNumCtxArg()
# 2b. add body with NumContextProc
let ncpIdent = ident"ncp"
new.body = quote do:
mixin eval # defined in `interpolate`, but macro used in `integrate`
mixin InterpolatorType
mixin toNumContextProc
let `ncpIdent` = toNumContextProc(`arg`)
# 2c. add call to original proc
new.body.add genOriginalCall(fn, ncpIdent)
# 3. finalize
result.add new
Loading