Skip to content

Commit

Permalink
Implement _zeros and _ones utils
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeingold committed Nov 25, 2024
1 parent 1bccd58 commit 4c2d777
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function _integral(
# Create a wrapper that returns only the value component in those units
uintegrand(ts) = Unitful.ustrip.(integrandunits, integrand(ts))
# Integrate only the unitless values
value = HCubature.hcubature(uintegrand, zeros(FP, N), ones(FP, N); rule.kwargs...)[1]
value = HCubature.hcubature(uintegrand, _zeros(FP, N), _ones(FP, N); rule.kwargs...)[1]

# Reapply units
return value .* integrandunits
Expand Down
2 changes: 1 addition & 1 deletion src/specializations/BezierCurve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function integral(
# Create a wrapper that returns only the value component in those units
uintegrand(ts) = Unitful.ustrip.(integrandunits, integrand(ts))
# Integrate only the unitless values
value = HCubature.hcubature(uintegrand, zeros(FP, 1), ones(FP, 1); rule.kwargs...)[1]
value = HCubature.hcubature(uintegrand, _zeros(FP, 1), _ones(FP, 1); rule.kwargs...)[1]

# Reapply units
return value .* integrandunits
Expand Down
4 changes: 2 additions & 2 deletions src/specializations/Line.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ function integral(

# HCubature doesn't support functions that output Unitful Quantity types
# Establish the units that are output by f
testpoint_parametriccoord = FP[0.5]
testpoint_parametriccoord = (FP(0.5),)
integrandunits = Unitful.unit.(integrand(testpoint_parametriccoord))
# Create a wrapper that returns only the value component in those units
uintegrand(uv) = Unitful.ustrip.(integrandunits, integrand(uv))
# Integrate only the unitless values
value = HCubature.hcubature(uintegrand, (-one(FP),), (one(FP),); rule.kwargs...)[1]
value = HCubature.hcubature(uintegrand, -_ones(FP, 1), _ones(FP, 1); rule.kwargs...)[1]

# Reapply units
return value .* integrandunits
Expand Down
2 changes: 1 addition & 1 deletion src/specializations/Plane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ function integral(
# Create a wrapper that returns only the value component in those units
uintegrand(uv) = Unitful.ustrip.(integrandunits, integrand(uv))
# Integrate only the unitless values
value = HCubature.hcubature(uintegrand, -ones(FP, 2), ones(FP, 2); rule.kwargs...)[1]
value = HCubature.hcubature(uintegrand, -_ones(FP, 2), _ones(FP, 2); rule.kwargs...)[1]

# Reapply units
return value .* integrandunits
Expand Down
2 changes: 1 addition & 1 deletion src/specializations/Ray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function integral(
# Create a wrapper that returns only the value component in those units
uintegrand(uv) = Unitful.ustrip.(integrandunits, integrand(uv))
# Integrate only the unitless values
value = HCubature.hcubature(uintegrand, zeros(FP, 1), ones(FP, 1); rule.kwargs...)[1]
value = HCubature.hcubature(uintegrand, _zeros(FP, 1), _ones(FP, 1); rule.kwargs...)[1]

# Reapply units
return value .* integrandunits
Expand Down
2 changes: 1 addition & 1 deletion src/specializations/Triangle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function integral(
v = R * (1 - b / (a + b))
return f(triangle(u, v)) * R / (a + b)^2
end
= HCubature.hcubature(integrand, zeros(FP, 2), FP[1, π / 2], rule.kwargs...)[1]
= HCubature.hcubature(integrand, _zeros(FP, 2), (FP(1), FP(π / 2)), rule.kwargs...)[1]

# Apply a linear domain-correction factor 0.5 ↦ area(triangle)
return 2 * Meshes.area(triangle) .*
Expand Down
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ function _error_unsupported_combination(geometry, rule)
throw(ArgumentError(msg))
end

# Return an NTuple{N, T} of zeros; same interface as Base.zeros() but faster
_zeros(T::DataType, N::Int64) = ntuple(_ -> zero(T), N)
_zeros(N::Int) = _zeros(Float64, N)

# Return an NTuple{N, T} of ones; same interface as Base.ones() but faster
_ones(T::DataType, N::Int64) = ntuple(_ -> one(T), N)
_ones(N::Int) = _ones(Float64, N)

################################################################################
# DifferentiationMethod
################################################################################
Expand Down

0 comments on commit 4c2d777

Please sign in to comment.