Skip to content

Commit

Permalink
fix: missing promote_to in NNlibExt
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 19, 2024
1 parent edf2d71 commit 242cfdf
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ using NNlib
using GPUArraysCore: @allowscalar
using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber

using Reactant.TracedUtils: materialize_traced_array, get_mlir_data, set_mlir_data!
using Reactant.TracedUtils:
TracedUtils, materialize_traced_array, get_mlir_data, set_mlir_data!

using ReactantCore: @trace
using LinearAlgebra: LinearAlgebra, triu
Expand All @@ -20,8 +21,8 @@ end
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
max_ = NNlib.fast_maximum(x; dims)
# XXX: Once reverse mode of if is properly supported, we can make it @trace
# zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
# one_num = Reactant.promote_to(TracedRNumber{T}, 1)
# zero_num = TracedUtils.promote_to(TracedRNumber{T}, 0)
# one_num = TracedUtils.promote_to(TracedRNumber{T}, 1)
# @trace if all(isfinite, max_)
@. out = exp(x - max_)
# else
Expand All @@ -37,8 +38,8 @@ end
function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
max_ = NNlib.fast_maximum(x; dims)
# XXX: Once reverse mode of if is properly supported, we can make it @trace
# inf_num = Reactant.promote_to(TracedRNumber{T}, Inf)
# zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
# inf_num = TracedUtils.promote_to(TracedRNumber{T}, Inf)
# zero_num = TracedUtils.promote_to(TracedRNumber{T}, 0)
# @trace if all(isfinite, max_)
@. out = x - max_
# else
Expand Down Expand Up @@ -232,9 +233,9 @@ function NNlib.batched_mul!(
if size(x, 3) != size(y, 3)
B = max(size(x, 3), size(y, 3))
if size(x, 3) == 1
x = Reactant.TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
x = TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
elseif size(y, 3) == 1
y = Reactant.TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
y = TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
end
end

Expand All @@ -244,9 +245,9 @@ function NNlib.batched_mul!(
if size(x, 1) != size(y, 1)
B = max(size(x, 1), size(y, 1))
if size(x, 1) == 1
x = Reactant.TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
x = TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
elseif size(y, 1) == 1
y = Reactant.TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
y = TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
end
end

Expand All @@ -264,7 +265,7 @@ end
function NNlib.pad_constant(
x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
) where {T,N}
value = Reactant.TracedUtils.promote_to(TracedRNumber{T}, value)
value = TracedUtils.promote_to(TracedRNumber{T}, value)
low = [i[1] for i in pad]
high = [i[2] for i in pad]
interior = [0 for i in pad]
Expand All @@ -287,8 +288,10 @@ function NNlib.gather!(
) where {T1,T2}
dims = NNlib.scatter_dims(src, dst, idxs)
@assert dims == 1 # scatter_dims lets us do some size checks so we call that function
idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1)
slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]))
idxs = get_mlir_data(TracedUtils.promote_to(TracedRArray{Int,1}, idxs) .- 1)
slice_sizes = get_mlir_data(
TracedUtils.promote_to(TracedRArray{Int,1}, [size(src, 1), 1])
)

#! format: off
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
Expand Down Expand Up @@ -323,8 +326,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
start_sizes = ntuple(i -> size(src, i), dims)
results = map(CartesianIndices(idxs)) do k
res = @allowscalar src[colons..., Tuple(idxs[k])...]
res isa TracedRNumber &&
(res = Reactant.TracedUtils.broadcast_to_size(res, (1,)))
res isa TracedRNumber && (res = TracedUtils.broadcast_to_size(res, (1,)))
return reshape(res, start_sizes..., :)
end
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
Expand Down

0 comments on commit 242cfdf

Please sign in to comment.