Skip to content

Commit bfdae14

Browse files
committed
linearize
1 parent dc71ab2 commit bfdae14

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

base/inference.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3391,6 +3391,7 @@ function optimize(me::InferenceState)
33913391
alloc_elim_pass!(me)
33923392
getfield_elim_pass!(me)
33933393
copy_duplicated_expr_pass!(me)
3394+
linearize_pass!(me)
33943395
# Clean up for `alloc_elim_pass!` and `getfield_elim_pass!`
33953396
void_use_elim_pass!(me)
33963397
filter!(x -> x !== nothing, code)
@@ -5946,6 +5947,114 @@ function alloc_elim_pass!(sv::InferenceState)
59465947
end
59475948
end
59485949

5950+
function is_ccall_static(e::Expr, sv::InferenceState)
5951+
if e.head === :call
5952+
is_known_call(e, tuple, sv.src, sv.mod) || return false
5953+
length(e.args) == 3 || return false
5954+
for i in 2:3
5955+
a = e.args[i]
5956+
(isa(a, Expr) || isa(a, Slot) || isa(a, SSAValue)) && return false
5957+
end
5958+
return true
5959+
elseif e.head === :static_parameter
5960+
return true
5961+
end
5962+
return false
5963+
end
5964+
5965+
function linearize_arg!(args, i, stmts, sv::InferenceState)
5966+
a = args[i]
5967+
if isa(a, Symbol)
5968+
a = a::Symbol
5969+
isdefined(sv.mod, a) && isconst(sv.mod, a) && return
5970+
typ = Any
5971+
elseif isa(a, GlobalRef)
5972+
a = a::GlobalRef
5973+
isdefined(a.mod, a.name) && isconst(a.mod, a.name) && return
5974+
typ = Any
5975+
elseif isa(a, Expr)
5976+
typ = (a::Expr).typ
5977+
else
5978+
return
5979+
end
5980+
ssa = newvar!(sv, typ)
5981+
push!(stmts, :($ssa = $a))
5982+
args[i] = ssa
5983+
return
5984+
end
5985+
5986+
# Temporary pass to linearize the IR before `alloc_elim_pass!` before we do so in lowering
5987+
function linearize_pass!(sv::InferenceState)
5988+
body = sv.src.code
5989+
len = length(body)
5990+
next_i = 1
5991+
stmts = []
5992+
while next_i <= len
5993+
i = next_i
5994+
next_i += 1
5995+
ex = body[i]
5996+
isa(ex, Expr) || continue
5997+
ex = ex::Expr
5998+
head = ex.head
5999+
is_meta_expr_head(head) && continue
6000+
if head === :(=)
6001+
ex = ex.args[2]
6002+
isa(ex, Expr) || continue
6003+
ex = ex::Expr
6004+
head = ex.head
6005+
end
6006+
args = ex.args
6007+
if head === :foreigncall
6008+
if isa(args[1], Expr) && !is_ccall_static(args[1]::Expr, sv)
6009+
linearize_arg!(args, 1, stmts, sv)
6010+
end
6011+
for j in 2:length(args)
6012+
a = args[j]
6013+
isa(a, Expr) || continue
6014+
if a.head === :&
6015+
linearize_arg!(a.args, 1, stmts, sv)
6016+
else
6017+
linearize_arg!(args, j, stmts, sv)
6018+
end
6019+
end
6020+
elseif head === :isdefined || head === :const || is_meta_expr_head(head)
6021+
continue
6022+
elseif head === :call
6023+
if is_known_call(ex, Intrinsics.llvmcall, sv.src, sv.mod)
6024+
for j in 5:length(args)
6025+
linearize_arg!(args, j, stmts, sv)
6026+
end
6027+
elseif is_known_call(ex, Intrinsics.cglobal, sv.src, sv.mod)
6028+
if isa(args[2], Expr) && !is_ccall_static(args[2]::Expr, sv)
6029+
linearize_arg!(args, 2, stmts, sv)
6030+
end
6031+
for j in 3:length(args)
6032+
linearize_arg!(args, j, stmts, sv)
6033+
end
6034+
else
6035+
for j in 1:length(args)
6036+
linearize_arg!(args, j, stmts, sv)
6037+
end
6038+
end
6039+
else
6040+
for j in 1:length(args)
6041+
if j == 1 && head === :method
6042+
argj = args[j]
6043+
if isa(argj, Slot) || isa(argj, Symbol) || isa(argj, GlobalRef)
6044+
continue
6045+
end
6046+
end
6047+
linearize_arg!(args, j, stmts, sv)
6048+
end
6049+
end
6050+
isempty(stmts) && continue
6051+
next_i = i
6052+
splice!(body, i:(i - 1), stmts)
6053+
len += length(stmts)
6054+
empty!(stmts)
6055+
end
6056+
end
6057+
59496058
# Return the number of expressions added before `i0`
59506059
function replace_newvar_node!(body, orig, new_slots, i0)
59516060
nvars = length(new_slots)

0 commit comments

Comments
 (0)