From ab0105528e268b59c4dda43fb1acf9762c3e4ea6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 26 May 2024 19:05:51 -0700 Subject: [PATCH] Patches for Lux integration --- Project.toml | 10 ++++++---- src/Reactant.jl | 10 ++++++++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index ba91dec15..04178de3b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy "] -version = "0.1.0" +version = "0.1.1" [deps] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" Cassette = "7057c7e9-c182-5462-911a-8362d720325c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -18,10 +19,11 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" ReactantNNlibExt = "NNlib" [compat] -Cassette = "0.3" +ArrayInterface = "7.10" CEnum = "0.4, 0.5" +Cassette = "0.3" Enzyme = "0.11, 0.12" -Reactant_jll = "0.0.6" +NNlib = "0.9.17" Preferences = "1.4" +Reactant_jll = "0.0.6" julia = "1" - diff --git a/src/Reactant.jl b/src/Reactant.jl index 1dbb846cd..e91118f4f 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -1,11 +1,16 @@ module Reactant +using ArrayInterface: ArrayInterface + include("mlir/MLIR.jl") include("XLA.jl") include("utils.jl") abstract type RArray{ElType,Shape,N} <: AbstractArray{ElType, N} end +ArrayInterface.can_setindex(::Type{<:RArray}) = false +ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false + @inline Base.eltype(::RArray{ElType,Shape}) where {ElType, Shape} = ElType @inline Base.size(::RArray{ElType,Shape}) where {ElType, Shape} = Shape @inline Base.ndims(::RArray{ElType,Shape, N}) where {ElType, Shape, N} = N @@ -166,6 +171,7 @@ include("overloads.jl") using Enzyme @inline val_value(::Val{T}) where T = T +@inline val_value(::Type{Val{T}}) where T = T @enum TraceMode begin ConcreteToTraced = 1 @@ -599,7 +605,7 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea end res = Symbol("arg_$(path[2])") for p in path[3:end] - res = :(Base.getfield($res, $p)) + res = :(Base.getfield($res, $(Meta.quot(p)))) end sym = Symbol("sbuf_$i") sbuf = :($sym = XLA.synced_buffer($res.data)) @@ -725,7 +731,7 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea return end - error("canot copy $T") + error("cannot copy $T") end create_result(concrete_result, :result, ())