Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: patch more enzyme issues
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 17, 2024
1 parent dd76953 commit ad7790b
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 12 deletions.
2 changes: 1 addition & 1 deletion ext/LuxLibLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end

# batched matmul
function LuxLib.Impl.batched_matmul_loopvec_impl!(
z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
::True, z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}, α::Number=true, β::Number=false) where {zT, xT, yT}
if size(x, 3) == size(y, 3)
@batch for L in axes(z, 3)
Expand Down
15 changes: 11 additions & 4 deletions src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,21 @@ function batched_matmul!(z::AbstractArray{zT, 3}, ::LoopedArrayOp,
return
end

function batched_matmul_cpu!(z::AbstractArray{zT, 3},
x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT}
function batched_matmul_cpu!(
z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3},
α::Number=true, β::Number=false) where {zT, xT, yT}
if can_loopvec_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) &&
!unsafe_known(explicit_blas_loaded())
batched_matmul_loopvec_impl!(z, x, y)
batched_matmul_loopvec_impl!is_extension_loaded(
Val(:LoopVectorization), z, x, y, α, β)
return
end
NNlib.batched_mul!(z, x, y)
NNlib.batched_mul!(z, x, y, α, β)
return
end

function batched_matmul_loopvec_impl!(_, z, x, y, α, β)
NNlib.batched_mul!(z, x, y, α, β)
return
end

Expand Down
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ const KA = KernelAbstractions

is_extension_loaded(::Val) = False()

CRC.@non_differentiable is_extension_loaded(::Any...)
EnzymeRules.inactive_noinl(::typeof(is_extension_loaded), ::Any...) = nothing

# Simple Operations -- no rrules needed
ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x
function ofeltype_array(
Expand Down Expand Up @@ -328,4 +331,8 @@ end

@inline can_loopvec_args_check(::False, args...) = false

CRC.@non_differentiable can_loopvec_args_check(::Any...)

EnzymeRules.inactive_noinl(::typeof(can_loopvec_args_check), ::Any...) = nothing

end
2 changes: 1 addition & 1 deletion test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
@jet apply_act_fast2(f, x)

@test @inferred(Zygote.gradient(apply_act, f, x)) isa Any
if f !== lisht || (f === lisht && T == Float32 && !ongpu)
if f !== lisht
@test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any
end
@test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any
Expand Down
11 changes: 5 additions & 6 deletions test/common_ops/bias_act_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@
@jet bias_act_loss2(act, x, b)
@jet bias_act_loss3(act, x, b)

if (act !== lisht || (act === lisht && T == Float32 && !ongpu)) && T != Float16
@test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any
@test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any
elseif T != Float16
@test_broken @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any
@test_broken @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any
if act !== lisht
@test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any broken=(T !=
Float16)
@test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any broken=(T !=
Float16)
end

@test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol,
Expand Down

0 comments on commit ad7790b

Please sign in to comment.