Skip to content

Commit

Permalink
Patches for Lux integration
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 27, 2024
1 parent 292dc03 commit 1757780
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.1.0"
version = "0.1.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -18,10 +19,11 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
ReactantNNlibExt = "NNlib"

[compat]
Cassette = "0.3"
ArrayInterface = "7.10"
CEnum = "0.4, 0.5"
Cassette = "0.3"
Enzyme = "0.11, 0.12"
Reactant_jll = "0.0.6"
NNlib = "0.9.17"
Preferences = "1.4"
Reactant_jll = "0.0.6"
julia = "1"

10 changes: 8 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
module Reactant

using ArrayInterface: ArrayInterface

include("mlir/MLIR.jl")
include("XLA.jl")
include("utils.jl")

abstract type RArray{ElType,Shape,N} <: AbstractArray{ElType, N} end

ArrayInterface.can_setindex(::Type{<:RArray}) = false
ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false

@inline Base.eltype(::RArray{ElType,Shape}) where {ElType, Shape} = ElType
@inline Base.size(::RArray{ElType,Shape}) where {ElType, Shape} = Shape
@inline Base.ndims(::RArray{ElType,Shape, N}) where {ElType, Shape, N} = N
Expand Down Expand Up @@ -166,6 +171,7 @@ include("overloads.jl")
using Enzyme

@inline val_value(::Val{T}) where T = T
@inline val_value(::Type{Val{T}}) where T = T

@enum TraceMode begin
ConcreteToTraced = 1
Expand Down Expand Up @@ -599,7 +605,7 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea
end
res = Symbol("arg_$(path[2])")
for p in path[3:end]
res = :(Base.getfield($res, $p))
res = :(Base.getfield($res, $(Meta.quot(p))))
end
sym = Symbol("sbuf_$i")
sbuf = :($sym = XLA.synced_buffer($res.data))
Expand Down Expand Up @@ -725,7 +731,7 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea
return
end

error("canot copy $T")
error("cannot copy $T")
end

create_result(concrete_result, :result, ())
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function transpose_val(val)
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val, permutation=attr), 1)
end

function apply(f, args...; kwargs)
function apply(f, args...; kwargs...)
f(args...; kwargs...)
end

Expand Down

0 comments on commit 1757780

Please sign in to comment.