From 0f7a912ca9cfd2ce1a96491052a16eab899cc9a7 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 16 May 2024 16:28:53 -0400 Subject: [PATCH] Add comparisons --- src/overloads.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/overloads.jl b/src/overloads.jl index 159503621..bd7a3a8ea 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -397,6 +397,31 @@ for (jlop, hloop, RT) in ((:(Base.min), :minimum, :ElType),(:(Base.max), :maximu end end +for (jlop, hloop, hlocomp, RT) in ( + (:(Base.:(==)), :compare, "EQ", :ElType), + (:(Base.:(!=)), :compare, "NE", :ElType), + (:(Base.:(>=)), :compare, "GE", :ElType), + (:(Base.:(>)), :compare, "GT", :ElType), + (:(Base.:(<=)), :compare, "LE", :ElType), + (:(Base.:(<)), :compare, "LT", :ElType), + ) + @eval begin + function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} + return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; + comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(MLIR.IR.context(), $hlocomp)), 1)) + end + + function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N} + rhs = promote_to(lhs, rhs) + return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(MLIR.IR.context(), $hlocomp)), 1)) + end + + function elem_apply(::typeof($jlop), lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} + lhs = promote_to(rhs, lhs) + return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(MLIR.IR.context(), $hlocomp)), 1)) + end + end +end function elem_apply(::typeof(identity), lhs) return lhs