Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patches for Lux integration #7

Merged
merged 6 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.1.0"
version = "0.1.1"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand All @@ -12,16 +12,19 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"

[weakdeps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[extensions]
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantNNlibExt = "NNlib"

[compat]
Cassette = "0.3"
ArrayInterface = "7.10"
CEnum = "0.4, 0.5"
Cassette = "0.3"
Enzyme = "0.11, 0.12"
Reactant_jll = "0.0.6"
NNlib = "0.9"
Preferences = "1.4"
Reactant_jll = "0.0.6"
julia = "1"

9 changes: 9 additions & 0 deletions ext/ReactantArrayInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module ReactantArrayInterfaceExt

using ArrayInterface: ArrayInterface
using Reactant: RArray

ArrayInterface.can_setindex(::Type{<:RArray}) = false
ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false

end
27 changes: 13 additions & 14 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@ module ReactantNNlibExt
using NNlib
using Reactant

function __init__()
for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh))
@eval begin
if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast
function Reactant.elem_apply(
::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
return Reactant.TracedRArray{ElType,Shape,N}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1
),
)
end
for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh))
@eval begin
if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast
function Reactant.elem_apply(
::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
return Reactant.TracedRArray{ElType,Shape,N}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1
),
)
end
end
end
Expand All @@ -36,4 +34,5 @@ function NNlib.softmax!(
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
return out ./= tmp
end

end
12 changes: 10 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ abstract type RArray{ElType,Shape,N} <: AbstractArray{ElType,N} end

@inline Base.eltype(::RArray{ElType,Shape}) where {ElType,Shape} = ElType
@inline Base.size(::RArray{ElType,Shape}) where {ElType,Shape} = Shape
@inline Base.size(::Type{<:RArray{ElType,Shape}}) where {ElType,Shape} = Shape
@inline Base.ndims(::RArray{ElType,Shape,N}) where {ElType,Shape,N} = N
@inline Base.ndims(::Type{<:RArray{ElType,Shape,N}}) where {ElType,Shape,N} = N

Expand Down Expand Up @@ -162,6 +163,12 @@ end
return res
end

if RT <: TracedRArray
res = broadcast_to_size(eltype(RT)(0), size(prev))
seen[prev] = res
return res
end

attr = fill(MLIR.IR.Attribute(eltype(RT)(0)), mlir_type(prev))
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
res = RT((), cst)
Expand Down Expand Up @@ -189,6 +196,7 @@ include("overloads.jl")
using Enzyme

@inline val_value(::Val{T}) where {T} = T
@inline val_value(::Type{Val{T}}) where {T} = T

@enum TraceMode begin
ConcreteToTraced = 1
Expand Down Expand Up @@ -649,7 +657,7 @@ function generate_jlfunc(
end
res = Symbol("arg_$(path[2])")
for p in path[3:end]
res = :(Base.getfield($res, $p))
res = :(Base.getfield($res, $(Meta.quot(p))))
end
sym = Symbol("sbuf_$i")
sbuf = :($sym = XLA.synced_buffer($res.data))
Expand Down Expand Up @@ -819,7 +827,7 @@ function generate_jlfunc(
return nothing
end

return error("canot copy $T")
return error("cannot copy $T")
end

create_result(concrete_result, :result, ())
Expand Down
7 changes: 7 additions & 0 deletions src/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,13 @@ Cassette.overdub(context::TraceCtx, f::typeof(Base.materialize!), args...) = f(a

@inline Base.copyto!(dest::TracedRArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict

@inline function Base.copyto!(
dest::TracedRArray{ElType,Shape,N}, src::TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
dest.mlir_data = src.mlir_data
return dest
end

@inline function broadcast_to_size(arg::AbstractArray, rsize)
attr = MLIR.IR.DenseElementsAttribute(arg)
len = ndims(arg)
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function transpose_val(val)
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1)
end

function apply(f, args...; kwargs)
function apply(f, args...; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for what needed this. I somehow recall the other way being intentional but maybe not

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This came up in computing the gradient of the model with Enzyme. But that code is not completely functional yet, see LuxDL/Lux.jl#665 (comment)

return f(args...; kwargs...)
end

Expand Down
Loading