From a7e53d142f339c2edd817111246d80ddec2d3492 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 6 Dec 2021 22:04:26 +0900 Subject: [PATCH] optimizer: fully support inlining of union-split, partially constant-prop' callsite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Makes full use of constant-propagation, by addressing this [TODO](https://github.com/JuliaLang/julia/blob/00734c5fd045316a00d287ca2c0ec1a2eef6e4d1/base/compiler/ssair/inlining.jl#L1212). Here is a performance improvement from #43287: ```julia ulia> using BenchmarkTools julia> X = rand(ComplexF32, 64, 64); julia> dst = reinterpret(reshape, Float32, X); julia> src = copy(dst); julia> @btime copyto!($dst, $src); 50.819 μs (1 allocation: 32 bytes) # v1.6.4 41.081 μs (0 allocations: 0 bytes) # this commit ``` fixes #43287 --- base/compiler/abstractinterpretation.jl | 1 + base/compiler/ssair/inlining.jl | 156 +++++++++++++----------- base/compiler/stmtinfo.jl | 33 +++-- test/compiler/inline.jl | 13 +- 4 files changed, 113 insertions(+), 90 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index b1d4eceefc2cf..d5eda991c641f 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -139,6 +139,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # by constant analysis, but let's create `ConstCallInfo` if there has been any successful # constant propagation happened since other consumers may be interested in this if any_const_result && seen == napplicable + @assert napplicable == nmatches(info) == length(const_results) info = ConstCallInfo(info, const_results) end diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 55445f5c8032b..fceffa39de578 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -672,19 +672,16 @@ function rewrite_apply_exprargs!( new_sig = with_atype(call_sig(ir, new_stmt)::Signature) new_info = call.info if isa(new_info, ConstCallInfo) - maybe_handle_const_call!( + handle_const_call!( ir, state1.id, new_stmt, new_info, flag, - new_sig, istate, todo) && @goto analyzed - new_info = new_info.call # cascade to the non-constant handling - end - if isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo) + new_sig, istate, todo) + elseif isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo) new_infos = isa(new_info, MethodMatchInfo) ? MethodMatchInfo[new_info] : new_info.matches # See if we can inline this call to `iterate` analyze_single_call!( ir, state1.id, new_stmt, new_infos, flag, new_sig, istate, todo) end - @label analyzed if i != length(thisarginfo.each) valT = getfield_tfunc(call.rt, Const(1)) val_extracted = insert_node!(ir, idx, NewInstruction( @@ -1126,7 +1123,7 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto return stmt, sig end -# TODO inline non-`isdispatchtuple`, union-split callsites +# TODO inline non-`isdispatchtuple`, union-split callsites? function analyze_single_call!( ir::IRCode, idx::Int, stmt::Expr, infos::Vector{MethodMatchInfo}, flag::UInt8, sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}}) @@ -1134,46 +1131,35 @@ function analyze_single_call!( cases = InliningCase[] local signature_union = Bottom local only_method = nothing # keep track of whether there is one matching method - local meth + local meth::MethodLookupResult local fully_covered = true for i in 1:length(infos) - info = infos[i] - meth = info.results + meth = infos[i].results if meth.ambig # Too many applicable methods # Or there is a (partial?) ambiguity - return + return nothing elseif length(meth) == 0 # No applicable methods; try next union split continue - elseif length(meth) == 1 && only_method !== false - if only_method === nothing - only_method = meth[1].method - elseif only_method !== meth[1].method + else + if length(meth) == 1 && only_method !== false + if only_method === nothing + only_method = meth[1].method + elseif only_method !== meth[1].method + only_method = false + end + else only_method = false end - else - only_method = false end for match in meth - spec_types = match.spec_types - signature_union = Union{signature_union, spec_types} - if !isdispatchtuple(spec_types) - fully_covered = false - continue - end - item = analyze_method!(match, argtypes, flag, state) - if item === nothing - fully_covered = false - continue - elseif _any(case->case.sig === spec_types, cases) - continue - end - push!(cases, InliningCase(spec_types, item)) + signature_union = Union{signature_union, match.spec_types} + fully_covered &= handle_match!(match, argtypes, flag, state, cases) end end - # if the signature is fully or mostly covered and there is only one applicable method, + # if the signature is fully covered and there is only one applicable method, # we can try to inline it even if the signature is not a dispatch tuple if length(cases) == 0 && only_method isa Method if length(infos) > 1 @@ -1181,60 +1167,52 @@ function analyze_single_call!( atype, only_method.sig)::SimpleVector match = MethodMatch(metharg, methsp::SimpleVector, only_method, true) else - meth = meth::MethodLookupResult @assert length(meth) == 1 match = meth[1] end item = analyze_method!(match, argtypes, flag, state) - item === nothing && return + item === nothing && return nothing push!(cases, InliningCase(match.spec_types, item)) fully_covered = match.fully_covers else fully_covered &= atype <: signature_union end - # If we only have one case and that case is fully covered, we may either - # be able to do the inlining now (for constant cases), or push it directly - # onto the todo list - if fully_covered && length(cases) == 1 - handle_single_case!(ir, idx, stmt, cases[1].item, todo) - elseif length(cases) > 0 - push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) - end - return nothing + handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo) end -# try to create `InliningCase`s using constant-prop'ed results -# currently it works only when constant-prop' succeeded for all (union-split) signatures -# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later -# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out -function maybe_handle_const_call!( - ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, flag::UInt8, +# similar to `analyze_single_call!`, but with constant results +function handle_const_call!( + ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo, flag::UInt8, sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}}) (; argtypes, atype) = sig - results = info.results - cases = InliningCase[] # TODO avoid this allocation for single cases ? + (; call, results) = cinfo + infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches + cases = InliningCase[] local fully_covered = true local signature_union = Bottom - for result in results - isa(result, InferenceResult) || return false - (; mi) = item = InliningTodo(result, argtypes) - spec_types = mi.specTypes - signature_union = Union{signature_union, spec_types} - if !isdispatchtuple(spec_types) - fully_covered = false - continue - end - if !validate_sparams(mi.sparam_vals) - fully_covered = false + local j = 0 + for i in 1:length(infos) + meth = infos[i].results + if meth.ambig + # Too many applicable methods + # Or there is a (partial?) ambiguity + return nothing + elseif length(meth) == 0 + # No applicable methods; try next union split continue end - state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) - if item === nothing - fully_covered = false - continue + for match in meth + j += 1 + result = results[j] + if result === nothing + signature_union = Union{signature_union, match.spec_types} + fully_covered &= handle_match!(match, argtypes, flag, state, cases) + else + signature_union = Union{signature_union, result.linfo.specTypes} + fully_covered &= handle_const_result!(result, argtypes, flag, state, cases) + end end - push!(cases, InliningCase(spec_types, item)) end # if the signature is fully covered and there is only one applicable method, @@ -1242,23 +1220,53 @@ function maybe_handle_const_call!( if length(cases) == 0 && length(results) == 1 (; mi) = item = InliningTodo(results[1]::InferenceResult, argtypes) state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) - validate_sparams(mi.sparam_vals) || return true - item === nothing && return true + validate_sparams(mi.sparam_vals) || return nothing + item === nothing && return nothing push!(cases, InliningCase(mi.specTypes, item)) fully_covered = atype <: mi.specTypes else fully_covered &= atype <: signature_union end + handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo) +end + +function handle_match!( + match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState, + cases::Vector{InliningCase}) + spec_types = match.spec_types + isdispatchtuple(spec_types) || return false + item = analyze_method!(match, argtypes, flag, state) + item === nothing && return false + _any(case->case.sig === spec_types, cases) && return true + push!(cases, InliningCase(spec_types, item)) + return true +end + +function handle_const_result!( + result::InferenceResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState, + cases::Vector{InliningCase}) + (; mi) = item = InliningTodo(result, argtypes) + spec_types = mi.specTypes + isdispatchtuple(spec_types) || return false + validate_sparams(mi.sparam_vals) || return false + state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) + item === nothing && return false + push!(cases, InliningCase(spec_types, item)) + return true +end + +function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, sig::Signature, + cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}}) # If we only have one case and that case is fully covered, we may either # be able to do the inlining now (for constant cases), or push it directly # onto the todo list if fully_covered && length(cases) == 1 handle_single_case!(ir, idx, stmt, cases[1].item, todo) elseif length(cases) > 0 - push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) + push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases)) end - return true + return nothing end function handle_const_opaque_closure_call!( @@ -1324,10 +1332,10 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) # if inference arrived here with constant-prop'ed result(s), # we can perform a specialized analysis for just this case if isa(info, ConstCallInfo) - maybe_handle_const_call!( + handle_const_call!( ir, idx, stmt, info, flag, - sig, state, todo) && continue - info = info.call # cascade to the non-constant handling + sig, state, todo) + continue end # Ok, now figure out what method to call diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index 6360f1697d417..ca8c7d0d27d56 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -38,6 +38,27 @@ struct UnionSplitInfo matches::Vector{MethodMatchInfo} end +nmatches(info::MethodMatchInfo) = length(info.results) +function nmatches(info::UnionSplitInfo) + n = 0 + for mminfo in info.matches + n += nmatches(mminfo) + end + return n +end + +""" + info::ConstCallInfo + +The precision of this call was improved using constant information. +In addition to the original call information `info.call`, this info also keeps +the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`. +""" +struct ConstCallInfo + call::Union{MethodMatchInfo,UnionSplitInfo} + results::Vector{Union{Nothing,InferenceResult}} +end + """ info::MethodResultPure @@ -92,18 +113,6 @@ struct UnionSplitApplyCallInfo infos::Vector{ApplyCallInfo} end -""" - info::ConstCallInfo - -The precision of this call was improved using constant information. -In addition to the original call information `info.call`, this info also keeps -the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`. -""" -struct ConstCallInfo - call::Union{MethodMatchInfo,UnionSplitInfo} - results::Vector{Union{Nothing,InferenceResult}} -end - """ info::InvokeCallInfo diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index 83780ca8b1ac5..8fd8e4236d988 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -758,13 +758,18 @@ end import Base: @constprop # test union-split callsite with successful and unsuccessful constant-prop' results -@constprop :aggressive @inline f42840(xs, a::Int) = xs[a] # should be successful, and inlined -@constprop :none @noinline f42840(xs::AbstractVector, a::Int) = xs[a] # should be unsuccessful, but still statically resolved +# (also for https://github.com/JuliaLang/julia/issues/43287) +@constprop :aggressive @inline f42840(cond::Bool, xs::Tuple, a::Int) = # should be successful, and inlined with constant prop' result + cond ? xs[a] : @noinline(length(xs)) +@constprop :none @noinline f42840(::Bool, xs::AbstractVector, a::Int) = # should be unsuccessful, but still statically resolved + xs[a] let src = code_typed((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs - f42840(xs, 2) + f42840(true, xs, 2) end |> only |> first - # `(xs::Tuple{Int,Int,Int})[a::Const(2)]` => `getfield(xs, 2)` + # `f43287(true, xs::Tuple{Int,Int,Int}, 2)` => `getfield(xs, 2)` + # `f43287(true, xs::Vector{Int}, 2)` => `:invoke f43287(true, xs, 2)` @test count(iscall((src, getfield)), src.code) == 1 + @test count(isinvoke(:length), src.code) == 0 @test count(isinvoke(:f42840), src.code) == 1 end # a bit weird, but should handle this kind of case as well