Skip to content

Commit

Permalink
supports_autoenzyme always false unless ext is loaded
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeingold committed Dec 28, 2024
1 parent d4a7bde commit ec03b1d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
8 changes: 8 additions & 0 deletions ext/MeshIntegralsEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,12 @@ function MeshIntegrals.jacobian(
return Meshes.to.(Enzyme.jacobian(Enzyme.Forward, geometry, ts...))
end

# Supports all geometries except for a few
# See GitHub Issue #154 for more information.
supports_autoenzyme(::Type{<:Meshes.Geometry}) = true
supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false
supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false
supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false
supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false

end
23 changes: 15 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ end
Return whether a geometry (or geometry type) has a parametric function that can be
differentiated with Enzyme. See GitHub Issue #154 for more information.
"""
supports_autoenzyme(::Type{<:Meshes.Geometry}) = true
supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false
supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false
supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false
supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false
function supports_autoenzyme end

# Returns false on all geometries when Enzyme extension not loaded
supports_autoenzyme(::Any) = false
supports_autoenzyme(::G) where {G <: Geometry} = supports_autoenzyme(G)

"""
Expand All @@ -44,16 +43,24 @@ Return an instance of the default DifferentiationMethod for a particular geometr
(or geometry type) and floating point type.
"""
function _default_diff_method(
g::Type{G}, FP::Type{T}
::Type{G},
FP::Type{T}
) where {G <: Geometry, T <: AbstractFloat}
if supports_autoenzyme(g) && FP <: Union{Float32, Float64}
# Enzyme only works with these FP types
EnzymeSupportedFPs = Union{Float32, Float64}

if supports_autoenzyme(G) && (FP <: EnzymeSupportedFPs)
AutoEnzyme()
else
FiniteDifference()
end
end

function _default_diff_method(::G, ::Type{T}) where {G <: Geometry, T <: AbstractFloat}
# Re-run with the Geometry type as first argument
function _default_diff_method(
::G,
::Type{T}
) where {G <: Geometry, T <: AbstractFloat}
_default_diff_method(G, T)
end

Expand Down

0 comments on commit ec03b1d

Please sign in to comment.