diff --git a/ext/MeshIntegralsEnzymeExt.jl b/ext/MeshIntegralsEnzymeExt.jl index 339b8b64..e4a676f6 100644 --- a/ext/MeshIntegralsEnzymeExt.jl +++ b/ext/MeshIntegralsEnzymeExt.jl @@ -16,4 +16,12 @@ function MeshIntegrals.jacobian( return Meshes.to.(Enzyme.jacobian(Enzyme.Forward, geometry, ts...)) end +# Supports all geometries except for a few +# See GitHub Issue #154 for more information. +supports_autoenzyme(::Type{<:Meshes.Geometry}) = true +supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false +supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false +supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false +supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false + end diff --git a/src/utils.jl b/src/utils.jl index 84ebfdc7..ad55784f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,11 +18,10 @@ end Return whether a geometry (or geometry type) has a parametric function that can be differentiated with Enzyme. See GitHub Issue #154 for more information. """ -supports_autoenzyme(::Type{<:Meshes.Geometry}) = true -supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false -supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false -supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false -supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false +function supports_autoenzyme end + +# Returns false on all geometries when Enzyme extension not loaded +supports_autoenzyme(::Any) = false supports_autoenzyme(::G) where {G <: Geometry} = supports_autoenzyme(G) """ @@ -44,16 +43,24 @@ Return an instance of the default DifferentiationMethod for a particular geometr (or geometry type) and floating point type. """ function _default_diff_method( - g::Type{G}, FP::Type{T} + ::Type{G}, + FP::Type{T} ) where {G <: Geometry, T <: AbstractFloat} - if supports_autoenzyme(g) && FP <: Union{Float32, Float64} + # Enzyme only works with these FP types + EnzymeSupportedFPs = Union{Float32, Float64} + + if supports_autoenzyme(G) && (FP <: EnzymeSupportedFPs) AutoEnzyme() else FiniteDifference() end end -function _default_diff_method(::G, ::Type{T}) where {G <: Geometry, T <: AbstractFloat} +# Re-run with the Geometry type as first argument +function _default_diff_method( + ::G, + ::Type{T} +) where {G <: Geometry, T <: AbstractFloat} _default_diff_method(G, T) end