Skip to content

Commit

Permalink
optimizer: fully support inlining of union-split, partially constant-…
Browse files Browse the repository at this point in the history
…prop' callsite (#43347)

Makes full use of constant-propagation, by addressing this 
[TODO](https://github.com/JuliaLang/julia/blob/00734c5fd045316a00d287ca2c0ec1a2eef6e4d1/base/compiler/ssair/inlining.jl#L1212).
Here is a performance improvement from #43287:
```julia
ulia> using BenchmarkTools

julia> X = rand(ComplexF32, 64, 64);

julia> dst = reinterpret(reshape, Float32, X);

julia> src = copy(dst);

julia> @Btime copyto!($dst, $src);
  50.819 μs (1 allocation: 32 bytes) # v1.6.4
  41.081 μs (0 allocations: 0 bytes) # this commit
```

fixes #43287
  • Loading branch information
aviatesk committed Jan 5, 2022
1 parent 85fc5c9 commit 590a384
Showing 4 changed files with 133 additions and 115 deletions.
1 change: 1 addition & 0 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
@@ -156,6 +156,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# by constant analysis, but let's create `ConstCallInfo` if there has been any successful
# constant propagation happened since other consumers may be interested in this
if any_const_result && seen == napplicable
@assert napplicable == nmatches(info) == length(const_results)
info = ConstCallInfo(info, const_results)
end

190 changes: 98 additions & 92 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
@@ -675,24 +675,17 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
new_stmt = Expr(:call, argexprs[2], def, state...)
state1 = insert_node!(ir, idx, NewInstruction(new_stmt, call.rt))
new_sig = with_atype(call_sig(ir, new_stmt)::Signature)
info = call.info
handled = false
if isa(info, ConstCallInfo)
if maybe_handle_const_call!(
ir, state1.id, new_stmt, info, new_sig,
istate, false, todo)
handled = true
else
info = info.call
end
end
if !handled && (isa(info, MethodMatchInfo) || isa(info, UnionSplitInfo))
info = isa(info, MethodMatchInfo) ?
MethodMatchInfo[info] : info.matches
new_info = call.info
if isa(new_info, ConstCallInfo)
handle_const_call!(
ir, state1.id, new_stmt, new_info,
new_sig, istate, todo)
elseif isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
new_infos = isa(new_info, MethodMatchInfo) ? MethodMatchInfo[new_info] : new_info.matches
# See if we can inline this call to `iterate`
analyze_single_call!(
ir, todo, state1.id, new_stmt,
new_sig, info, istate)
new_sig, new_infos, istate)
end
if i != length(thisarginfo.each)
valT = getfield_tfunc(call.rt, Const(1))
@@ -910,7 +903,9 @@ function iterate(split::UnionSplitSignature, state::Vector{Int}...)
return (sig, state)
end

function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Pair{Int, Any}})
function handle_single_case!(
ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case),
todo::Vector{Pair{Int, Any}}, isinvoke::Bool = false)
if isa(case, ConstantCase)
ir[SSAValue(idx)] = case.val
elseif isa(case, MethodInstance)
@@ -1086,13 +1081,13 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result):
validate_sparams(mi.sparam_vals) || return nothing
if argtypes_to_type(atypes) <: mi.def.sig
state.mi_cache !== nothing && (item = resolve_todo(item, state))
handle_single_case!(ir, stmt, idx, item, true, todo)
handle_single_case!(ir, stmt, idx, item, todo, true)
return nothing
end
end

result = analyze_method!(match, atypes, state)
handle_single_case!(ir, stmt, idx, result, true, todo)
handle_single_case!(ir, stmt, idx, result, todo, true)
return nothing
end

@@ -1200,49 +1195,39 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta
return sig
end

# TODO inline non-`isdispatchtuple`, union-split callsites
# TODO inline non-`isdispatchtuple`, union-split callsites?
function analyze_single_call!(
ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
(; atypes, atype)::Signature, infos::Vector{MethodMatchInfo}, state::InliningState)
sig::Signature, infos::Vector{MethodMatchInfo}, state::InliningState)
(; atypes, atype) = sig
cases = InliningCase[]
local signature_union = Bottom
local only_method = nothing # keep track of whether there is one matching method
local meth
local meth::MethodLookupResult
local fully_covered = true
for i in 1:length(infos)
info = infos[i]
meth = info.results
meth = infos[i].results
if meth.ambig
# Too many applicable methods
# Or there is a (partial?) ambiguity
return
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
continue
elseif length(meth) == 1 && only_method !== false
if only_method === nothing
only_method = meth[1].method
elseif only_method !== meth[1].method
else
if length(meth) == 1 && only_method !== false
if only_method === nothing
only_method = meth[1].method
elseif only_method !== meth[1].method
only_method = false
end
else
only_method = false
end
else
only_method = false
end
for match in meth
spec_types = match.spec_types
signature_union = Union{signature_union, spec_types}
if !isdispatchtuple(spec_types)
fully_covered = false
continue
end
item = analyze_method!(match, atypes, state)
if item === nothing
fully_covered = false
continue
elseif _any(case->case.sig === spec_types, cases)
continue
end
push!(cases, InliningCase(spec_types, item))
signature_union = Union{signature_union, match.spec_types}
fully_covered &= handle_match!(match, atypes, state, cases)
end
end

