diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index d11bb43c03ee0..ba3edcdf2d73b 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1342,6 +1342,14 @@ function ssa_def_slot(@nospecialize(arg), sv::InferenceState) return arg end +struct AbstractIterationResult + cti::Vector{Any} + info::MaybeAbstractIterationInfo + ai_effects::Effects +end +AbstractIterationResult(cti::Vector{Any}, info::MaybeAbstractIterationInfo) = + AbstractIterationResult(cti, info, EFFECTS_TOTAL) + # `typ` is the inferred type for expression `arg`. # if the expression constructs a container (e.g. `svec(x,y,z)`), # refine its type to an array of element types. @@ -1352,14 +1360,14 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft) if isa(typ, PartialStruct) widet = typ.typ if isa(widet, DataType) && widet.name === Tuple.name - return typ.fields, nothing + return AbstractIterationResult(typ.fields, nothing) end end if isa(typ, Const) val = typ.val if isa(val, SimpleVector) || isa(val, Tuple) - return Any[ Const(val[i]) for i in 1:length(val) ], nothing # avoid making a tuple Generator here! + return AbstractIterationResult(Any[ Const(val[i]) for i in 1:length(val) ], nothing) # avoid making a tuple Generator here! end end @@ -1374,12 +1382,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft) if isa(tti, Union) utis = uniontypes(tti) if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis) - return Any[Vararg{Any}], nothing + return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′) end ltp = length((utis[1]::DataType).parameters) for t in utis if length((t::DataType).parameters) != ltp - return Any[Vararg{Any}], nothing + return AbstractIterationResult(Any[Vararg{Any}], nothing) end end result = Any[ Union{} for _ in 1:ltp ] @@ -1390,12 +1398,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft) result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0)) end end - return result, nothing + return AbstractIterationResult(result, nothing) elseif tti0 <: Tuple if isa(tti0, DataType) - return Any[ p for p in tti0.parameters ], nothing + return AbstractIterationResult(Any[ p for p in tti0.parameters ], nothing) elseif !isa(tti, DataType) - return Any[Vararg{Any}], nothing + return AbstractIterationResult(Any[Vararg{Any}], nothing) else len = length(tti.parameters) last = tti.parameters[len] @@ -1404,12 +1412,14 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft) if va elts[len] = Vararg{elts[len]} end - return elts, nothing + return AbstractIterationResult(elts, nothing) end - elseif tti0 === SimpleVector || tti0 === Any - return Any[Vararg{Any}], nothing + elseif tti0 === SimpleVector + return AbstractIterationResult(Any[Vararg{Any}], nothing) + elseif tti0 === Any + return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′) elseif tti0 <: Array - return Any[Vararg{eltype(tti0)}], nothing + return AbstractIterationResult(Any[Vararg{eltype(tti0)}], nothing) else return abstract_iteration(interp, itft, typ, sv) end @@ -1420,7 +1430,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n if isa(itft, Const) iteratef = itft.val else - return Any[Vararg{Any}], nothing + return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′) end @assert !isvarargtype(itertype) call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[itft, itertype]), StmtInfo(true), sv) @@ -1430,7 +1440,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n # WARNING: Changes to the iteration protocol must be reflected here, # this is not just an optimization. # TODO: this doesn't realize that Array, SimpleVector, Tuple, and NamedTuple do not use the iterate protocol - stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, call.effects, info)]) + stateordonet === Bottom && return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, call.effects, info)], true)) valtype = statetype = Bottom ret = Any[] calls = CallMeta[call] @@ -1440,7 +1450,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n # length iterators, or interesting prefix while true if stateordonet_widened === Nothing - return ret, AbstractIterationInfo(calls) + return AbstractIterationResult(ret, AbstractIterationInfo(calls, true)) end if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).max_tuple_splat break @@ -1452,7 +1462,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n # If there's no new information in this statetype, don't bother continuing, # the iterator won't be finite. if ⊑(typeinf_lattice(interp), nstatetype, statetype) - return Any[Bottom], nothing + return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(calls, false), EFFECTS_THROWS) end valtype = getfield_tfunc(typeinf_lattice(interp), stateordonet, Const(1)) push!(ret, valtype) @@ -1482,7 +1492,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n # ... but cannot terminate if !may_have_terminated # ... and cannot have terminated prior to this loop - return Any[Bottom], nothing + return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(calls, false), EFFECTS_UNKNOWN′) else # iterator may have terminated prior to this loop, but not during it valtype = Bottom @@ -1492,13 +1502,15 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n end valtype = tmerge(valtype, nounion.parameters[1]) statetype = tmerge(statetype, nounion.parameters[2]) - stateordonet = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), StmtInfo(true), sv).rt + call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), StmtInfo(true), sv) + push!(calls, call) + stateordonet = call.rt stateordonet_widened = widenconst(stateordonet) end if valtype !== Union{} push!(ret, Vararg{valtype}) end - return ret, nothing + return AbstractIterationResult(ret, AbstractIterationInfo(calls, false)) end # do apply(af, fargs...), where af is a function value @@ -1529,13 +1541,9 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si:: infos′ = Vector{MaybeAbstractIterationInfo}[] for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]]) if !isvarargtype(ti) - cti_info = precise_container_type(interp, itft, ti, sv) - cti = cti_info[1]::Vector{Any} - info = cti_info[2]::MaybeAbstractIterationInfo + (;cti, info, ai_effects) = precise_container_type(interp, itft, ti, sv) else - cti_info = precise_container_type(interp, itft, unwrapva(ti), sv) - cti = cti_info[1]::Vector{Any} - info = cti_info[2]::MaybeAbstractIterationInfo + (;cti, info, ai_effects) = precise_container_type(interp, itft, unwrapva(ti), sv) # We can't represent a repeating sequence of the same types, # so tmerge everything together to get one type that represents # everything. @@ -1548,6 +1556,12 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si:: end cti = Any[Vararg{argt}] end + effects = merge_effects(effects, ai_effects) + if info !== nothing + for call in info.each + effects = merge_effects(effects, call.effects) + end + end if any(@nospecialize(t) -> t === Bottom, cti) continue end diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 43b9caa1b3154..63319509b672b 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -729,7 +729,7 @@ function rewrite_apply_exprargs!(todo::Vector{Pair{Int,Any}}, def = argexprs[i] def_type = argtypes[i] thisarginfo = arginfos[i-arg_start] - if thisarginfo === nothing + if thisarginfo === nothing || !thisarginfo.complete if def_type isa PartialStruct # def_type.typ <: Tuple is assumed def_argtypes = def_type.fields @@ -1134,9 +1134,9 @@ function inline_apply!(todo::Vector{Pair{Int,Any}}, for i = (arg_start + 1):length(argtypes) thisarginfo = nothing if !is_valid_type_for_apply_rewrite(argtypes[i], state.params) - if isa(info, ApplyCallInfo) && info.arginfo[i-arg_start] !== nothing - thisarginfo = info.arginfo[i-arg_start] - else + isa(info, ApplyCallInfo) || return nothing + thisarginfo = info.arginfo[i-arg_start] + if thisarginfo === nothing || !thisarginfo.complete return nothing end end diff --git a/base/compiler/ssair/irinterp.jl b/base/compiler/ssair/irinterp.jl index ad10064b2d74a..89e0851e84a60 100644 --- a/base/compiler/ssair/irinterp.jl +++ b/base/compiler/ssair/irinterp.jl @@ -66,7 +66,17 @@ function kill_def_use!(tpdum::TwoPhaseDefUseMap, def::Int, use::Int) if !tpdum.complete tpdum.ssa_uses[def] -= 1 else - @assert false && "TODO" + range = tpdum.ssa_uses[def]:(def == length(tpdum.ssa_uses) ? length(tpdum.data) : (tpdum.ssa_uses[def + 1] - 1)) + # TODO: Sorted + useidx = findfirst(idx->tpdum.data[idx] == use, range) + @assert useidx !== nothing + idx = range[useidx] + while idx < lastindex(range) + ndata = tpdum.data[idx+1] + ndata == 0 && break + tpdum.data[idx] = ndata + end + tpdum.data[idx + 1] = 0 end end kill_def_use!(tpdum::TwoPhaseDefUseMap, def::SSAValue, use::Int) = @@ -262,11 +272,11 @@ function process_terminator!(ir::IRCode, idx::Int, bb::Int, end return false elseif isa(inst, GotoNode) - backedge = inst.label < bb + backedge = inst.label <= bb !backedge && push!(ip, inst.label) return backedge elseif isa(inst, GotoIfNot) - backedge = inst.dest < bb + backedge = inst.dest <= bb !backedge && push!(ip, inst.dest) push!(ip, bb + 1) return backedge diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index 556c0082e4532..23f8c3aba908e 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -114,6 +114,7 @@ Each (abstract) call to `iterate`, corresponds to one entry in `ainfo.each::Vect """ struct AbstractIterationInfo each::Vector{CallMeta} + complete::Bool end const MaybeAbstractIterationInfo = Union{Nothing, AbstractIterationInfo} diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index d440b0097ac53..a6a8684f67595 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -4633,3 +4633,19 @@ end |> only === Type{Float64} # Issue #46839: `abstract_invoke` should handle incorrect call type @test only(Base.return_types(()->invoke(BitSet, Any, x), ())) === Union{} @test only(Base.return_types(()->invoke(BitSet, Union{Tuple{Int32},Tuple{Int64}}, 1), ())) === Union{} + +# Issue #47688: Abstract iteration should take into account `iterate` effects +global it_count47688 = 0 +struct CountsIterate47688{N}; end +function Base.iterate(::CountsIterate47688{N}, n=0) where N + global it_count47688 += 1 + n <= N ? (n, n+1) : nothing +end +foo47688() = tuple(CountsIterate47688{5}()...) +bar47688() = foo47688() +@test only(Base.return_types(bar47688)) == NTuple{6, Int} +@test it_count47688 == 0 +@test isa(bar47688(), NTuple{6, Int}) +@test it_count47688 == 7 +@test isa(foo47688(), NTuple{6, Int}) +@test it_count47688 == 14