diff --git a/Project.toml b/Project.toml index ba91dec15..d47c765a0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy "] -version = "0.1.0" +version = "0.1.1" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -12,16 +12,19 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" [weakdeps] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" [extensions] +ReactantArrayInterfaceExt = "ArrayInterface" ReactantNNlibExt = "NNlib" [compat] -Cassette = "0.3" +ArrayInterface = "7.10" CEnum = "0.4, 0.5" +Cassette = "0.3" Enzyme = "0.11, 0.12" -Reactant_jll = "0.0.6" +NNlib = "0.9" Preferences = "1.4" +Reactant_jll = "0.0.6" julia = "1" - diff --git a/ext/ReactantArrayInterfaceExt.jl b/ext/ReactantArrayInterfaceExt.jl new file mode 100644 index 000000000..3fc66ce0f --- /dev/null +++ b/ext/ReactantArrayInterfaceExt.jl @@ -0,0 +1,9 @@ +module ReactantArrayInterfaceExt + +using ArrayInterface: ArrayInterface +using Reactant: RArray + +ArrayInterface.can_setindex(::Type{<:RArray}) = false +ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false + +end diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 605915b94..7f9351416 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -3,20 +3,18 @@ module ReactantNNlibExt using NNlib using Reactant -function __init__() - for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh)) - @eval begin - if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast - function Reactant.elem_apply( - ::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N} - ) where {ElType,Shape,N} - return Reactant.TracedRArray{ElType,Shape,N}( - (), - Reactant.MLIR.IR.result( - Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1 - ), - ) - end +for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh)) + @eval begin + if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast + function Reactant.elem_apply( + ::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} + return Reactant.TracedRArray{ElType,Shape,N}( + (), + Reactant.MLIR.IR.result( + Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1 + ), + ) end end end @@ -36,4 +34,5 @@ function NNlib.softmax!( tmp = dims isa Colon ? sum(out) : sum!(max_, out) return out ./= tmp end + end diff --git a/src/Reactant.jl b/src/Reactant.jl index 2389f083c..d21aaf0cf 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -8,6 +8,7 @@ abstract type RArray{ElType,Shape,N} <: AbstractArray{ElType,N} end @inline Base.eltype(::RArray{ElType,Shape}) where {ElType,Shape} = ElType @inline Base.size(::RArray{ElType,Shape}) where {ElType,Shape} = Shape +@inline Base.size(::Type{<:RArray{ElType,Shape}}) where {ElType,Shape} = Shape @inline Base.ndims(::RArray{ElType,Shape,N}) where {ElType,Shape,N} = N @inline Base.ndims(::Type{<:RArray{ElType,Shape,N}}) where {ElType,Shape,N} = N @@ -162,6 +163,12 @@ end return res end + if RT <: TracedRArray + res = broadcast_to_size(eltype(RT)(0), size(prev)) + seen[prev] = res + return res + end + attr = fill(MLIR.IR.Attribute(eltype(RT)(0)), mlir_type(prev)) cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) res = RT((), cst) @@ -189,6 +196,7 @@ include("overloads.jl") using Enzyme @inline val_value(::Val{T}) where {T} = T +@inline val_value(::Type{Val{T}}) where {T} = T @enum TraceMode begin ConcreteToTraced = 1 @@ -649,7 +657,7 @@ function generate_jlfunc( end res = Symbol("arg_$(path[2])") for p in path[3:end] - res = :(Base.getfield($res, $p)) + res = :(Base.getfield($res, $(Meta.quot(p)))) end sym = Symbol("sbuf_$i") sbuf = :($sym = XLA.synced_buffer($res.data)) @@ -819,7 +827,7 @@ function generate_jlfunc( return nothing end - return error("canot copy $T") + return error("cannot copy $T") end create_result(concrete_result, :result, ()) diff --git a/src/overloads.jl b/src/overloads.jl index 643c69a2d..4518d5f9e 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -749,6 +749,13 @@ Cassette.overdub(context::TraceCtx, f::typeof(Base.materialize!), args...) = f(a @inline Base.copyto!(dest::TracedRArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict +@inline function Base.copyto!( + dest::TracedRArray{ElType,Shape,N}, src::TracedRArray{ElType,Shape,N} +) where {ElType,Shape,N} + dest.mlir_data = src.mlir_data + return dest +end + @inline function broadcast_to_size(arg::AbstractArray, rsize) attr = MLIR.IR.DenseElementsAttribute(arg) len = ndims(arg) diff --git a/src/utils.jl b/src/utils.jl index 2dfd97567..d2bd02000 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,7 +9,7 @@ function transpose_val(val) return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) end -function apply(f, args...; kwargs) +function apply(f, args...; kwargs...) return f(args...; kwargs...) end