Skip to content

Commit 14a3b12

Browse files
Kenoalexcrichton
authored andcommitted
[WebAssembly] Fix conflict between ret legalization and sjlj
Summary: When the WebAssembly backend encounters a return type that doesn't fit within i32, SelectionDAG performs sret demotion, adding an additional argument to the start of the function that contains a pointer to an sret buffer to use instead. However, this conflicts with the emscripten sjlj lowering pass. There we translate calls like: ``` call {i32, i32} @foo() ``` into (in pseudo-llvm) ``` %addr = @foo call {i32, i32} @__invoke_{i32,i32}(%addr) ``` i.e. we perform an indirect call through an extra function. However, the sret transform now transforms this into the equivalent of ``` %addr = @foo %sret = alloca {i32, i32} call {i32, i32} @__invoke_{i32,i32}(%sret, %addr) ``` (while simultaneously translation the implementation of @foo as well). Unfortunately, this doesn't work out. The __invoke_ ABI expected the function address to be the first argument, causing crashes. There is several possible ways to fix this: 1. Implementing the sret rewrite at the IR level as well and performing it as part of lowering to __invoke 2. Fixing the wasm backend to recognize that __invoke has a special ABI 3. A change to the binaryen/emscripten ABI to recognize this situation This revision implements the middle option, teaching the backend to treat __invoke_ functions specially in sret lowering. This is achieved by 1) Introducing a new CallingConv ID for invoke functions 2) When this CallingConv ID is seen in the backend and the first argument is marked as sret (a function pointer would never be marked as sret), swapping the first two arguments. Reviewed By: tlively, aheejin Differential Revision: https://reviews.llvm.org/D65463 llvm-svn: 367935
1 parent 8473db5 commit 14a3b12

File tree

7 files changed

+55
-9
lines changed

7 files changed

+55
-9
lines changed

llvm/include/llvm/IR/CallingConv.h

+8
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,14 @@ namespace CallingConv {
222222
// Calling convention between AArch64 Advanced SIMD functions
223223
AArch64_VectorCall = 97,
224224

225+
/// Calling convention between AArch64 SVE functions
226+
AArch64_SVE_VectorCall = 98,
227+
228+
/// Calling convention for emscripten __invoke_* functions. The first
229+
/// argument is required to be the function ptr being indirectly called.
230+
/// The remainder matches the regular calling convention.
231+
WASM_EmscriptenInvoke = 99,
232+
225233
/// The highest possible calling convention ID. Must be some 2^k - 1.
226234
MaxID = 1023
227235
};

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,8 @@ static bool callingConvSupported(CallingConv::ID CallConv) {
613613
CallConv == CallingConv::Cold ||
614614
CallConv == CallingConv::PreserveMost ||
615615
CallConv == CallingConv::PreserveAll ||
616-
CallConv == CallingConv::CXX_FAST_TLS;
616+
CallConv == CallingConv::CXX_FAST_TLS ||
617+
CallConv == CallingConv::WASM_EmscriptenInvoke;
617618
}
618619

619620
SDValue
@@ -649,6 +650,16 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
649650

