Skip to content

Commit cbfdca3

Browse files
committed
[NVPTX] Fix internal indirect call prototypes not obeying the ABI
Summary: The NVPTX backend optimizes the ABI for functions that are internal, however, this is not legal for indirect call prototypes. Previously, we would modify the ABI on an aggregate byval type passed to an indirect call prototype, which would make PTXAS error. This patch just passes the function as a nullptr to force strict ABI compliance without modification in the helper function. Fixes #100055
1 parent c747300 commit cbfdca3

File tree

3 files changed

+100
-13
lines changed

3 files changed

+100
-13
lines changed

libc/config/gpu/entrypoints.txt

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
if(LIBC_TARGET_ARCHITECTURE_IS_AMDGPU)
2-
set(extra_entrypoints
3-
# stdio.h entrypoints
4-
libc.src.stdio.snprintf
5-
libc.src.stdio.sprintf
6-
libc.src.stdio.vsnprintf
7-
libc.src.stdio.vsprintf
8-
)
9-
endif()
10-
111
set(TARGET_LIBC_ENTRYPOINTS
122
# assert.h entrypoints
133
libc.src.assert.__assert_fail
@@ -186,13 +176,16 @@ set(TARGET_LIBC_ENTRYPOINTS
186176
libc.src.errno.errno
187177

188178
# stdio.h entrypoints
189-
${extra_entrypoints}
190179
libc.src.stdio.clearerr
191180
libc.src.stdio.fclose
192181
libc.src.stdio.printf
193182
libc.src.stdio.vprintf
194183
libc.src.stdio.fprintf
195184
libc.src.stdio.vfprintf
185+
libc.src.stdio.snprintf
186+
libc.src.stdio.sprintf
187+
libc.src.stdio.vsnprintf
188+
libc.src.stdio.vsprintf
196189
libc.src.stdio.feof
197190
libc.src.stdio.ferror
198191
libc.src.stdio.fflush

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,6 @@ std::string NVPTXTargetLowering::getPrototype(
14291429

14301430
bool first = true;
14311431

1432-
const Function *F = CB.getFunction();
14331432
unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
14341433
for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
14351434
Type *Ty = Args[i].Ty;
@@ -1471,10 +1470,11 @@ std::string NVPTXTargetLowering::getPrototype(
14711470
continue;
14721471
}
14731472

1473+
// Indirect calls need strict ABI alignment so we disable optimizations.
14741474
Type *ETy = Args[i].IndirectType;
14751475
Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
14761476
Align ParamByValAlign =
1477-
getFunctionByValParamAlign(F, ETy, InitialAlign, DL);
1477+
getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
14781478

14791479
O << ".param .align " << ParamByValAlign.value() << " .b8 ";
14801480
O << "_";
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_52 -mattr=+ptx64 | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_52 -mattr=+ptx64 | %ptxas-verify %}
4+
5+
target triple = "nvptx64-nvidia-cuda"
6+
7+
%struct.S = type { i8 }
8+
%struct.U = type { i64 }
9+
10+
@ptr = external global ptr, align 8
11+
12+
define internal i32 @foo() {
13+
; CHECK-LABEL: foo(
14+
; CHECK: {
15+
; CHECK-NEXT: .local .align 1 .b8 __local_depot0[2];
16+
; CHECK-NEXT: .reg .b64 %SP;
17+
; CHECK-NEXT: .reg .b64 %SPL;
18+
; CHECK-NEXT: .reg .b16 %rs<2>;
19+
; CHECK-NEXT: .reg .b32 %r<3>;
20+
; CHECK-NEXT: .reg .b64 %rd<3>;
21+
; CHECK-EMPTY:
22+
; CHECK-NEXT: // %bb.0: // %entry
23+
; CHECK-NEXT: mov.u64 %SPL, __local_depot0;
24+
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
25+
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
26+
; CHECK-NEXT: ld.u8 %rs1, [%SP+1];
27+
; CHECK-NEXT: add.u64 %rd2, %SP, 0;
28+
; CHECK-NEXT: { // callseq 0, 0
29+
; CHECK-NEXT: .param .align 1 .b8 param0[1];
30+
; CHECK-NEXT: st.param.b8 [param0+0], %rs1;
31+
; CHECK-NEXT: .param .b64 param1;
32+
; CHECK-NEXT: st.param.b64 [param1+0], %rd2;
33+
; CHECK-NEXT: .param .b32 retval0;
34+
; CHECK-NEXT: prototype_0 : .callprototype (.param .b32 _) _ (.param .align 1 .b8 _[1], .param .b64 _);
35+
; CHECK-NEXT: call (retval0),
36+
; CHECK-NEXT: %rd1,
37+
; CHECK-NEXT: (
38+
; CHECK-NEXT: param0,
39+
; CHECK-NEXT: param1
40+
; CHECK-NEXT: )
41+
; CHECK-NEXT: , prototype_0;
42+
; CHECK-NEXT: ld.param.b32 %r1, [retval0+0];
43+
; CHECK-NEXT: } // callseq 0
44+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1;
45+
; CHECK-NEXT: ret;
46+
entry:
47+
%s = alloca %struct.S, align 1
48+
%agg.tmp = alloca %struct.S, align 1
49+
%0 = load ptr, ptr @ptr, align 8
50+
%call = call i32 %0(ptr byval(%struct.S) align 1 %agg.tmp, ptr noundef %s)
51+
ret i32 %call
52+
}
53+
54+
define internal i32 @bar() {
55+
; CHECK-LABEL: bar(
56+
; CHECK: // @bar
57+
; CHECK-NEXT: {
58+
; CHECK-NEXT: .local .align 8 .b8 __local_depot1[16];
59+
; CHECK-NEXT: .reg .b64 %SP;
60+
; CHECK-NEXT: .reg .b64 %SPL;
61+
; CHECK-NEXT: .reg .b32 %r<3>;
62+
; CHECK-NEXT: .reg .b64 %rd<4>;
63+
; CHECK-EMPTY:
64+
; CHECK-NEXT: // %bb.0: // %entry
65+
; CHECK-NEXT: mov.u64 %SPL, __local_depot1;
66+
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
67+
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
68+
; CHECK-NEXT: ld.u64 %rd2, [%SP+8];
69+
; CHECK-NEXT: add.u64 %rd3, %SP, 0;
70+
; CHECK-NEXT: { // callseq 1, 0
71+
; CHECK-NEXT: .param .align 8 .b8 param0[8];
72+
; CHECK-NEXT: st.param.b64 [param0+0], %rd2;
73+
; CHECK-NEXT: .param .b64 param1;
74+
; CHECK-NEXT: st.param.b64 [param1+0], %rd3;
75+
; CHECK-NEXT: .param .b32 retval0;
76+
; CHECK-NEXT: prototype_1 : .callprototype (.param .b32 _) _ (.param .align 8 .b8 _[8], .param .b64 _);
77+
; CHECK-NEXT: call (retval0),
78+
; CHECK-NEXT: %rd1,
79+
; CHECK-NEXT: (
80+
; CHECK-NEXT: param0,
81+
; CHECK-NEXT: param1
82+
; CHECK-NEXT: )
83+
; CHECK-NEXT: , prototype_1;
84+
; CHECK-NEXT: ld.param.b32 %r1, [retval0+0];
85+
; CHECK-NEXT: } // callseq 1
86+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1;
87+
; CHECK-NEXT: ret;
88+
entry:
89+
%s = alloca %struct.U, align 8
90+
%agg.tmp = alloca %struct.U, align 8
91+
%0 = load ptr, ptr @ptr, align 8
92+
%call = call noundef i32 %0(ptr byval(%struct.U) align 8 %agg.tmp, ptr %s)
93+
ret i32 %call
94+
}

0 commit comments

Comments
 (0)