Skip to content

Commit 58bedde

Browse files
authored
Add a type annotation to return_call_ref (#5068)
The GC spec has been updated to have heap type annotations on call_ref and return_call_ref. To avoid breaking users, we will have a graceful, multi-step upgrade to the annotated version of call_ref, but since return_call_ref has no users yet, update it in a single step.
1 parent b1ba257 commit 58bedde

10 files changed

+126
-61
lines changed

src/ir/module-utils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ struct CodeScanner
5353
void visitExpression(Expression* curr) {
5454
if (auto* call = curr->dynCast<CallIndirect>()) {
5555
counts.note(call->heapType);
56+
} else if (auto* call = curr->dynCast<CallRef>()) {
57+
if (call->isReturn && call->target->type.isFunction()) {
58+
counts.note(call->target->type);
59+
}
5660
} else if (curr->is<RefNull>()) {
5761
counts.note(curr->type);
5862
} else if (auto* make = curr->dynCast<StructNew>()) {

src/passes/Print.cpp

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,9 +2043,31 @@ struct PrintExpressionContents
20432043
void visitI31Get(I31Get* curr) {
20442044
printMedium(o, curr->signed_ ? "i31.get_s" : "i31.get_u");
20452045
}
2046+
2047+
// If we cannot print a valid unreachable instruction (say, a struct.get,
2048+
// where if the ref is unreachable, we don't know what heap type to print),
2049+
// then print the children in a block, which is good enough as this
2050+
// instruction is never reached anyhow.
2051+
//
2052+
// This function checks if the input is in fact unreachable, and if so, begins
2053+
// to emit a replacement for it and returns true.
2054+
bool printUnreachableReplacement(Expression* curr) {
2055+
if (curr->type == Type::unreachable) {
2056+
printMedium(o, "block");
2057+
return true;
2058+
}
2059+
return false;
2060+
}
2061+
20462062
void visitCallRef(CallRef* curr) {
20472063
if (curr->isReturn) {
2048-
printMedium(o, "return_call_ref");
2064+
if (printUnreachableReplacement(curr->target)) {
2065+
return;
2066+
}
2067+
printMedium(o, "return_call_ref ");
2068+
assert(curr->target->type != Type::unreachable);
2069+
// TODO: Workaround if target has bottom type.
2070+
printHeapType(o, curr->target->type.getHeapType(), wasm);
20492071
} else {
20502072
printMedium(o, "call_ref");
20512073
}
@@ -2106,22 +2128,6 @@ struct PrintExpressionContents
21062128
}
21072129
printName(curr->name, o);
21082130
}
2109-
2110-
// If we cannot print a valid unreachable instruction (say, a struct.get,
2111-
// where if the ref is unreachable, we don't know what heap type to print),
2112-
// then print the children in a block, which is good enough as this
2113-
// instruction is never reached anyhow.
2114-
//
2115-
// This function checks if the input is in fact unreachable, and if so, begins
2116-
// to emit a replacement for it and returns true.
2117-
bool printUnreachableReplacement(Expression* curr) {
2118-
if (curr->type == Type::unreachable) {
2119-
printMedium(o, "block");
2120-
return true;
2121-
}
2122-
return false;
2123-
}
2124-
21252131
void visitStructNew(StructNew* curr) {
21262132
if (printUnreachableReplacement(curr)) {
21272133
return;
@@ -2748,6 +2754,13 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
27482754
}
27492755
decIndent();
27502756
}
2757+
void visitCallRef(CallRef* curr) {
2758+
if (curr->isReturn) {
2759+
maybePrintUnreachableReplacement(curr, curr->target->type);
2760+
} else {
2761+
visitExpression(curr);
2762+
}
2763+
}
27512764
void visitStructNew(StructNew* curr) {
27522765
maybePrintUnreachableReplacement(curr, curr->type);
27532766
}

src/wasm-binary.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1727,7 +1727,8 @@ class WasmBinaryBuilder {
17271727
void visitTryOrTryInBlock(Expression*& out);
17281728
void visitThrow(Throw* curr);
17291729
void visitRethrow(Rethrow* curr);
1730-
void visitCallRef(CallRef* curr);
1730+
void visitCallRef(CallRef* curr,
1731+
std::optional<HeapType> maybeType = std::nullopt);
17311732
void visitRefAs(RefAs* curr, uint8_t code);
17321733

17331734
[[noreturn]] void throwError(std::string text);

src/wasm/wasm-binary.cpp

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3783,7 +3783,7 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) {
37833783
auto call = allocator.alloc<CallRef>();
37843784
call->isReturn = true;
37853785
curr = call;
3786-
visitCallRef(call);
3786+
visitCallRef(call, getTypeByIndex(getU32LEB()));
37873787
break;
37883788
}
37893789
case BinaryConsts::AtomicPrefix: {
@@ -6777,30 +6777,44 @@ void WasmBinaryBuilder::visitRethrow(Rethrow* curr) {
67776777
curr->finalize();
67786778
}
67796779

6780-
void WasmBinaryBuilder::visitCallRef(CallRef* curr) {
6780+
void WasmBinaryBuilder::visitCallRef(CallRef* curr,
6781+
std::optional<HeapType> maybeType) {
67816782
BYN_TRACE("zz node: CallRef\n");
67826783
curr->target = popNonVoidExpression();
6783-
auto type = curr->target->type;
6784-
if (type == Type::unreachable) {
6785-
// If our input is unreachable, then we cannot even find out how many inputs
6786-
// we have, and just set ourselves to unreachable as well.
6787-
curr->finalize(type);
6788-
return;
6789-
}
6790-
if (!type.isRef()) {
6791-
throwError("Non-ref type for a call_ref: " + type.toString());
6784+
HeapType heapType;
6785+
if (maybeType) {
6786+
heapType = *maybeType;
6787+
if (!Type::isSubType(curr->target->type, Type(heapType, Nullable))) {
6788+
throwError("Call target has invalid type: " +
6789+
curr->target->type.toString());
6790+
}
6791+
} else {
6792+
auto type = curr->target->type;
6793+
if (type == Type::unreachable) {
6794+
// If our input is unreachable, then we cannot even find out how many
6795+
// inputs we have, and just set ourselves to unreachable as well.
6796+
curr->finalize(type);
6797+
return;
6798+
}
6799+
if (!type.isRef()) {
6800+
throwError("Non-ref type for a call_ref: " + type.toString());
6801+
}
6802+
heapType = type.getHeapType();
67926803
}
6793-
auto heapType = type.getHeapType();
67946804
if (!heapType.isSignature()) {
6795-
throwError("Invalid reference type for a call_ref: " + type.toString());
6805+
throwError("Invalid reference type for a call_ref: " + heapType.toString());
67966806
}
67976807
auto sig = heapType.getSignature();
67986808
auto num = sig.params.size();
67996809
curr->operands.resize(num);
68006810
for (size_t i = 0; i < num; i++) {
68016811
curr->operands[num - i - 1] = popNonVoidExpression();
68026812
}
6803-
curr->finalize(sig.results);
6813+
if (maybeType) {
6814+
curr->finalize();
6815+
} else {
6816+
curr->finalize(sig.results);
6817+
}
68046818
}
68056819

68066820
bool WasmBinaryBuilder::maybeVisitI31New(Expression*& out, uint32_t code) {

src/wasm/wasm-s-parser.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2830,9 +2830,24 @@ Expression* SExpressionWasmBuilder::makeTupleExtract(Element& s) {
28302830
}
28312831

28322832
Expression* SExpressionWasmBuilder::makeCallRef(Element& s, bool isReturn) {
2833+
Index operandsStart = 1;
2834+
HeapType sigType;
2835+
if (isReturn) {
2836+
sigType = parseHeapType(*s[1]);
2837+
operandsStart = 2;
2838+
}
28332839
std::vector<Expression*> operands;
2834-
parseOperands(s, 1, s.size() - 1, operands);
2840+
parseOperands(s, operandsStart, s.size() - 1, operands);
28352841
auto* target = parseExpression(s[s.size() - 1]);
2842+
2843+
if (isReturn) {
2844+
if (!sigType.isSignature()) {
2845+
throw ParseException(
2846+
"return_call_ref type annotation should be a signature", s.line, s.col);
2847+
}
2848+
return Builder(wasm).makeCallRef(
2849+
target, operands, sigType.getSignature().results, isReturn);
2850+
}
28362851
return ValidatingBuilder(wasm, s.line, s.col)
28372852
.validateAndMakeCallRef(target, operands, isReturn);
28382853
}

src/wasm/wasm-stack.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,8 +2013,14 @@ void BinaryInstWriter::visitI31Get(I31Get* curr) {
20132013
}
20142014

20152015
void BinaryInstWriter::visitCallRef(CallRef* curr) {
2016-
o << int8_t(curr->isReturn ? BinaryConsts::RetCallRef
2017-
: BinaryConsts::CallRef);
2016+
if (curr->isReturn) {
2017+
assert(curr->target->type != Type::unreachable);
2018+
// TODO: `emitUnreachable` if target has bottom type.
2019+
o << int8_t(BinaryConsts::RetCallRef);
2020+
parent.writeIndexedHeapType(curr->target->type.getHeapType());
2021+
return;
2022+
}
2023+
o << int8_t(BinaryConsts::CallRef);
20182024
}
20192025

20202026
void BinaryInstWriter::visitRefTest(RefTest* curr) {

test/lit/passes/dae-gc-refine-return.wast

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -586,20 +586,20 @@
586586
)
587587
;; CHECK: (func $tail-caller-call_ref-yes (result (ref ${}))
588588
;; CHECK-NEXT: (local $return_{} (ref null $return_{}))
589-
;; CHECK-NEXT: (return_call_ref
589+
;; CHECK-NEXT: (return_call_ref $return_{}
590590
;; CHECK-NEXT: (local.get $return_{})
591591
;; CHECK-NEXT: )
592592
;; CHECK-NEXT: )
593593
;; NOMNL: (func $tail-caller-call_ref-yes (type $return_{}) (result (ref ${}))
594594
;; NOMNL-NEXT: (local $return_{} (ref null $return_{}))
595-
;; NOMNL-NEXT: (return_call_ref
595+
;; NOMNL-NEXT: (return_call_ref $return_{}
596596
;; NOMNL-NEXT: (local.get $return_{})
597597
;; NOMNL-NEXT: )
598598
;; NOMNL-NEXT: )
599599
(func $tail-caller-call_ref-yes (result anyref)
600600
(local $return_{} (ref null $return_{}))
601601

602-
(return_call_ref (local.get $return_{}))
602+
(return_call_ref $return_{} (local.get $return_{}))
603603
)
604604
;; CHECK: (func $tail-caller-call_ref-no (result anyref)
605605
;; CHECK-NEXT: (local $any anyref)
@@ -610,7 +610,7 @@
610610
;; CHECK-NEXT: (local.get $any)
611611
;; CHECK-NEXT: )
612612
;; CHECK-NEXT: )
613-
;; CHECK-NEXT: (return_call_ref
613+
;; CHECK-NEXT: (return_call_ref $return_{}
614614
;; CHECK-NEXT: (local.get $return_{})
615615
;; CHECK-NEXT: )
616616
;; CHECK-NEXT: )
@@ -623,7 +623,7 @@
623623
;; NOMNL-NEXT: (local.get $any)
624624
;; NOMNL-NEXT: )
625625
;; NOMNL-NEXT: )
626-
;; NOMNL-NEXT: (return_call_ref
626+
;; NOMNL-NEXT: (return_call_ref $return_{}
627627
;; NOMNL-NEXT: (local.get $return_{})
628628
;; NOMNL-NEXT: )
629629
;; NOMNL-NEXT: )
@@ -634,18 +634,26 @@
634634
(if (i32.const 1)
635635
(return (local.get $any))
636636
)
637-
(return_call_ref (local.get $return_{}))
637+
(return_call_ref $return_{} (local.get $return_{}))
638638
)
639-
;; CHECK: (func $tail-caller-call_ref-unreachable
640-
;; CHECK-NEXT: (unreachable)
639+
;; CHECK: (func $tail-caller-call_ref-unreachable (result anyref)
640+
;; CHECK-NEXT: (block ;; (replaces something unreachable we can't emit)
641+
;; CHECK-NEXT: (drop
642+
;; CHECK-NEXT: (unreachable)
643+
;; CHECK-NEXT: )
644+
;; CHECK-NEXT: )
641645
;; CHECK-NEXT: )
642-
;; NOMNL: (func $tail-caller-call_ref-unreachable (type $none_=>_none)
643-
;; NOMNL-NEXT: (unreachable)
646+
;; NOMNL: (func $tail-caller-call_ref-unreachable (type $none_=>_anyref) (result anyref)
647+
;; NOMNL-NEXT: (block ;; (replaces something unreachable we can't emit)
648+
;; NOMNL-NEXT: (drop
649+
;; NOMNL-NEXT: (unreachable)
650+
;; NOMNL-NEXT: )
651+
;; NOMNL-NEXT: )
644652
;; NOMNL-NEXT: )
645653
(func $tail-caller-call_ref-unreachable (result anyref)
646654
;; An unreachable means there is no function signature to even look at. We
647655
;; should not hit an assertion on such things.
648-
(return_call_ref (unreachable))
656+
(return_call_ref $return_{} (unreachable))
649657
)
650658
;; CHECK: (func $tail-call-caller-call_ref
651659
;; CHECK-NEXT: (drop
@@ -654,7 +662,9 @@
654662
;; CHECK-NEXT: (drop
655663
;; CHECK-NEXT: (call $tail-caller-call_ref-no)
656664
;; CHECK-NEXT: )
657-
;; CHECK-NEXT: (call $tail-caller-call_ref-unreachable)
665+
;; CHECK-NEXT: (drop
666+
;; CHECK-NEXT: (call $tail-caller-call_ref-unreachable)
667+
;; CHECK-NEXT: )
658668
;; CHECK-NEXT: )
659669
;; NOMNL: (func $tail-call-caller-call_ref (type $none_=>_none)
660670
;; NOMNL-NEXT: (drop
@@ -663,7 +673,9 @@
663673
;; NOMNL-NEXT: (drop
664674
;; NOMNL-NEXT: (call $tail-caller-call_ref-no)
665675
;; NOMNL-NEXT: )
666-
;; NOMNL-NEXT: (call $tail-caller-call_ref-unreachable)
676+
;; NOMNL-NEXT: (drop
677+
;; NOMNL-NEXT: (call $tail-caller-call_ref-unreachable)
678+
;; NOMNL-NEXT: )
667679
;; NOMNL-NEXT: )
668680
(func $tail-call-caller-call_ref
669681
(drop

test/lit/passes/inlining_all-features.wast

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@
135135
(export "func_36_invoker" (func $1))
136136

137137
(func $0
138-
(return_call_ref
138+
(return_call_ref $none_=>_none
139139
(ref.null $none_=>_none)
140140
)
141141
)

test/lit/passes/optimize-instructions-call_ref.wast

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@
316316
(func $return_call_ref-to-select (param $x i32) (param $y i32)
317317
;; As above, but with a return call. We optimize this too, and turn a
318318
;; return_call_ref over a select into an if over return_calls.
319-
(return_call_ref
319+
(return_call_ref $i32_i32_=>_none
320320
(local.get $x)
321321
(local.get $y)
322322
(select

test/lit/types-function-references.wast

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
;; RUN: cat %t.text.wast | filecheck %s --check-prefix=CHECK-TEXT
1111

1212
(module
13-
;; inline ref type in result
14-
(type $_=>_eqref (func (result eqref)))
1513
;; CHECK-BINARY: (type $mixed_results (func (result anyref f32 anyref f32)))
1614

17-
;; CHECK-BINARY: (type $none_=>_none (func))
15+
;; CHECK-BINARY: (type $void (func))
16+
;; CHECK-TEXT: (type $mixed_results (func (result anyref f32 anyref f32)))
1817

18+
;; CHECK-TEXT: (type $void (func))
19+
(type $void (func))
20+
21+
;; inline ref type in result
22+
(type $_=>_eqref (func (result eqref)))
1923
;; CHECK-BINARY: (type $i32-i32 (func (param i32) (result i32)))
2024

2125
;; CHECK-BINARY: (type $=>eqref (func (result eqref)))
@@ -27,10 +31,6 @@
2731
;; CHECK-BINARY: (type $none_=>_i32 (func (result i32)))
2832

2933
;; CHECK-BINARY: (type $f64_=>_ref_null<_->_eqref> (func (param f64) (result (ref null $=>eqref))))
30-
;; CHECK-TEXT: (type $mixed_results (func (result anyref f32 anyref f32)))
31-
32-
;; CHECK-TEXT: (type $none_=>_none (func))
33-
3434
;; CHECK-TEXT: (type $i32-i32 (func (param i32) (result i32)))
3535

3636
;; CHECK-TEXT: (type $=>eqref (func (result eqref)))
@@ -77,17 +77,17 @@
7777
(call_ref (ref.func $call-ref))
7878
)
7979
;; CHECK-BINARY: (func $return-call-ref
80-
;; CHECK-BINARY-NEXT: (return_call_ref
80+
;; CHECK-BINARY-NEXT: (return_call_ref $void
8181
;; CHECK-BINARY-NEXT: (ref.func $call-ref)
8282
;; CHECK-BINARY-NEXT: )
8383
;; CHECK-BINARY-NEXT: )
8484
;; CHECK-TEXT: (func $return-call-ref
85-
;; CHECK-TEXT-NEXT: (return_call_ref
85+
;; CHECK-TEXT-NEXT: (return_call_ref $void
8686
;; CHECK-TEXT-NEXT: (ref.func $call-ref)
8787
;; CHECK-TEXT-NEXT: )
8888
;; CHECK-TEXT-NEXT: )
8989
(func $return-call-ref
90-
(return_call_ref (ref.func $call-ref))
90+
(return_call_ref $void (ref.func $call-ref))
9191
)
9292
;; CHECK-BINARY: (func $call-ref-more (param $0 i32) (result i32)
9393
;; CHECK-BINARY-NEXT: (call_ref
@@ -405,7 +405,7 @@
405405
;; CHECK-NODEBUG-NEXT: )
406406

407407
;; CHECK-NODEBUG: (func $1
408-
;; CHECK-NODEBUG-NEXT: (return_call_ref
408+
;; CHECK-NODEBUG-NEXT: (return_call_ref $none_=>_none
409409
;; CHECK-NODEBUG-NEXT: (ref.func $0)
410410
;; CHECK-NODEBUG-NEXT: )
411411
;; CHECK-NODEBUG-NEXT: )

0 commit comments

Comments
 (0)