diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 544dfe649..c0d9e0f03 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -14,13 +14,14 @@ for (jlop, hloop) in ( Reactant.MLIR.IR.result( Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1 ), + (), ) end end -NNlib.relu(x::Reactant.TracedRArray{T,(),0}) where {T} = max(x, zero(T)) +NNlib.relu(x::Reactant.TracedRArray{T,0}) where {T} = max(x, zero(T)) -NNlib.gelu(x::Reactant.TracedRArray{T,(),0}) where {T} = x * sigmoid(T(1.702) * x) +NNlib.gelu(x::Reactant.TracedRArray{T,0}) where {T} = x * sigmoid(T(1.702) * x) # TODO handle non finite cases function NNlib.softmax!( diff --git a/src/Reactant.jl b/src/Reactant.jl index 006667fa8..29c7d3a02 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -17,6 +17,11 @@ function mlir_type(x::RArray{T,N}) where {T,N} return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) end +function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N} + @assert length(shape) == N + return MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) +end + function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) )::RT where {copy_if_inactive,RT<:RArray} diff --git a/src/Tracing.jl b/src/Tracing.jl index 89adce461..4119eaeb7 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -92,20 +92,21 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} (), MLIR.IR.result( MLIR.Dialects.stablehlo.convert( - rhs.mlir_data; result=mlir_type(TracedRArray{T,N}) + rhs.mlir_data; result=mlir_type(TracedRArray{T,N}, size(rhs)) ), 1, ), + size(rhs), ) end if isa(rhs, Number) - attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRArray{T,N})) + attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRArray{T,N}, size(rhs))) ta = TracedRArray{T,N}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) + (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs) ) return ta end - attr = MLIR.IR.DenseElementsAttribute(mlir_type(TracedRArray{T,N}), rhs) + attr = MLIR.IR.DenseElementsAttribute(mlir_type(TracedRArray{T,N}, size(rhs)), rhs) return TracedRArray{T,N}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs) ) @@ -115,11 +116,11 @@ function promote_to(lhs::TracedRArray{T,N}, rhs) where {T,N} return promote_to(TracedRArray{T,N}, rhs) end -for (jlop, hloop) in ( - (:(Base.min), :minimum), - (:(Base.max), :maximum), - (:(Base.:+), :add), - (:(Base.:-), :subtract), +for (jlop, hloop, RT) in ( + (:(Base.min), :minimum, :T), + (:(Base.max), :maximum, :T), + (:(Base.:+), :add, :T), + (:(Base.:-), :subtract, :T), ) @eval begin function $jlop(lhs::TracedRArray{T,N}, rhs::TracedRArray{T2,N}) where {T,T2,N} @@ -136,7 +137,7 @@ for (jlop, hloop) in ( end function $jlop(lhs::TracedRArray{T,N}, rhs::TracedRArray{T,N}) where {T,N} - return TracedRArray{T,N}( + return TracedRArray{$RT,N}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 @@ -144,32 +145,37 @@ for (jlop, hloop) in ( size(lhs), ) end + end - function $jlop(lhs::TracedRArray{T,N}, rhs) where {T,N} - rhs = promote_to(lhs, rhs) - return TracedRArray{T,N}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 - ), - size(lhs), - ) - end + for otherType in (Number, Any) #=TracedRArray{S,0} where {S}=# + @eval begin + function $jlop(lhs::TracedRArray{T,N}, rhs::$otherType) where {T,N} + rhs = promote_to(lhs, rhs) + return TracedRArray{$RT,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + size(lhs), + ) + end - function $jlop(lhs, rhs::TracedRArray{T,N}) where {T,N} - lhs = promote_to(rhs, lhs) - return TracedRArray{T,N}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 - ), - size(lhs), - ) + function $jlop(lhs::$otherType, rhs::TracedRArray{T,N}) where {T,N} + lhs = promote_to(rhs, lhs) + return TracedRArray{$RT,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + size(lhs), + ) + end end end end -for (jlop, hloop) in ((:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^), :power)) +for (jlop, hloop, RT) in + ((:(Base.:*), :multiply, :T), (:(Base.:/), :divide, :T), (:(Base.:^), :power, :T)) @eval begin function $jlop(lhs::TracedRArray{T,0}, rhs::TracedRArray{T2,0}) where {T,T2} commonTy = TracedRArray{Base.promote_type(T, T2),0} @@ -185,7 +191,7 @@ for (jlop, hloop) in ((:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^ end function $jlop(lhs::TracedRArray{T,0}, rhs::TracedRArray{T,0}) where {T} - return TracedRArray{T,0}( + return TracedRArray{$RT,0}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 @@ -196,7 +202,7 @@ for (jlop, hloop) in ((:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^ function $jlop(lhs::TracedRArray{T,0}, rhs) where {T} rhs = promote_to(lhs, rhs) - return TracedRArray{T,0}( + return TracedRArray{$RT,0}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 @@ -207,7 +213,30 @@ for (jlop, hloop) in ((:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^ function $jlop(lhs, rhs::TracedRArray{T,0}) where {T} lhs = promote_to(rhs, lhs) - return TracedRArray{T,0}( + return TracedRArray{$RT,0}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + (), + ) + end + + # Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity + function $jlop(lhs::TracedRArray{T,0}, rhs::Number) where {T} + rhs = promote_to(lhs, rhs) + return TracedRArray{$RT,0}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + (), + ) + end + + function $jlop(lhs::Number, rhs::TracedRArray{T,0}) where {T} + lhs = promote_to(rhs, lhs) + return TracedRArray{$RT,0}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 @@ -218,6 +247,20 @@ for (jlop, hloop) in ((:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^ end end +function Base.ifelse( + pred::TracedRArray{Bool,0}, x::TracedRArray{T1,0}, y::TracedRArray{T2,0} +) where {T1,T2} + return TracedRArray{promote_type(T1, T2),0}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1 + ), + size(pred), + ) +end + +Base.abs2(x::Reactant.TracedRArray{T,0}) where {T} = x * conj(x) + function Base.literal_pow( ::Base.RefValue{typeof(^)}, x::TracedRArray{T,0}, ::Base.RefValue{Val{P}} ) where {T,P}