diff --git a/Project.toml b/Project.toml index 74decf6c0..acc3130cb 100644 --- a/Project.toml +++ b/Project.toml @@ -57,7 +57,7 @@ Preferences = "1.4" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.3" -Reactant_jll = "0.0.26" +Reactant_jll = "0.0.27" Scratch = "1.2" Statistics = "1.10" YaoBlocks = "0.13" diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index f85bd1d84..e90c00076 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -4,7 +4,8 @@ using NNlib using GPUArraysCore: @allowscalar using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber -using Reactant.TracedUtils: materialize_traced_array, get_mlir_data, set_mlir_data! +using Reactant.TracedUtils: + TracedUtils, materialize_traced_array, get_mlir_data, set_mlir_data! using ReactantCore: @trace using LinearAlgebra: LinearAlgebra, triu @@ -20,8 +21,8 @@ end function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N} max_ = NNlib.fast_maximum(x; dims) # XXX: Once reverse mode of if is properly supported, we can make it @trace - # zero_num = Reactant.promote_to(TracedRNumber{T}, 0) - # one_num = Reactant.promote_to(TracedRNumber{T}, 1) + # zero_num = TracedUtils.promote_to(TracedRNumber{T}, 0) + # one_num = TracedUtils.promote_to(TracedRNumber{T}, 1) # @trace if all(isfinite, max_) @. out = exp(x - max_) # else @@ -37,8 +38,8 @@ end function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T} max_ = NNlib.fast_maximum(x; dims) # XXX: Once reverse mode of if is properly supported, we can make it @trace - # inf_num = Reactant.promote_to(TracedRNumber{T}, Inf) - # zero_num = Reactant.promote_to(TracedRNumber{T}, 0) + # inf_num = TracedUtils.promote_to(TracedRNumber{T}, Inf) + # zero_num = TracedUtils.promote_to(TracedRNumber{T}, 0) # @trace if all(isfinite, max_) @. out = x - max_ # else @@ -232,9 +233,9 @@ function NNlib.batched_mul!( if size(x, 3) != size(y, 3) B = max(size(x, 3), size(y, 3)) if size(x, 3) == 1 - x = Reactant.TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) + x = TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) elseif size(y, 3) == 1 - y = Reactant.TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) + y = TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) end end @@ -244,9 +245,9 @@ function NNlib.batched_mul!( if size(x, 1) != size(y, 1) B = max(size(x, 1), size(y, 1)) if size(x, 1) == 1 - x = Reactant.TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) + x = TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) elseif size(y, 1) == 1 - y = Reactant.TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3))) + y = TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3))) end end @@ -264,7 +265,7 @@ end function NNlib.pad_constant( x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value ) where {T,N} - value = Reactant.TracedUtils.promote_to(TracedRNumber{T}, value) + value = TracedUtils.promote_to(TracedRNumber{T}, value) low = [i[1] for i in pad] high = [i[2] for i in pad] interior = [0 for i in pad] @@ -287,8 +288,10 @@ function NNlib.gather!( ) where {T1,T2} dims = NNlib.scatter_dims(src, dst, idxs) @assert dims == 1 # scatter_dims lets us do some size checks so we call that function - idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1) - slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1])) + idxs = get_mlir_data(TracedUtils.promote_to(TracedRArray{Int,1}, idxs) .- 1) + slice_sizes = get_mlir_data( + TracedUtils.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]) + ) #! format: off dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( @@ -323,8 +326,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr start_sizes = ntuple(i -> size(src, i), dims) results = map(CartesianIndices(idxs)) do k res = @allowscalar src[colons..., Tuple(idxs[k])...] - res isa TracedRNumber && - (res = Reactant.TracedUtils.broadcast_to_size(res, (1,))) + res isa TracedRNumber && (res = TracedUtils.broadcast_to_size(res, (1,))) return reshape(res, start_sizes..., :) end res = reshape(cat(results...; dims=(dims + 1)), size(dst)) diff --git a/src/Compiler.jl b/src/Compiler.jl index 074dabf90..deb124869 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -246,6 +246,7 @@ const opt_passes::String = join( "if_inline<1>", "if_to_select<1>", "dynamic_update_slice_const_prop", + "dynamic_gather_op_is_not_dynamic<16>", ], ';', ) * diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 4b71a1341..f5e4475a2 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -283,7 +283,7 @@ function overload_autodiff( end end for (i, act) in enumerate(activity) - if act == enzyme_out || (reverse && (act == enzyme_dup || act == enzyme_dupnoneed)) + if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed if width == 1 push!(outtys, in_tys[i]) else diff --git a/test/autodiff.jl b/test/autodiff.jl index 5cf1726d0..842050413 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -11,7 +11,7 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) res1 = @jit( fwd( - set_abi(Forward, Reactant.ReactantABI), + Forward, Duplicated, ConcreteRArray(ones(3, 2)), ConcreteRArray(3.1 * ones(3, 2)), @@ -42,12 +42,7 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) @test typeof(ores1) == Tuple{} res1 = @jit( - fwd( - set_abi(Forward, Reactant.ReactantABI), - Const, - ConcreteRArray(ones(3, 2)), - ConcreteRArray(3.1 * ones(3, 2)), - ) + fwd(Forward, Const, ConcreteRArray(ones(3, 2)), ConcreteRArray(3.1 * ones(3, 2))) ) @test typeof(res1) == Tuple{}