Skip to content

Commit

Permalink
Fix artifacts introduced when rebasing #45
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Aug 14, 2024
1 parent bd6b643 commit 5eccc66
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 35 deletions.
5 changes: 3 additions & 2 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ for (jlop, hloop) in (
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
),
(),
)
end
end

NNlib.relu(x::Reactant.TracedRArray{T,(),0}) where {T} = max(x, zero(T))
NNlib.relu(x::Reactant.TracedRArray{T,0}) where {T} = max(x, zero(T))

NNlib.gelu(x::Reactant.TracedRArray{T,(),0}) where {T} = x * sigmoid(T(1.702) * x)
NNlib.gelu(x::Reactant.TracedRArray{T,0}) where {T} = x * sigmoid(T(1.702) * x)

# TODO handle non finite cases
function NNlib.softmax!(
Expand Down
5 changes: 5 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ function mlir_type(x::RArray{T,N}) where {T,N}
return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T))
end

function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N}
@assert length(shape) == N
return MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
end

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
)::RT where {copy_if_inactive,RT<:RArray}
Expand Down
109 changes: 76 additions & 33 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,21 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.convert(
rhs.mlir_data; result=mlir_type(TracedRArray{T,N})
rhs.mlir_data; result=mlir_type(TracedRArray{T,N}, size(rhs))
),
1,
),
size(rhs),
)
end
if isa(rhs, Number)
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRArray{T,N}))
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRArray{T,N}, size(rhs)))
ta = TracedRArray{T,N}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs)
)
return ta
end
attr = MLIR.IR.DenseElementsAttribute(mlir_type(TracedRArray{T,N}), rhs)
attr = MLIR.IR.DenseElementsAttribute(mlir_type(TracedRArray{T,N}, size(rhs)), rhs)
return TracedRArray{T,N}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs)
)
Expand All @@ -115,11 +116,11 @@ function promote_to(lhs::TracedRArray{T,N}, rhs) where {T,N}
return promote_to(TracedRArray{T,N}, rhs)
end

for (jlop, hloop) in (
(:(Base.min), :minimum),
(:(Base.max), :maximum),
(:(Base.:+), :add),
(:(Base.:-), :subtract),
for (jlop, hloop, RT) in (
(:(Base.min), :minimum, :T),
(:(Base.max), :maximum, :T),
(:(Base.:+), :add, :T),
(:(Base.:-), :subtract, :T),
)
@eval begin
function $jlop(lhs::TracedRArray{T,N}, rhs::TracedRArray{T2,N}) where {T,T2,N}
Expand All @@ -136,40 +137,45 @@ for (jlop, hloop) in (
end

function $jlop(lhs::TracedRArray{T,N}, rhs::TracedRArray{T,N}) where {T,N}
return TracedRArray{T,N}(
return TracedRArray{$RT,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
size(lhs),
)
end
end

function $jlop(lhs::TracedRArray{T,N}, rhs) where {T,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{T,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
size(lhs),
)
end
for otherType in (Number, Any) #=TracedRArray{S,0} where {S}=#
@eval begin
function $jlop(lhs::TracedRArray{T,N}, rhs::$otherType) where {T,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
size(lhs),
)
end

function $jlop(lhs, rhs::TracedRArray{T,N}) where {T,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{T,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
size(lhs),
)
function $jlop(lhs::$otherType, rhs::TracedRArray{T,N}) where {T,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
size(lhs),
)
end
end
end
end

for (jlop, hloop) in ((:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^), :power))
for (jlop, hloop, RT) in
((:(Base.:*), :multiply, :T), (:(Base.:/), :divide, :T), (:(Base.:^), :power, :T))
@eval begin
function $jlop(lhs::TracedRArray{T,0}, rhs::TracedRArray{T2,0}) where {T,T2}
commonTy = TracedRArray{Base.promote_type(T, T2),0}
Expand All @@ -185,7 +191,7 @@ for (jlop, hloop) in ((:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^
end

function $jlop(lhs::TracedRArray{T,0}, rhs::TracedRArray{T,0}) where {T}
return TracedRArray{T,0}(
return TracedRArray{$RT,0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
Expand All @@ -196,7 +202,7 @@ for (jlop, hloop) in ((:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^

function $jlop(lhs::TracedRArray{T,0}, rhs) where {T}
rhs = promote_to(lhs, rhs)
return TracedRArray{T,0}(
return TracedRArray{$RT,0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
Expand All @@ -207,7 +213,30 @@ for (jlop, hloop) in ((:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^

function $jlop(lhs, rhs::TracedRArray{T,0}) where {T}
lhs = promote_to(rhs, lhs)
return TracedRArray{T,0}(
return TracedRArray{$RT,0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
(),
)
end

# Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity
function $jlop(lhs::TracedRArray{T,0}, rhs::Number) where {T}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
(),
)
end

function $jlop(lhs::Number, rhs::TracedRArray{T,0}) where {T}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
Expand All @@ -218,6 +247,20 @@ for (jlop, hloop) in ((:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^
end
end

function Base.ifelse(
pred::TracedRArray{Bool,0}, x::TracedRArray{T1,0}, y::TracedRArray{T2,0}
) where {T1,T2}
return TracedRArray{promote_type(T1, T2),0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1
),
size(pred),
)
end

Base.abs2(x::Reactant.TracedRArray{T,0}) where {T} = x * conj(x)

function Base.literal_pow(
::Base.RefValue{typeof(^)}, x::TracedRArray{T,0}, ::Base.RefValue{Val{P}}
) where {T,P}
Expand Down

0 comments on commit 5eccc66

Please sign in to comment.