@@ -1253,9 +1238,8 @@ function analyze_single_call!(
if length(infos) > 1
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
atype, only_method.sig)::SimpleVector
match = MethodMatch(metharg, methsp, only_method, true)
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
else
meth = meth::MethodLookupResult
@assert length(meth) == 1
match = meth[1]
end
@@ -1268,46 +1252,41 @@ function analyze_single_call!(
fully_covered = false
end

# If we only have one case and that case is fully covered, we may either
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if fully_covered && length(cases) == 1
handle_single_case!(ir, stmt, idx, cases[1].item, false, todo)
elseif length(cases) > 0
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
end
return nothing
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
end

# try to create `InliningCase`s using constant-prop'ed results
# currently it works only when constant-prop' succeeded for all (union-split) signatures
# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later
# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out
function maybe_handle_const_call!(
ir::IRCode, idx::Int, stmt::Expr, (; results)::ConstCallInfo, (; atypes, atype)::Signature,
state::InliningState, isinvoke::Bool, todo::Vector{Pair{Int, Any}})
cases = InliningCase[] # TODO avoid this allocation for single cases ?
# similar to `analyze_single_call!`, but with constant results
function handle_const_call!(
ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo,
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
(; atypes, atype) = sig
(; call, results) = cinfo
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
cases = InliningCase[]
local fully_covered = true
local signature_union = Bottom
for result in results
isa(result, InferenceResult) || return false
(; mi) = item = InliningTodo(result, atypes)
spec_types = mi.specTypes
signature_union = Union{signature_union, spec_types}
if !isdispatchtuple(spec_types)
fully_covered = false
continue
end
if !validate_sparams(mi.sparam_vals)
fully_covered = false
local j = 0
for i in 1:length(infos)
meth = infos[i].results
if meth.ambig
# Too many applicable methods
# Or there is a (partial?) ambiguity
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
continue
end
state.mi_cache !== nothing && (item = resolve_todo(item, state))
if item === nothing
fully_covered = false
continue
for match in meth
j += 1
result = results[j]
if result === nothing
signature_union = Union{signature_union, match.spec_types}
fully_covered &= handle_match!(match, atypes, state, cases)
else
signature_union = Union{signature_union, result.linfo.specTypes}
fully_covered &= handle_const_result!(result, atypes, state, cases)
end
end
push!(cases, InliningCase(spec_types, item))
end

# if the signature is fully covered and there is only one applicable method,
@@ -1316,25 +1295,54 @@ function maybe_handle_const_call!(
if length(cases) == 0 && length(results) == 1
(; mi) = item = InliningTodo(results[1]::InferenceResult, atypes)
state.mi_cache !== nothing && (item = resolve_todo(item, state))
validate_sparams(mi.sparam_vals) || return true
item === nothing && return true
validate_sparams(mi.sparam_vals) || return nothing
item === nothing && return nothing
push!(cases, InliningCase(mi.specTypes, item))
fully_covered = true
end
else
fully_covered = false
end

handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
end

function handle_match!(
match::MethodMatch, argtypes::Vector{Any}, state::InliningState,
cases::Vector{InliningCase})
spec_types = match.spec_types
isdispatchtuple(spec_types) || return false
item = analyze_method!(match, argtypes, state)
item === nothing && return false
_any(case->case.sig === spec_types, cases) && return true
push!(cases, InliningCase(spec_types, item))
return true
end

function handle_const_result!(
result::InferenceResult, argtypes::Vector{Any}, state::InliningState,
cases::Vector{InliningCase})
(; mi) = item = InliningTodo(result, argtypes)
spec_types = mi.specTypes
isdispatchtuple(spec_types) || return false
validate_sparams(mi.sparam_vals) || return false
state.mi_cache !== nothing && (item = resolve_todo(item, state))
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
return true
end

function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, sig::Signature,
cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}})
# If we only have one case and that case is fully covered, we may either
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if fully_covered && length(cases) == 1
handle_single_case!(ir, stmt, idx, cases[1].item, isinvoke, todo)
handle_single_case!(ir, stmt, idx, cases[1].item, todo)
elseif length(cases) > 0
isinvoke && rewrite_invoke_exprargs!(stmt)
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases))
end
return true
return nothing
end

