Skip to content

Commit

Permalink
broadcast to size fix reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 17, 2024
1 parent 5f03906 commit 97822fe
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ function Base.mapreduce(
)

res = MLIR.IR.block!(fnbody) do
tmp = broadcast_to_size(op(args...), (1,)).mlir_data
tmp = broadcast_to_size(op(args...), ()).mlir_data
MLIR.Dialects.stablehlo.return_(MLIR.IR.Value[tmp])
return tmp
end
Expand Down
18 changes: 18 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@ fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-I

using InteractiveUtils

@testset "2D sum" begin
r_res = sum(ones(2, 10))

a = Reactant.ConcreteRArray(ones(2, 10))

c_res = sum(a)
@test c_res r_res

f = Reactant.compile(sum, (a,))

@show @code_typed f(a)
@show @code_llvm f(a)

f_res = f(a)

@test f_res r_res
end

@testset "Basic reduce max" begin
r_res = fastmax(ones(2, 10))

Expand Down

0 comments on commit 97822fe

Please sign in to comment.