Skip to content

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed May 28, 2024
1 parent a4f685c commit e108c7a
Show file tree
Hide file tree
Showing 26 changed files with 2,550 additions and 1,654 deletions.
21 changes: 12 additions & 9 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,25 @@ function __init__()
for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh))
@eval begin
if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast
function Reactant.elem_apply(::typeof($jlop),
lhs::Reactant.TracedRArray{ElType,Shape,N}) where {ElType,
Shape,
N}
return Reactant.TracedRArray{ElType,Shape,N}((),
Reactant.MLIR.IR.result(Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data),
1))
function Reactant.elem_apply(
::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
return Reactant.TracedRArray{ElType,Shape,N}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1
),
)
end
end
end
end
end

# TODO handle non finite cases
function NNlib.softmax!(out::Reactant.TracedRArray{T,Shape,N}, x::AbstractArray;
dims=1) where {T,Shape,N}
function NNlib.softmax!(
out::Reactant.TracedRArray{T,Shape,N}, x::AbstractArray; dims=1
) where {T,Shape,N}
max_ = NNlib.fast_maximum(x; dims)
#if all(isfinite, max_)
@fastmath out .= exp.(x .- max_)
Expand Down
Loading

0 comments on commit e108c7a

Please sign in to comment.