Skip to content

Commit

Permalink
Add comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 16, 2024
1 parent ea37332 commit 0f7a912
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,31 @@ for (jlop, hloop, RT) in ((:(Base.min), :minimum, :ElType),(:(Base.max), :maximu
end
end

for (jlop, hloop, hlocomp, RT) in (
(:(Base.:(==)), :compare, "EQ", :ElType),
(:(Base.:(!=)), :compare, "NE", :ElType),
(:(Base.:(>=)), :compare, "GE", :ElType),
(:(Base.:(>)), :compare, "GT", :ElType),
(:(Base.:(<=)), :compare, "LE", :ElType),
(:(Base.:(<)), :compare, "LT", :ElType),
)
@eval begin
function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data;
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(MLIR.IR.context(), $hlocomp)), 1))
end

function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(MLIR.IR.context(), $hlocomp)), 1))
end

function elem_apply(::typeof($jlop), lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data; comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(MLIR.IR.context(), $hlocomp)), 1))
end
end
end

function elem_apply(::typeof(identity), lhs)
return lhs
Expand Down

0 comments on commit 0f7a912

Please sign in to comment.