650651
SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
651652
SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
653+
654+
// The generic code may have added an sret argument. If we're lowering an
655+
// invoke function, the ABI requires that the function pointer be the first
656+
// argument, so we may have to swap the arguments.
657+
if (CallConv == CallingConv::WASM_EmscriptenInvoke && Outs.size() >= 2 &&
658+
Outs[0].Flags.isSRet()) {
659+
std::swap(Outs[0], Outs[1]);
660+
std::swap(OutVals[0], OutVals[1]);
661+
}
662+
652663
unsigned NumFixedArgs = 0;
653664
for (unsigned I = 0; I < Outs.size(); ++I) {
654665
const ISD::OutputArg &Out = Outs[I];

llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ Value *WebAssemblyLowerEmscriptenEHSjLj::wrapInvoke(CallOrInvoke *CI) {
418418
Args.append(CI->arg_begin(), CI->arg_end());
419419
CallInst *NewCall = IRB.CreateCall(getInvokeWrapper(CI), Args);
420420
NewCall->takeName(CI);
421-
NewCall->setCallingConv(CI->getCallingConv());
421+
NewCall->setCallingConv(CallingConv::WASM_EmscriptenInvoke);
422422
NewCall->setDebugLoc(CI->getDebugLoc());
423423

424424
// Because we added the pointer to the callee as first argument, all

llvm/test/CodeGen/WebAssembly/lower-em-exceptions-whitelist.ll

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ entry:
3838
to label %invoke.cont unwind label %lpad
3939
; CHECK: entry:
4040
; CHECK-NEXT: store i32 0, i32*
41-
; CHECK-NEXT: call void @__invoke_void(void ()* @foo)
41+
; CHECK-NEXT: call cc{{.*}} void @__invoke_void(void ()* @foo)
4242

4343
invoke.cont: ; preds = %entry
4444
br label %try.cont

llvm/test/CodeGen/WebAssembly/lower-em-exceptions.ll

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ entry:
1616
to label %invoke.cont unwind label %lpad
1717
; CHECK: entry:
1818
; CHECK-NEXT: store i32 0, i32* @__THREW__
19-
; CHECK-NEXT: call void @__invoke_void_i32(void (i32)* @foo, i32 3)
19+
; CHECK-NEXT: call cc{{.*}} void @__invoke_void_i32(void (i32)* @foo, i32 3)
2020
; CHECK-NEXT: %[[__THREW__VAL:.*]] = load i32, i32* @__THREW__
2121
; CHECK-NEXT: store i32 0, i32* @__THREW__
2222
; CHECK-NEXT: %cmp = icmp eq i32 %[[__THREW__VAL]], 1
@@ -72,7 +72,7 @@ entry:
7272
to label %invoke.cont unwind label %lpad
7373
; CHECK: entry:
7474
; CHECK-NEXT: store i32 0, i32* @__THREW__
75-
; CHECK-NEXT: call void @__invoke_void_i32(void (i32)* @foo, i32 3)
75+
; CHECK-NEXT: call cc{{.*}} void @__invoke_void_i32(void (i32)* @foo, i32 3)
7676
; CHECK-NEXT: %[[__THREW__VAL:.*]] = load i32, i32* @__THREW__
7777
; CHECK-NEXT: store i32 0, i32* @__THREW__
7878
; CHECK-NEXT: %cmp = icmp eq i32 %[[__THREW__VAL]], 1
@@ -123,7 +123,7 @@ entry:
123123
to label %invoke.cont unwind label %lpad
124124
; CHECK: entry:
125125
; CHECK-NEXT: store i32 0, i32* @__THREW__
126-
; CHECK-NEXT: %0 = call noalias i8* @"__invoke_i8*_i8_i8"(i8* (i8, i8)* @bar, i8 signext 1, i8 zeroext 2)
126+
; CHECK-NEXT: %0 = call cc{{.*}} noalias i8* @"__invoke_i8*_i8_i8"(i8* (i8, i8)* @bar, i8 signext 1, i8 zeroext 2)
127127

128128
invoke.cont: ; preds = %entry
129129
br label %try.cont
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; RUN: llc < %s -asm-verbose=false -enable-emscripten-sjlj -wasm-keep-registers | FileCheck %s
2+
3+
target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128"
4+
target triple = "wasm32-unknown-unknown"
5+
6+
%struct.__jmp_buf_tag = type { [6 x i32], i32, [32 x i32] }
7+
8+
declare i32 @setjmp(%struct.__jmp_buf_tag*) #0
9+
declare {i32, i32} @returns_struct()
10+
11+
; Test the combination of backend legalization of large return types and the
12+
; Emscripten sjlj transformation
13+
define {i32, i32} @legalized_to_sret() {
14+
entry:
15+
%env = alloca [1 x %struct.__jmp_buf_tag], align 16
16+
%arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %env, i32 0, i32 0
17+
%call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0
18+
; This is the function pointer to pass to invoke.
19+
; It needs to be the first argument (that's what we're testing here)
20+
; CHECK: i32.const $push[[FPTR:[0-9]+]]=, returns_struct
21+
; This is the sret stack region (as an offset from the stack pointer local)
22+
; CHECK: call "__invoke_{i32.i32}", $pop[[FPTR]]
23+
%ret = call {i32, i32} @returns_struct()
24+
ret {i32, i32} %ret
25+
}
26+
27+
attributes #0 = { returns_twice }

llvm/test/CodeGen/WebAssembly/lower-em-sjlj.ll

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ entry:
3434
; CHECK-NEXT: phi i32 [ 0, %entry ], [ %[[LONGJMP_RESULT:.*]], %if.end ]
3535
; CHECK-NEXT: %[[ARRAYDECAY1:.*]] = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %[[BUF]], i32 0, i32 0
3636
; CHECK-NEXT: store i32 0, i32* @__THREW__
37-
; CHECK-NEXT: call void @"__invoke_void_%struct.__jmp_buf_tag*_i32"(void (%struct.__jmp_buf_tag*, i32)* @emscripten_longjmp_jmpbuf, %struct.__jmp_buf_tag* %[[ARRAYDECAY1]], i32 1)
37+
; CHECK-NEXT: call cc{{.*}} void @"__invoke_void_%struct.__jmp_buf_tag*_i32"(void (%struct.__jmp_buf_tag*, i32)* @emscripten_longjmp_jmpbuf, %struct.__jmp_buf_tag* %[[ARRAYDECAY1]], i32 1)
3838
; CHECK-NEXT: %[[__THREW__VAL:.*]] = load i32, i32* @__THREW__
3939
; CHECK-NEXT: store i32 0, i32* @__THREW__
4040
; CHECK-NEXT: %[[CMP0:.*]] = icmp ne i32 %__THREW__.val, 0
@@ -85,7 +85,7 @@ entry:
8585
; CHECK: %[[SETJMP_TABLE:.*]] = call i32* @saveSetjmp(
8686

8787
; CHECK: entry.split:
88-
; CHECK: call void @__invoke_void(void ()* @foo)
88+
; CHECK: @__invoke_void(void ()* @foo)
8989

9090
; CHECK: entry.split.split:
9191
; CHECK-NEXT: %[[BUF:.*]] = bitcast i32* %[[SETJMP_TABLE]] to i8*
@@ -105,7 +105,7 @@ entry:
105105

106106
; CHECK: entry.split:
107107
; CHECK: store i32 0, i32* @__THREW__
108-
; CHECK-NEXT: call void @__invoke_void(void ()* @foo)
108+
; CHECK-NEXT: call cc{{.*}} void @__invoke_void(void ()* @foo)
109109
; CHECK-NEXT: %[[__THREW__VAL:.*]] = load i32, i32* @__THREW__
110110
; CHECK-NEXT: store i32 0, i32* @__THREW__
111111
; CHECK-NEXT: %[[CMP0:.*]] = icmp ne i32 %[[__THREW__VAL]], 0

0 commit comments

Comments
 (0)