Skip to content

Commit

Permalink
No Transpose emission for 0 rank tensor (#375)
Browse files Browse the repository at this point in the history
* `stablehlo.sort` Ops

* do not transpose rank 0 tensor

* move check

* format
  • Loading branch information
glou-nes authored Dec 20, 2024
1 parent 3afed78 commit 0a41c60
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
10 changes: 2 additions & 8 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,6 @@ function link(job, compiled)
return compiled
end

function transpose_val(val)
attr = MLIR.IR.DenseArrayAttribute(
Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...]
)
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1)
end

Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
args...;
Expand All @@ -366,7 +360,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
push!(rarrays, ta)
arg = ta.mlir_data
arg = transpose_val(arg)
arg = Reactant.TracedUtils.transpose_val(arg)
push!(restys, MLIR.IR.type(arg))
push!(mlir_args, arg)
push!(
Expand Down Expand Up @@ -399,7 +393,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
)
# call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod))
for (i, res) in enumerate(rarrays)
res.mlir_data = transpose_val(MLIR.IR.result(call, i))
res.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, i))
end

@show blockdim
Expand Down
6 changes: 3 additions & 3 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ function transpose_ty(mlirty)
return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty))
end
function transpose_val(val)
attr = MLIR.IR.DenseArrayAttribute(
Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...]
)
val_size = size(MLIR.IR.type(val))
val_size == () && return val
attr = MLIR.IR.DenseArrayAttribute(Int64[reverse(0:(length(val_size) - 1))...])
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1)
end

Expand Down

0 comments on commit 0a41c60

Please sign in to comment.