function handle_const_opaque_closure_call!(
@@ -1346,7 +1354,7 @@ function handle_const_opaque_closure_call!(
isdispatchtuple(item.mi.specTypes) || return
validate_sparams(item.mi.sparam_vals) || return
state.mi_cache !== nothing && (item = resolve_todo(item, state))
handle_single_case!(ir, stmt, idx, item, false, todo)
handle_single_case!(ir, stmt, idx, item, todo)
return nothing
end

@@ -1371,9 +1379,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE
info = info.info
end

# Inference determined this couldn't be analyzed. Don't question it.
if info === false
# Inference determined this couldn't be analyzed. Don't question it.
continue
end

@@ -1386,16 +1393,15 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
sig, state, todo)
continue
else
maybe_handle_const_call!(
handle_const_call!(
ir, idx, stmt, info, sig,
state, sig.f === Core.invoke, todo) && continue
state, todo)
end
info = info.call # cascade to the non-constant handling
end

if isa(info, OpaqueClosureCallInfo)
item = analyze_method!(info.match, sig.atypes, state)
handle_single_case!(ir, stmt, idx, item, false, todo)
handle_single_case!(ir, stmt, idx, item, todo)
continue
end

33 changes: 21 additions & 12 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
@@ -40,6 +40,27 @@ struct UnionSplitInfo
matches::Vector{MethodMatchInfo}
end

nmatches(info::MethodMatchInfo) = length(info.results)
function nmatches(info::UnionSplitInfo)
n = 0
for mminfo in info.matches
n += nmatches(mminfo)
end
return n
end

"""
info::ConstCallInfo
The precision of this call was improved using constant information.
In addition to the original call information `info.call`, this info also keeps
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
"""
struct ConstCallInfo
call::Any
results::Vector{Union{Nothing,InferenceResult}}
end

"""
struct CallMeta
@@ -88,18 +109,6 @@ struct UnionSplitApplyCallInfo
infos::Vector{ApplyCallInfo}
end

"""
struct ConstCallInfo
Precision for this call was improved using constant information. This info
keeps a reference to the result that was used (or created for these)
constant information.
"""
struct ConstCallInfo
call::Any
results::Vector{Union{Nothing,InferenceResult}}
end

"""
struct InvokeCallInfo
24 changes: 13 additions & 11 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
@@ -438,17 +438,19 @@ end
import Base: @constprop

# test union-split callsite with successful and unsuccessful constant-prop' results
@constprop :aggressive @inline f42840(xs, a::Int) = xs[a] # should be successful, and inlined
@constprop :none @noinline f42840(xs::AbstractVector, a::Int) = xs[a] # should be unsuccessful, but still statically resolved
let src = code_typed1((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs
f42840(xs, 2)
end
@test count(src.code) do @nospecialize x
iscall((src, getfield), x) # `(xs::Tuple{Int,Int,Int})[a::Const(2)]` => `getfield(xs, 2)`
end == 1
@test count(src.code) do @nospecialize x
isinvoke(:f42840, x)
end == 1
# (also for https://github.com/JuliaLang/julia/issues/43287)
@constprop :aggressive @inline f42840(cond::Bool, xs::Tuple, a::Int) = # should be successful, and inlined with constant prop' result
cond ? xs[a] : length(xs)
@constprop :none @noinline f42840(::Bool, xs::AbstractVector, a::Int) = # should be unsuccessful, but still statically resolved
xs[a]
let src = code_typed((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs
f42840(true, xs, 2)
end |> only |> first
# `f43287(true, xs::Tuple{Int,Int,Int}, 2)` => `getfield(xs, 2)`
# `f43287(true, xs::Vector{Int}, 2)` => `:invoke f43287(true, xs, 2)`
@test count(x->iscall((src, getfield),x), src.code) == 1
@test count(x->isinvoke(:length, x), src.code) == 0
@test count(x->isinvoke(:f42840, x), src.code) == 1
end
# a bit weird, but should handle this kind of case as well
@constprop :aggressive @noinline g42840(xs, a::Int) = xs[a] # should be successful, but only statically resolved

0 comments on commit 590a384

Please sign in to comment.