Skip to content

Commit

Permalink
Don't eval in __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 27, 2024
1 parent c85e63e commit 4ccff4b
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ module ReactantNNlibExt
using NNlib
using Reactant

function __init__()
for (jlop, hloop) in ((:(NNlib.tanh), :tanh),(:(NNlib.tanh_fast), :tanh),)
for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh),)
@eval begin
if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast
function Reactant.elem_apply(::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
Expand All @@ -13,10 +12,9 @@ for (jlop, hloop) in ((:(NNlib.tanh), :tanh),(:(NNlib.tanh_fast), :tanh),)
end
end
end
end

# TODO handle non finite cases
function NNlib.softmax!(out::Reactant.TracedRArray{T, Shape, N}, x::AbstractArray; dims = 1) where {T, Shape, N}
function NNlib.softmax!(out::Reactant.TracedRArray{T,Shape,N}, x::AbstractArray; dims=1) where {T,Shape,N}
max_ = NNlib.fast_maximum(x; dims)
#if all(isfinite, max_)
@fastmath out .= exp.(x .- max_)
Expand Down

0 comments on commit 4ccff4b

Please sign in to comment.