Skip to content

Commit 62aee92

Browse files
committed
WIP: move to linear IR
1 parent 9d4d02c commit 62aee92

File tree

4 files changed

+116
-47
lines changed

4 files changed

+116
-47
lines changed

base/codevalidation.jl

Lines changed: 67 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ const VALID_EXPR_HEADS = ObjectIdDict(
1010
:(=) => 2:2,
1111
:method => 1:4,
1212
:const => 1:1,
13-
:null => 0:0, # TODO from @vtjnash: remove this + any :null handling code in Base
1413
:new => 1:typemax(Int),
1514
:return => 1:1,
1615
:the_exception => 0:0,
@@ -41,6 +40,7 @@ const SLOTTYPES_MISMATCH_UNINFERRED = "uninferred CodeInfo slottypes field is no
4140
const SSAVALUETYPES_MISMATCH = "not all SSAValues in AST have a type in ssavaluetypes"
4241
const SSAVALUETYPES_MISMATCH_UNINFERRED = "uninferred CodeInfo ssavaluetypes field does not equal the number of present SSAValues"
4342
const NON_TOP_LEVEL_METHOD = "encountered `Expr` head `:method` in non-top-level code (i.e. `nargs` > 0)"
43+
const NON_TOP_LEVEL_GLOBAL = "encountered `Expr` head `:global` in non-top-level code (i.e. `nargs` > 0)"
4444
const SIGNATURE_NARGS_MISMATCH = "method signature does not match number of method arguments"
4545
const SLOTNAMES_NARGS_MISMATCH = "CodeInfo for method contains fewer slotnames than the number of method arguments"
4646

@@ -57,38 +57,72 @@ InvalidCodeError(kind::AbstractString) = InvalidCodeError(kind, nothing)
5757
Validate `c`, logging any violation by pushing an `InvalidCodeError` into `errors`.
5858
"""
5959
function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_level::Bool = false)
60+
function validate_val!(@nospecialize(x))
61+
if isa(x, Expr)
62+
if x.head == :call || x.head == :invoke
63+
for arg in x.args
64+
if !is_valid_argument(arg)
65+
push!(errors, InvalidCodeError(INVALID_CALL_ARG, arg))
66+
else
67+
validate_val!(arg)
68+
end
69+
end
70+
end
71+
elseif isa(x, SSAValue)
72+
id = x.id + 1 # ensures that id > 0 for use with IntSet
73+
!in(id, ssavals) && push!(ssavals, id)
74+
end
75+
end
76+
6077
ssavals = IntSet()
6178
lhs_slotnums = IntSet()
62-
walkast(c.code) do x
79+
for x in c.code
6380
if isa(x, Expr)
64-
!is_top_level && x.head == :method && push!(errors, InvalidCodeError(NON_TOP_LEVEL_METHOD))
81+
if !is_top_level
82+
x.head === :method && push!(errors, InvalidCodeError(NON_TOP_LEVEL_METHOD))
83+
x.head === :global && push!(errors, InvalidCodeError(NON_TOP_LEVEL_GLOBAL))
84+
end
6585
narg_bounds = get(VALID_EXPR_HEADS, x.head, -1:-1)
6686
nargs = length(x.args)
6787
if narg_bounds == -1:-1
6888
push!(errors, InvalidCodeError(INVALID_EXPR_HEAD, (x.head, x)))
6989
elseif !in(nargs, narg_bounds)
7090
push!(errors, InvalidCodeError(INVALID_EXPR_NARGS, (x.head, nargs, x)))
71-
elseif x.head == :(=)
91+
elseif x.head === :(=)
7292
lhs, rhs = x.args
7393
if !is_valid_lvalue(lhs)
7494
push!(errors, InvalidCodeError(INVALID_LVALUE, lhs))
7595
elseif isa(lhs, SlotNumber) && !in(lhs.id, lhs_slotnums)
7696
n = lhs.id
7797
push!(lhs_slotnums, n)
7898
end
79-
if !is_valid_rvalue(rhs)
99+
if !is_valid_rvalue(lhs, rhs)
80100
push!(errors, InvalidCodeError(INVALID_RVALUE, rhs))
81101
end
82-
elseif x.head == :call || x.head == :invoke
83-
for arg in x.args
84-
if !is_valid_rvalue(arg)
85-
push!(errors, InvalidCodeError(INVALID_CALL_ARG, arg))
86-
end
102+
validate_val!(lhs)
103+
validate_val!(rhs)
104+
elseif x.head === :gotoifnot
105+
if !is_valid_argument(x.args[1])
106+
push!(errors, InvalidCodeError(INVALID_CALL_ARG, x.args[1]))
107+
end
108+
validate_val!(x.args[1])
109+
elseif x.head === :return
110+
if !is_valid_rvalue(nothing, x.args[1])
111+
push!(errors, InvalidCodeError(INVALID_RVALUE, x.args[1]))
87112
end
113+
validate_val!(x.args[1])
114+
else
115+
validate_val!(x)
88116
end
89-
elseif isa(x, SSAValue)
90-
id = x.id + 1 # ensures that id > 0 for use with IntSet
91-
!in(id, ssavals) && push!(ssavals, id)
117+
elseif isa(x, NewvarNode)
118+
elseif isa(x, LabelNode)
119+
elseif isa(x, GotoNode)
120+
elseif x === nothing
121+
elseif isa(x, SlotNumber)
122+
elseif isa(x, GlobalRef)
123+
elseif isa(x, LineNumberNode)
124+
else
125+
push!(errors, InvalidCodeError("invalid statement", x))
92126
end
93127
end
94128
nslotnames = length(c.slotnames)
@@ -133,18 +167,29 @@ end
133167

134168
validate_code(args...) = validate_code!(Vector{InvalidCodeError}(), args...)
135169

136-
function walkast(f, stmts::Array)
137-
for stmt in stmts
138-
f(stmt)
139-
isa(stmt, Expr) && walkast(f, stmt.args)
170+
is_valid_lvalue(x) = isa(x, SlotNumber) || isa(x, SSAValue) || isa(x, GlobalRef)
171+
172+
function is_valid_argument(x)
173+
if isa(x, SlotNumber) || isa(x, SSAValue) || isa(x, GlobalRef) || isa(x, QuoteNode) ||
174+
(isa(x,Expr) && (x.head in (:static_parameter, :boundscheck, :copyast))) ||
175+
isa(x, Number) || isa(x, AbstractString) || isa(x, Char) || isa(x, Tuple) ||
176+
isa(x, Type) || isa(x, Core.Box) || isa(x, Module) || x === nothing
177+
return true
140178
end
179+
# TODO: consider being stricter about what needs to be wrapped with QuoteNode
180+
return !(isa(x,Expr) || isa(x,Symbol) || isa(x,GotoNode) || isa(x,LabelNode) ||
181+
isa(x,LineNumberNode) || isa(x,NewvarNode))
141182
end
142183

143-
is_valid_lvalue(x) = isa(x, SlotNumber) || isa(x, SSAValue) || isa(x, GlobalRef)
144-
145-
function is_valid_rvalue(x)
146-
isa(x, Expr) && return !in(x.head, (:gotoifnot, :line, :const, :meta))
147-
return !isa(x, GotoNode) && !isa(x, LabelNode) && !isa(x, LineNumberNode)
184+
function is_valid_rvalue(lhs, x)
185+
is_valid_argument(x) && return true
186+
if isa(x, Expr)
187+
if isa(lhs, SSAValue) && x.head in (:call, :invoke, :foreigncall, :gc_preserve_begin)
188+
return true
189+
end
190+
return x.head in (:new, :the_exception, :isdefined)
191+
end
192+
return false
148193
end
149194

150195
is_flag_set(byte::UInt8, flag::UInt8) = (byte & flag) == flag

base/inference.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,8 +2383,6 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
23832383
e = e::Expr
23842384
if e.head === :call
23852385
t = abstract_eval_call(e, vtypes, sv)
2386-
elseif e.head === :null
2387-
t = Void
23882386
elseif e.head === :new
23892387
t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1]
23902388
for i = 2:length(e.args)
@@ -4036,7 +4034,7 @@ function inline_as_constant(@nospecialize(val), argexprs, sv::InferenceState, @n
40364034
end
40374035

40384036
function is_self_quoting(@nospecialize(x))
4039-
return isa(x,Number) || isa(x,AbstractString) || isa(x,Tuple) || isa(x,Type)
4037+
return isa(x,Number) || isa(x,AbstractString) || isa(x,Tuple) || isa(x,Type) || isa(x,Char) || x === nothing
40404038
end
40414039

40424040
function countunionsplit(atypes)

src/ast.scm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@
234234
(ssavalue? e)))
235235

236236
(define (simple-atom? x)
237-
(or (number? x) (string? x) (char? x) (eq? x 'true) (eq? x 'false)))
237+
(or (number? x) (string? x) (char? x) (eq? x 'true) (eq? x 'false)
238+
(ssavalue? x) (eq? (typeof x) 'julia_value)))
238239

239240
;; identify some expressions that are safe to repeat
240241
(define (effect-free? e)

src/julia-syntax.scm

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2950,7 +2950,7 @@ f(x) = yt(x)
29502950
(capt (and vi (vinfo:asgn vi) (vinfo:capt vi))))
29512951
(if (and (not closed) (not capt) (equal? vt '(core Any)))
29522952
`(= ,var ,rhs0)
2953-
(let* ((rhs1 (if (or (ssavalue? rhs0) (simple-atom? rhs0)
2953+
(let* ((rhs1 (if (or (simple-atom? rhs0)
29542954
(equal? rhs0 '(the_exception)))
29552955
rhs0
29562956
(make-ssavalue)))
@@ -3415,7 +3415,7 @@ f(x) = yt(x)
34153415
;; pass 5: convert to linear IR
34163416

34173417
;; with this enabled, all nested calls are assigned to numbered locations
3418-
(define *very-linear-mode* #f)
3418+
(define *very-linear-mode* #t)
34193419

34203420
(define (linearize e)
34213421
(cond ((or (not (pair? e)) (quoted? e)) e)
@@ -3431,6 +3431,18 @@ f(x) = yt(x)
34313431
(error msg)
34323432
(io.write *stderr* msg)))
34333433

3434+
(define (valid-ir-argument? e)
3435+
(or (simple-atom? e) (symbol? e)
3436+
(and (pair? e)
3437+
(memq (car e) '(quote inert top core globalref outerref null
3438+
slot static_parameter boundscheck copyast)))))
3439+
3440+
(define (valid-ir-rvalue? lhs e)
3441+
(or (valid-ir-argument? e)
3442+
(and (pair? e)
3443+
(or (memq (car e) '(new the_exception isdefined))
3444+
(ssavalue? lhs)))))
3445+
34343446
;; this pass behaves like an interpreter on the given code.
34353447
;; to perform stateful operations, it calls `emit` to record that something
34363448
;; needs to be done. in value position, it returns an expression computing
@@ -3464,16 +3476,20 @@ f(x) = yt(x)
34643476
(mark-label l)
34653477
l)))
34663478
(define (emit-return x)
3467-
(let ((rv (if (> handler-level 0)
3468-
(let ((tmp (if (or (simple-atom? x) (ssavalue? x) (equal? x '(null)))
3469-
#f (make-ssavalue))))
3470-
(if tmp (emit `(= ,tmp ,x)))
3471-
(emit `(leave ,handler-level))
3472-
(or tmp x))
3473-
x)))
3474-
(if rett
3475-
(emit `(return ,(convert-for-type-decl rv rett)))
3476-
(emit `(return ,rv)))))
3479+
(let ((x (if rett
3480+
(compile (convert-for-type-decl x rett) '() #t #f)
3481+
x)))
3482+
(let ((tmp (if (or (and (> handler-level 0)
3483+
(not (or (simple-atom? x) (equal? x '(null)))))
3484+
(not (or (valid-ir-rvalue? #f x)
3485+
;; returning lambda directly is needed for @generated
3486+
(and (pair? x) (eq? (car x) 'lambda)))))
3487+
(make-ssavalue)
3488+
#f)))
3489+
(if tmp (emit `(= ,tmp ,x)))
3490+
(if (> handler-level 0)
3491+
(emit `(leave ,handler-level)))
3492+
(emit `(return ,(or tmp x))))))
34773493
(define (new-mutable-var . name)
34783494
(let ((g (if (null? name) (gensy) (named-gensy (car name)))))
34793495
(set-car! (lam:vinfo lam) (append (car (lam:vinfo lam)) `((,g Any 2))))
@@ -3487,7 +3503,7 @@ f(x) = yt(x)
34873503
(and (pair? x) (eq? (car x) 'block))))
34883504
e))
34893505
(cdr lst))))
3490-
(simple? (every (lambda (x) (or (simple-atom? x) (symbol? x) (ssavalue? x)
3506+
(simple? (every (lambda (x) (or (simple-atom? x) (symbol? x)
34913507
(and (pair? x)
34923508
(memq (car x) '(quote inert top core globalref outerref copyast)))))
34933509
lst)))
@@ -3499,10 +3515,12 @@ f(x) = yt(x)
34993515
(aval (compile arg break-labels #t #f linearize)))
35003516
(loop (cdr lst)
35013517
(cons (if (and temps? linearize (not simple?)
3502-
(not (simple-atom? arg)) (not (ssavalue? arg))
3503-
(not (simple-atom? aval)) (not (ssavalue? aval))
3518+
(not (simple-atom? arg))
3519+
(not (simple-atom? aval))
35043520
(not (and (pair? arg)
35053521
(memq (car arg) '(& quote inert top core globalref outerref copyast))))
3522+
(not (and (symbol? aval) ;; function args are immutable and always assigned
3523+
(memq aval (lam:args lam))))
35063524
(not (and (symbol? arg)
35073525
(or (null? (cdr lst))
35083526
(null? vals)))))
@@ -3514,11 +3532,18 @@ f(x) = yt(x)
35143532
(define (compile-cond ex break-labels)
35153533
(let ((cnd (compile ex break-labels #t #f)))
35163534
(if (and *very-linear-mode*
3517-
(not (or (simple-atom? cnd) (ssavalue? cnd) (symbol? cnd))))
3535+
(not (valid-ir-argument? cnd)))
35183536
(let ((tmp (make-ssavalue)))
35193537
(emit `(= ,tmp ,cnd))
35203538
tmp)
35213539
cnd)))
3540+
(define (emit-assignment lhs rhs)
3541+
(if (valid-ir-rvalue? lhs rhs)
3542+
(emit `(= ,lhs ,rhs))
3543+
(let ((rr (make-ssavalue)))
3544+
(emit `(= ,rr ,rhs))
3545+
(emit `(= ,lhs ,rr))))
3546+
`(null))
35223547
;; the interpreter loop. `break-labels` keeps track of the labels to jump to
35233548
;; for all currently closing break-blocks.
35243549
;; `value` means we are in a context where a value is required; a meaningful
@@ -3595,7 +3620,7 @@ f(x) = yt(x)
35953620
(emit `(= ,lhs ,rr))
35963621
(if tail (emit-return rr))
35973622
rr)
3598-
(emit `(= ,lhs ,rhs)))))
3623+
(emit-assignment lhs rhs))))
35993624
((block body)
36003625
(let* ((last-fname filename)
36013626
(fnm (first-non-meta e))
@@ -3650,14 +3675,14 @@ f(x) = yt(x)
36503675
(val (if (and value (not tail)) (new-mutable-var) #f)))
36513676
(emit test)
36523677
(let ((v1 (compile (caddr e) break-labels value tail)))
3653-
(if val (emit `(= ,val ,v1)))
3678+
(if val (emit-assignment val v1))
36543679
(if (and (not tail) (or (length> e 3) val))
36553680
(emit end-jump))
36563681
(set-car! (cddr test) (make&mark-label))
36573682
(let ((v2 (if (length> e 3)
36583683
(compile (cadddr e) break-labels value tail)
36593684
'(null))))
3660-
(if val (emit `(= ,val ,v2)))
3685+
(if val (emit-assignment val v2))
36613686
(if (not tail)
36623687
(set-car! (cdr end-jump) (make&mark-label))
36633688
(if (length= e 3)
@@ -3722,7 +3747,7 @@ f(x) = yt(x)
37223747
break-labels value #f))
37233748
(val (if (and value (not tail))
37243749
(new-mutable-var) #f)))
3725-
(if val (emit `(= ,val ,v1)))
3750+
(if val (emit-assignment val v1))
37263751
(if tail
37273752
(begin (emit-return v1)
37283753
(set! endl #f))
@@ -3732,7 +3757,7 @@ f(x) = yt(x)
37323757
(mark-label catch)
37333758
(emit `(leave 1))
37343759
(let ((v2 (compile (caddr e) break-labels value tail)))
3735-
(if val (emit `(= ,val ,v2)))
3760+
(if val (emit-assignment val v2))
37363761
(if endl (mark-label endl))
37373762
val))))
37383763

0 commit comments

Comments
 (0)