Skip to content

Commit

Permalink
feat: support setindex with views (#240)
Browse files Browse the repository at this point in the history
* feat: support setindex with views

* Update test/basic.jl

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
avik-pal and mofeing authored Nov 7, 2024
1 parent c4d0b50 commit f34ec8f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
30 changes: 27 additions & 3 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...]
get_mlir_data(x::TracedRArray) = x.mlir_data
get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x))

function set_mlir_data!(x::TracedRArray, data)
x.mlir_data = data
return x
end
function set_mlir_data!(x::AnyTracedRArray, data)
data_type = MLIR.IR.type(data)
data = TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}(
(), data, size(data_type)
)
setindex!(x, data, axes(x)...)
return x
end

ancestor(x::TracedRArray) = x
ancestor(x::WrappedTracedRArray) = ancestor(parent(x))

Expand Down Expand Up @@ -115,12 +128,23 @@ function Base.setindex!(
i in indices
]
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.dynamic_update_slice(a.mlir_data, v.mlir_data, indices), 1
MLIR.Dialects.stablehlo.dynamic_update_slice(
a.mlir_data, get_mlir_data(v), indices
),
1,
)
a.mlir_data = res
return v
end

function Base.setindex!(
a::AnyTracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N}
) where {T,N}
ancestor_indices = get_ancestor_indices(a, indices...)
setindex!(ancestor(a), v, ancestor_indices...)
return a
end

Base.size(x::TracedRArray) = x.shape

Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A))
Expand Down Expand Up @@ -727,7 +751,7 @@ function broadcast_to_size_internal(x::TracedRArray, rsize)
)
end

function _copyto!(dest::TracedRArray, bc::Broadcasted)
function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest

Expand All @@ -736,7 +760,7 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
args = (broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)

res = elem_apply(bc.f, args...)
dest.mlir_data = res.mlir_data
set_mlir_data!(dest, res.mlir_data)
return dest
end

Expand Down
37 changes: 37 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,43 @@ end
# get_view_compiled = @compile get_view(x_concrete)
end

function masking(x)
y = similar(x)
y[1:2, :] .= 0
y[3:4, :] .= 1
return y
end

function masking!(x)
x[1:2, :] .= 0
x[3:4, :] .= 1
return x
end

@testset "setindex! with views" begin
x = rand(4, 4) .+ 2.0
x_ra = Reactant.to_rarray(x)

y = masking(x)
y_ra = @jit(masking(x_ra))
@test y y_ra

x_ra_array = Array(x_ra)
@test !(any(iszero, x_ra_array[1, :]))
@test !(any(iszero, x_ra_array[2, :]))
@test !(any(isone, x_ra_array[3, :]))
@test !(any(isone, x_ra_array[4, :]))

y_ra = @jit(masking!(x_ra))
@test y y_ra

x_ra_array = Array(x_ra)
@test all(iszero, x_ra_array[1, :])
@test all(iszero, x_ra_array[2, :])
@test all(isone, x_ra_array[3, :])
@test all(isone, x_ra_array[4, :])
end

tuple_byref(x) = (; a=(; b=x))
tuple_byref2(x) = abs2.(x), tuple_byref2(x)

Expand Down

1 comment on commit f34ec8f

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: f34ec8f Previous: 06f2c36 Ratio
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 7027955167 ns 5239134607 ns 1.34
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5911790203 ns 5354972275 ns 1.10
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5039003346 ns 5066098092 ns 0.99
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7620308614 ns 7784224752 ns 0.98
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 31357074001 ns 34116607328 ns 0.92
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1572787384 ns 1591873731 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1548995246 ns 1568425232 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1555509740 ns 1561105174 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3467552492 ns 3376666620 ns 1.03
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 3419731143 ns 3453230369.5 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2180418814 ns 2206471816 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2173460673 ns 2201801664 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2158353524 ns 2205219361 ns 0.98
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3927830322 ns 4126344490 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 6253378741.5 ns 6003789905.5 ns 1.04
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1440134479 ns 1489915239 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1421557207 ns 1467369457 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1428360591 ns 1479287632 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3208488833 ns 3357859792 ns 0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1139085595.5 ns 1245410178 ns 0.91
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1712651271 ns 1761398781 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1706763760 ns 1751130703 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1704312368 ns 1750308121 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3479746822 ns 3578816935 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3154626911 ns 3100003150 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2175602283 ns 2228938613 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2162683110 ns 2219891927 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2187544254 ns 2222482559 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3928792736 ns 4053746596 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 6172504320 ns 6810510970 ns 0.91
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2997646442 ns 3083583194 ns 0.97
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 2957532725 ns 3065688335 ns 0.96
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2960803098 ns 3060109436 ns 0.97
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4863944327 ns 5026909678 ns 0.97
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 22191688400 ns 14005795901 ns 1.58
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3157390991 ns 3259385374 ns 0.97
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3212937029 ns 3276966977 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3419812778 ns 3261022001 ns 1.05
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5010112071 ns 5268751195 ns 0.95
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 10940648930 ns 13606139679 ns 0.80
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1843742771 ns 1917023667 ns 0.96
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1827524417 ns 1900958836 ns 0.96
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1842645194 ns 1908808319 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3593953850 ns 3796722792 ns 0.95
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3544018917 ns 3995439412 ns 0.89

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.