From 97822fe16bfc618c3626dadc0ba9c6c2dacf9504 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 17 Jun 2024 19:24:57 -0400 Subject: [PATCH] broadcast to size fix reduce --- src/overloads.jl | 2 +- test/basic.jl | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/overloads.jl b/src/overloads.jl index 4518d5f9e..03a1cbf01 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -881,7 +881,7 @@ function Base.mapreduce( ) res = MLIR.IR.block!(fnbody) do - tmp = broadcast_to_size(op(args...), (1,)).mlir_data + tmp = broadcast_to_size(op(args...), ()).mlir_data MLIR.Dialects.stablehlo.return_(MLIR.IR.Value[tmp]) return tmp end diff --git a/test/basic.jl b/test/basic.jl index b221cb40c..9ead3d021 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -8,6 +8,24 @@ fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-I using InteractiveUtils +@testset "2D sum" begin + r_res = sum(ones(2, 10)) + + a = Reactant.ConcreteRArray(ones(2, 10)) + + c_res = sum(a) + @test c_res ≈ r_res + + f = Reactant.compile(sum, (a,)) + + @show @code_typed f(a) + @show @code_llvm f(a) + + f_res = f(a) + + @test f_res ≈ r_res +end + @testset "Basic reduce max" begin r_res = fastmax(ones(2, 10))