Skip to content

Commit

Permalink
Merge pull request #17 from YingboMa/metalower
Browse files Browse the repository at this point in the history
Use Meta.lower to handle cases the same way base broadcasting does.
  • Loading branch information
chriselrod authored May 21, 2021
2 parents e5be8ec + 204ced9 commit ff914db
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 52 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FastBroadcast"
uuid = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
authors = ["Yingbo Ma <[email protected]> and Chris Elrod <[email protected]>"]
version = "0.1.7"
version = "0.1.8"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
113 changes: 62 additions & 51 deletions src/FastBroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,68 +205,79 @@ function walk_bc!(
return nothing
end

# From Julia Base
dottable(x) = x !== Base.maybeview
# don't add dots to dot operators
dottable(x::Symbol) = x !== :(:)

function undotop(ex)
Meta.isexpr(ex, :call) || return ex
str = string(ex.args[1])
if first(str) == '.' && (op = Symbol(str[2:end]); Base.isoperator(op))
return Expr(:call, op, ex.args[2:end]...)
function pushsymname!(ex::Expr, base::Symbol, @nospecialize(arg))
if arg isa Core.SSAValue
push!(ex.args, Symbol(base, '_', arg.id))
elseif arg isa Core.SlotNumber
push!(ex.args, Symbol(base, 's', arg.id))
else
return ex
push!(ex.args, arg)
end
end

function todotview(ex)
Meta.isexpr(ex, :ref) || return ex
q = Expr(:call, Base.Broadcast.dotview)
for a ex.args
push!(q.args, a)
end
q
function _goto(base::Symbol, i::Int, sym::Symbol)
Expr(:macrocall, sym, LineNumberNode(@__LINE__,Symbol(@__FILE__)), Symbol(base, "#label#", i))
end
goto(base::Symbol, i::Int) = _goto(base, i, Symbol("@goto"))
label(base::Symbol, i::Int) = _goto(base, i, Symbol("@label"))

function broadcasted_expr!(_ex)
Meta.isexpr(_ex, :$) && return _ex.args[1]
if Meta.isexpr(_ex, :.) && Meta.isexpr(_ex.args[2], :tuple)
_ex = Expr(:call, _ex.args[1], _ex.args[2].args...)
elseif Meta.isexpr(_ex, :ref)
return todotview(_ex)
end
_ex = undotop(_ex)
Meta.isexpr(_ex, :call) || return _ex
Meta.isexpr(_ex.args[1], :$) && return Expr(:call, _ex.args[1].args[1], _ex.args[2:end]...)
dottable(_ex.args[1]) || return _ex
ex::Expr = _ex
call = Expr(:call, Base.Broadcast.broadcasted, ex.args[1])
for n 2:length(ex.args)
push!(call.args, broadcasted_expr!(ex.args[n]))
end
call
function add_gotoifnot!(q::Expr, gotos::Vector{Int}, base::Symbol, cond, dest::Int)
ex = Expr(:||)
pushsymname!(ex, base, cond)
push!(ex.args, goto(base, dest))
push!(q.args, ex)
push!(gotos, dest)
nothing
end

function broadcast_expr!(ex)
ex isa Expr || return ex
if ex.head === :let
return Expr(:let, ex.args[1], broadcast_expr!(ex.args[2]))
end
update = findfirst(Base.Fix1(Base.sym_in, ex.head), ((:(+=),:(.+=)), (:(-=),:(.-=)), (:(*=),:(.*=)), (:(/=),:(./=)), (:(\=),:(.\=)), (:(^=),:(.^=)), (:(&=),:(.&=)), (:(|=),:(.|=)), (:(⊻=),:(.⊻=)), (:(÷=),:(.÷=))))
if update nothing
lhs = Expr(:call, (:(+), :(-), :(*), :(/), :(\), :(^), :(&), :(|), :(), :(÷))[update], ex.args[1], ex.args[2])
ex = Expr(:(=), ex.args[1], lhs)
end
if Meta.isexpr(ex, :(=), 2) || Meta.isexpr(ex, :(.=), 2)
return Expr(:call, fast_materialize!, todotview(ex.args[1]), broadcasted_expr!(ex.args[2]))
else
return Expr(:call, fast_materialize, broadcasted_expr!(ex))
function broadcast_codeinfo(ci)
q = Expr(:block)
base = gensym(:fastbroadcast)
gotos = Int[]
for (i, code) enumerate(ci.code)
k = findfirst(==(i), gotos)
if k nothing
push!(q.args, label(base, i))
end
if Meta.isexpr(code, :call)
ex = Expr(:call)
f = code.args[1]
if f === GlobalRef(Base, :materialize)
push!(ex.args, fast_materialize)
elseif f === GlobalRef(Base, :materialize!)
push!(ex.args, fast_materialize!)
elseif f === GlobalRef(Base, :getindex)
push!(ex.args, Base.Broadcast.dotview)
else
pushsymname!(ex, base, f)
end
for arg @view(code.args[2:end])
pushsymname!(ex, base, arg)
end
push!(q.args, Expr(:(=), Symbol(base, '_', i), ex))
elseif Meta.isexpr(code, :(=))
ex = Expr(:(=), Symbol(base, 's', code.args[1].id))
pushsymname!(ex, base, code.args[2])
push!(q.args, ex)
elseif VERSION v"1.6" && code isa Core.GotoIfNot
add_gotoifnot!(q, gotos, base, code.cond, code.dest)
elseif VERSION < v"1.6" && Meta.isexpr(code, :gotoifnot)
add_gotoifnot!(q, gotos, base, code.args[1], code.args[2])
elseif code isa Core.GotoNode
push!(q.args, goto(base, code.label))
push!(gotos, code.label)
elseif !(VERSION v"1.6" ? isa(code, Core.ReturnNode) : Meta.isexpr(code, :return))
ex = Expr(:(=), Symbol(base, '_', i))
pushsymname!(ex, base, code)
push!(q.args, ex)
end
end
q
end

macro (..)(ex)
esc(broadcast_expr!(macroexpand(__module__, ex)))
lowered = Meta.lower(__module__, Base.Broadcast.__dot__(ex))
lowered isa Expr || return esc(lowered)
esc(broadcast_codeinfo(lowered.args[1]))
end

end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ if GROUP == "All" || GROUP == "Core"
@test (@views @.. r[3:4] += x[2:3,1] + a) == [19,26]
@test (@.. @views r[3:4] += x[2:3,1] + a) == [25,35]
@test r == [5,6,25,35]
@test (@.. r + r[end]) == [40,41,60,70]

Q = rand(5,2); k = 1; v = 100 .* rand(5); d = 5;
@.. @view(Q[:, k]) = v / d
@test Q hcat(v ./ d, @view(Q[:,2]))
@testset "Sparse" begin
x = sparse([1, 2, 0, 4])
y = sparse([1, 0, 0, 4])
Expand Down

2 comments on commit ff914db

@chriselrod
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/37189

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.8 -m "<description of version>" ff914dbb351dda2e262e72b991ee4a1023667c83
git push origin v0.1.8

Please sign in to comment.