From 242cfdfcded0ca8369d051364efe67d7cb43b74f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 19 Dec 2024 11:19:18 +0530 Subject: [PATCH] fix: missing promote_to in NNlibExt --- ext/ReactantNNlibExt.jl | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index f85bd1d84..e90c00076 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -4,7 +4,8 @@ using NNlib using GPUArraysCore: @allowscalar using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber -using Reactant.TracedUtils: materialize_traced_array, get_mlir_data, set_mlir_data! +using Reactant.TracedUtils: + TracedUtils, materialize_traced_array, get_mlir_data, set_mlir_data! using ReactantCore: @trace using LinearAlgebra: LinearAlgebra, triu @@ -20,8 +21,8 @@ end function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N} max_ = NNlib.fast_maximum(x; dims) # XXX: Once reverse mode of if is properly supported, we can make it @trace - # zero_num = Reactant.promote_to(TracedRNumber{T}, 0) - # one_num = Reactant.promote_to(TracedRNumber{T}, 1) + # zero_num = TracedUtils.promote_to(TracedRNumber{T}, 0) + # one_num = TracedUtils.promote_to(TracedRNumber{T}, 1) # @trace if all(isfinite, max_) @. out = exp(x - max_) # else @@ -37,8 +38,8 @@ end function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T} max_ = NNlib.fast_maximum(x; dims) # XXX: Once reverse mode of if is properly supported, we can make it @trace - # inf_num = Reactant.promote_to(TracedRNumber{T}, Inf) - # zero_num = Reactant.promote_to(TracedRNumber{T}, 0) + # inf_num = TracedUtils.promote_to(TracedRNumber{T}, Inf) + # zero_num = TracedUtils.promote_to(TracedRNumber{T}, 0) # @trace if all(isfinite, max_) @. out = x - max_ # else @@ -232,9 +233,9 @@ function NNlib.batched_mul!( if size(x, 3) != size(y, 3) B = max(size(x, 3), size(y, 3)) if size(x, 3) == 1 - x = Reactant.TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) + x = TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) elseif size(y, 3) == 1 - y = Reactant.TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) + y = TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) end end @@ -244,9 +245,9 @@ function NNlib.batched_mul!( if size(x, 1) != size(y, 1) B = max(size(x, 1), size(y, 1)) if size(x, 1) == 1 - x = Reactant.TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) + x = TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) elseif size(y, 1) == 1 - y = Reactant.TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3))) + y = TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3))) end end @@ -264,7 +265,7 @@ end function NNlib.pad_constant( x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value ) where {T,N} - value = Reactant.TracedUtils.promote_to(TracedRNumber{T}, value) + value = TracedUtils.promote_to(TracedRNumber{T}, value) low = [i[1] for i in pad] high = [i[2] for i in pad] interior = [0 for i in pad] @@ -287,8 +288,10 @@ function NNlib.gather!( ) where {T1,T2} dims = NNlib.scatter_dims(src, dst, idxs) @assert dims == 1 # scatter_dims lets us do some size checks so we call that function - idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1) - slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1])) + idxs = get_mlir_data(TracedUtils.promote_to(TracedRArray{Int,1}, idxs) .- 1) + slice_sizes = get_mlir_data( + TracedUtils.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]) + ) #! format: off dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( @@ -323,8 +326,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr start_sizes = ntuple(i -> size(src, i), dims) results = map(CartesianIndices(idxs)) do k res = @allowscalar src[colons..., Tuple(idxs[k])...] - res isa TracedRNumber && - (res = Reactant.TracedUtils.broadcast_to_size(res, (1,))) + res isa TracedRNumber && (res = TracedUtils.broadcast_to_size(res, (1,))) return reshape(res, start_sizes..., :) end res = reshape(cat(results...; dims=(dims + 1)), size(dst))