Skip to content

Commit 283b9ef

Browse files
author
anikelal
committed
[Clang][OpenCL][AMDGPU] Allow a kernel to call another kernel
This feature is currently not supported in the compiler. To facilitate this we emit a stub version of each kernel function body with different name mangling scheme, and replaces the respective kernel call-sites appropriately. Fixes #60313 D120566 was an earlier attempt made to upstream a solution for this issue.
1 parent 08a3c53 commit 283b9ef

14 files changed

+176
-40
lines changed

clang/include/clang/AST/GlobalDecl.h

+26-11
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,19 @@ class GlobalDecl {
7171
GlobalDecl(const FunctionDecl *D, unsigned MVIndex = 0)
7272
: MultiVersionIndex(MVIndex) {
7373
if (!D->hasAttr<CUDAGlobalAttr>()) {
74+
if (D->hasAttr<OpenCLKernelAttr>()) {
75+
Value.setPointerAndInt(D, unsigned(KernelReferenceKind::Kernel));
76+
return;
77+
}
7478
Init(D);
7579
return;
7680
}
7781
Value.setPointerAndInt(D, unsigned(getDefaultKernelReference(D)));
7882
}
7983
GlobalDecl(const FunctionDecl *D, KernelReferenceKind Kind)
8084
: Value(D, unsigned(Kind)) {
81-
assert(D->hasAttr<CUDAGlobalAttr>() && "Decl is not a GPU kernel!");
85+
assert((D->hasAttr<CUDAGlobalAttr>() && "Decl is not a GPU kernel!") ||
86+
(D->hasAttr<OpenCLKernelAttr>() && "Decl is not a OpenCL kernel!"));
8287
}
8388
GlobalDecl(const NamedDecl *D) { Init(D); }
8489
GlobalDecl(const BlockDecl *D) { Init(D); }
@@ -130,13 +135,15 @@ class GlobalDecl {
130135
}
131136

132137
KernelReferenceKind getKernelReferenceKind() const {
133-
assert(((isa<FunctionDecl>(getDecl()) &&
134-
cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>()) ||
135-
(isa<FunctionTemplateDecl>(getDecl()) &&
136-
cast<FunctionTemplateDecl>(getDecl())
137-
->getTemplatedDecl()
138-
->hasAttr<CUDAGlobalAttr>())) &&
139-
"Decl is not a GPU kernel!");
138+
assert((((isa<FunctionDecl>(getDecl()) &&
139+
cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>()) ||
140+
(isa<FunctionTemplateDecl>(getDecl()) &&
141+
cast<FunctionTemplateDecl>(getDecl())
142+
->getTemplatedDecl()
143+
->hasAttr<CUDAGlobalAttr>())) &&
144+
"Decl is not a GPU kernel!") ||
145+
(isDeclOpenCLKernel() && "Decl is not a OpenCL kernel!"));
146+
140147
return static_cast<KernelReferenceKind>(Value.getInt());
141148
}
142149

@@ -196,13 +203,21 @@ class GlobalDecl {
196203
}
197204

198205
GlobalDecl getWithKernelReferenceKind(KernelReferenceKind Kind) {
199-
assert(isa<FunctionDecl>(getDecl()) &&
200-
cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>() &&
201-
"Decl is not a GPU kernel!");
206+
assert((isa<FunctionDecl>(getDecl()) &&
207+
cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>() &&
208+
"Decl is not a GPU kernel!") ||
209+
(isDeclOpenCLKernel() && "Decl is not a OpenCL kernel!"));
202210
GlobalDecl Result(*this);
203211
Result.Value.setInt(unsigned(Kind));
204212
return Result;
205213
}
214+
215+
bool isDeclOpenCLKernel() const {
216+
auto FD = dyn_cast<FunctionDecl>(getDecl());
217+
if (FD)
218+
return FD->hasAttr<OpenCLKernelAttr>();
219+
return FD;
220+
}
206221
};
207222

208223
} // namespace clang

clang/lib/AST/Expr.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,8 @@ std::string PredefinedExpr::ComputeName(PredefinedIdentKind IK,
692692
GD = GlobalDecl(CD, Ctor_Base);
693693
else if (const CXXDestructorDecl *DD = dyn_cast<CXXDestructorDecl>(ND))
694694
GD = GlobalDecl(DD, Dtor_Base);
695-
else if (ND->hasAttr<CUDAGlobalAttr>())
695+
else if (ND->hasAttr<CUDAGlobalAttr>() ||
696+
ND->hasAttr<OpenCLKernelAttr>())
696697
GD = GlobalDecl(cast<FunctionDecl>(ND));
697698
else
698699
GD = GlobalDecl(ND);

clang/lib/AST/ItaniumMangle.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ class CXXNameMangler {
526526
void mangleSourceName(const IdentifierInfo *II);
527527
void mangleRegCallName(const IdentifierInfo *II);
528528
void mangleDeviceStubName(const IdentifierInfo *II);
529+
void mangleOCLDeviceStubName(const IdentifierInfo *II);
529530
void mangleSourceNameWithAbiTags(
530531
const NamedDecl *ND, const AbiTagList *AdditionalAbiTags = nullptr);
531532
void mangleLocalName(GlobalDecl GD,
@@ -1562,8 +1563,13 @@ void CXXNameMangler::mangleUnqualifiedName(
15621563
bool IsDeviceStub =
15631564
FD && FD->hasAttr<CUDAGlobalAttr>() &&
15641565
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
1566+
bool IsOCLDeviceStub =
1567+
FD && FD->hasAttr<OpenCLKernelAttr>() &&
1568+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
15651569
if (IsDeviceStub)
15661570
mangleDeviceStubName(II);
1571+
else if (IsOCLDeviceStub)
1572+
mangleOCLDeviceStubName(II);
15671573
else if (IsRegCall)
15681574
mangleRegCallName(II);
15691575
else
@@ -1781,6 +1787,15 @@ void CXXNameMangler::mangleDeviceStubName(const IdentifierInfo *II) {
17811787
<< II->getName();
17821788
}
17831789

1790+
void CXXNameMangler::mangleOCLDeviceStubName(const IdentifierInfo *II) {
1791+
// <source-name> ::= <positive length number> __clang_ocl_kern_imp_
1792+
// <identifier> <number> ::= [n] <non-negative decimal integer> <identifier>
1793+
// ::= <unqualified source code identifier>
1794+
StringRef OCLDeviceStubNamePrefix = "__clang_ocl_kern_imp_";
1795+
Out << II->getLength() + OCLDeviceStubNamePrefix.size() - 1
1796+
<< OCLDeviceStubNamePrefix << II->getName();
1797+
}
1798+
17841799
void CXXNameMangler::mangleSourceName(const IdentifierInfo *II) {
17851800
// <source-name> ::= <positive length number> <identifier>
17861801
// <number> ::= [n] <non-negative decimal integer>

clang/lib/AST/Mangle.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ class ASTNameGenerator::Implementation {
540540
GD = GlobalDecl(CtorD, Ctor_Complete);
541541
else if (const auto *DtorD = dyn_cast<CXXDestructorDecl>(D))
542542
GD = GlobalDecl(DtorD, Dtor_Complete);
543-
else if (D->hasAttr<CUDAGlobalAttr>())
543+
else if (D->hasAttr<CUDAGlobalAttr>() || D->hasAttr<OpenCLKernelAttr>())
544544
GD = GlobalDecl(cast<FunctionDecl>(D));
545545
else
546546
GD = GlobalDecl(D);

clang/lib/AST/MicrosoftMangle.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -1162,9 +1162,15 @@ void MicrosoftCXXNameMangler::mangleUnqualifiedName(GlobalDecl GD,
11621162
->getTemplatedDecl()
11631163
->hasAttr<CUDAGlobalAttr>())) &&
11641164
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
1165+
bool IsOCLDeviceStub =
1166+
ND && (isa<FunctionDecl>(ND) && ND->hasAttr<OpenCLKernelAttr>()) &&
1167+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
11651168
if (IsDeviceStub)
11661169
mangleSourceName(
11671170
(llvm::Twine("__device_stub__") + II->getName()).str());
1171+
else if (IsOCLDeviceStub)
1172+
mangleSourceName(
1173+
(llvm::Twine("__clang_ocl_kern_imp_") + II->getName()).str());
11681174
else
11691175
mangleSourceName(II->getName());
11701176
break;

clang/lib/CodeGen/CGBlocks.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ CGBlockInfo::CGBlockInfo(const BlockDecl *block, StringRef name)
4848
BlockByrefHelpers::~BlockByrefHelpers() {}
4949

5050
/// Build the given block as a global block.
51-
static llvm::Constant *buildGlobalBlock(CodeGenModule &CGM,
51+
static llvm::Constant *buildGlobalBlock(CodeGenModule &CGM, GlobalDecl GD,
5252
const CGBlockInfo &blockInfo,
5353
llvm::Constant *blockFn);
5454

@@ -1085,8 +1085,10 @@ llvm::Value *CodeGenFunction::EmitBlockLiteral(const CGBlockInfo &blockInfo) {
10851085
blockAddr.getPointer(), ConvertType(blockInfo.getBlockExpr()->getType()));
10861086

10871087
if (IsOpenCL) {
1088-
CGM.getOpenCLRuntime().recordBlockInfo(blockInfo.BlockExpression, InvokeFn,
1089-
result, blockInfo.StructureType);
1088+
CGM.getOpenCLRuntime().recordBlockInfo(
1089+
blockInfo.BlockExpression, InvokeFn, result, blockInfo.StructureType,
1090+
CurGD && CurGD.isDeclOpenCLKernel() &&
1091+
(CurGD.getKernelReferenceKind() == KernelReferenceKind::Kernel));
10901092
}
10911093

10921094
return result;
@@ -1264,7 +1266,7 @@ CodeGenModule::GetAddrOfGlobalBlock(const BlockExpr *BE,
12641266
return getAddrOfGlobalBlockIfEmitted(BE);
12651267
}
12661268

1267-
static llvm::Constant *buildGlobalBlock(CodeGenModule &CGM,
1269+
static llvm::Constant *buildGlobalBlock(CodeGenModule &CGM, GlobalDecl GD,
12681270
const CGBlockInfo &blockInfo,
12691271
llvm::Constant *blockFn) {
12701272
assert(blockInfo.CanBeGlobal);
@@ -1357,7 +1359,9 @@ static llvm::Constant *buildGlobalBlock(CodeGenModule &CGM,
13571359
CGM.getOpenCLRuntime().recordBlockInfo(
13581360
blockInfo.BlockExpression,
13591361
cast<llvm::Function>(blockFn->stripPointerCasts()), Result,
1360-
literal->getValueType());
1362+
literal->getValueType(),
1363+
GD && GD.isDeclOpenCLKernel() &&
1364+
(GD.getKernelReferenceKind() == KernelReferenceKind::Kernel));
13611365
return Result;
13621366
}
13631367

@@ -1466,7 +1470,7 @@ llvm::Function *CodeGenFunction::GenerateBlockFunction(
14661470
auto GenVoidPtrTy = getContext().getLangOpts().OpenCL
14671471
? CGM.getOpenCLRuntime().getGenericVoidPointerType()
14681472
: VoidPtrTy;
1469-
buildGlobalBlock(CGM, blockInfo,
1473+
buildGlobalBlock(CGM, CurGD, blockInfo,
14701474
llvm::ConstantExpr::getPointerCast(fn, GenVoidPtrTy));
14711475
}
14721476

clang/lib/CodeGen/CGCall.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -2382,6 +2382,15 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
23822382
// Collect function IR attributes from the CC lowering.
23832383
// We'll collect the paramete and result attributes later.
23842384
CallingConv = FI.getEffectiveCallingConvention();
2385+
GlobalDecl GD = CalleeInfo.getCalleeDecl();
2386+
const Decl *TargetDecl = CalleeInfo.getCalleeDecl().getDecl();
2387+
if (TargetDecl) {
2388+
if (auto FD = dyn_cast<FunctionDecl>(TargetDecl)) {
2389+
if (FD->hasAttr<OpenCLKernelAttr>() &&
2390+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub)
2391+
CallingConv = llvm::CallingConv::C;
2392+
}
2393+
}
23852394
if (FI.isNoReturn())
23862395
FuncAttrs.addAttribute(llvm::Attribute::NoReturn);
23872396
if (FI.isCmseNSCall())
@@ -2391,8 +2400,6 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
23912400
AddAttributesFromFunctionProtoType(getContext(), FuncAttrs,
23922401
CalleeInfo.getCalleeFunctionProtoType());
23932402

2394-
const Decl *TargetDecl = CalleeInfo.getCalleeDecl().getDecl();
2395-
23962403
// Attach assumption attributes to the declaration. If this is a call
23972404
// site, attach assumptions from the caller to the call as well.
23982405
AddAttributesFromOMPAssumes(FuncAttrs, TargetDecl);

clang/lib/CodeGen/CGExpr.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -5758,7 +5758,10 @@ CGCallee CodeGenFunction::EmitCallee(const Expr *E) {
57585758
// Resolve direct calls.
57595759
} else if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
57605760
if (auto FD = dyn_cast<FunctionDecl>(DRE->getDecl())) {
5761-
return EmitDirectCallee(*this, FD);
5761+
auto CalleeDecl = FD->hasAttr<OpenCLKernelAttr>()
5762+
? GlobalDecl(FD, KernelReferenceKind::Stub)
5763+
: FD;
5764+
return EmitDirectCallee(*this, CalleeDecl);
57625765
}
57635766
} else if (auto ME = dyn_cast<MemberExpr>(E)) {
57645767
if (auto FD = dyn_cast<FunctionDecl>(ME->getMemberDecl())) {

clang/lib/CodeGen/CGOpenCLRuntime.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,14 @@ static const BlockExpr *getBlockExpr(const Expr *E) {
126126
/// corresponding block expression.
127127
void CGOpenCLRuntime::recordBlockInfo(const BlockExpr *E,
128128
llvm::Function *InvokeF,
129-
llvm::Value *Block, llvm::Type *BlockTy) {
130-
assert(!EnqueuedBlockMap.contains(E) && "Block expression emitted twice");
129+
llvm::Value *Block, llvm::Type *BlockTy,
130+
bool isBlkExprInOCLKern) {
131+
132+
// FIXME: Since OpenCL Kernels are emitted twice (kernel version and stub
133+
// version), its constituent BlockExpr will also be emitted twice.
134+
assert((!EnqueuedBlockMap.contains(E) ||
135+
EnqueuedBlockMap[E].isBlkExprInOCLKern != isBlkExprInOCLKern) &&
136+
"Block expression emitted twice");
131137
assert(isa<llvm::Function>(InvokeF) && "Invalid invoke function");
132138
assert(Block->getType()->isPointerTy() && "Invalid block literal type");
133139
EnqueuedBlockInfo &BlockInfo = EnqueuedBlockMap[E];

clang/lib/CodeGen/CGOpenCLRuntime.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class CGOpenCLRuntime {
4646
llvm::Value *KernelHandle; /// Enqueued block kernel reference.
4747
llvm::Value *BlockArg; /// The first argument to enqueued block kernel.
4848
llvm::Type *BlockTy; /// Type of the block argument.
49+
bool isBlkExprInOCLKern; /// Does the BlockExpr reside in an OpenCL Kernel.
4950
};
5051
/// Maps block expression to block information.
5152
llvm::DenseMap<const Expr *, EnqueuedBlockInfo> EnqueuedBlockMap;
@@ -93,7 +94,8 @@ class CGOpenCLRuntime {
9394
/// \param InvokeF invoke function emitted for the block expression.
9495
/// \param Block block literal emitted for the block expression.
9596
void recordBlockInfo(const BlockExpr *E, llvm::Function *InvokeF,
96-
llvm::Value *Block, llvm::Type *BlockTy);
97+
llvm::Value *Block, llvm::Type *BlockTy,
98+
bool isBlkExprInOCLKern);
9799

98100
/// \return LLVM block invoke function emitted for an expression derived from
99101
/// the block expression.

clang/lib/CodeGen/CodeGenModule.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -1904,6 +1904,9 @@ static std::string getMangledNameImpl(CodeGenModule &CGM, GlobalDecl GD,
19041904
} else if (FD && FD->hasAttr<CUDAGlobalAttr>() &&
19051905
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
19061906
Out << "__device_stub__" << II->getName();
1907+
} else if (FD && FD->hasAttr<OpenCLKernelAttr>() &&
1908+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
1909+
Out << "__clang_ocl_kern_imp_" << II->getName();
19071910
} else {
19081911
Out << II->getName();
19091912
}
@@ -3892,6 +3895,10 @@ void CodeGenModule::EmitGlobal(GlobalDecl GD) {
38923895

38933896
// Ignore declarations, they will be emitted on their first use.
38943897
if (const auto *FD = dyn_cast<FunctionDecl>(Global)) {
3898+
3899+
if (FD->hasAttr<OpenCLKernelAttr>() && FD->doesThisDeclarationHaveABody())
3900+
addDeferredDeclToEmit(GlobalDecl(FD, KernelReferenceKind::Stub));
3901+
38953902
// Update deferred annotations with the latest declaration if the function
38963903
// function was already used or defined.
38973904
if (FD->hasAttr<AnnotateAttr>()) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: %clang_cc1 -triple amdgcn-unknown-unknown -emit-llvm -o - %s | FileCheck %s
2+
3+
// CHECK: define dso_local amdgpu_kernel void @callee_kern({{.*}})
4+
__attribute__((noinline)) kernel void callee_kern(global int *A){
5+
*A = 1;
6+
}
7+
8+
__attribute__((noinline)) kernel void ext_callee_kern(global int *A);
9+
10+
// CHECK: define dso_local void @callee_func({{.*}})
11+
__attribute__((noinline)) void callee_func(global int *A){
12+
*A = 2;
13+
}
14+
15+
// CHECK: define dso_local amdgpu_kernel void @caller_kern({{.*}})
16+
kernel void caller_kern(global int* A){
17+
callee_kern(A);
18+
// CHECK: tail call void @__clang_ocl_kern_imp_callee_kern({{.*}})
19+
ext_callee_kern(A);
20+
// CHECK: tail call void @__clang_ocl_kern_imp_ext_callee_kern({{.*}})
21+
callee_func(A);
22+
// CHECK: tail call void @callee_func({{.*}})
23+
24+
}
25+
26+
// CHECK: define dso_local void @__clang_ocl_kern_imp_callee_kern({{.*}})
27+
28+
// CHECK: declare void @__clang_ocl_kern_imp_ext_callee_kern({{.*}})
29+
30+
// CHECK: define dso_local void @caller_func({{.*}})
31+
void caller_func(global int* A){
32+
callee_kern(A);
33+
// CHECK: tail call void @__clang_ocl_kern_imp_callee_kern({{.*}}) #7
34+
ext_callee_kern(A);
35+
// CHECK: tail call void @__clang_ocl_kern_imp_ext_callee_kern({{.*}}) #8
36+
callee_func(A);
37+
// CHECK: tail call void @callee_func({{.*}})
38+
}
39+
40+
// CHECK: define dso_local void @__clang_ocl_kern_imp_caller_kern({{.*}})
41+
// CHECK: tail call void @__clang_ocl_kern_imp_callee_kern({{.*}})
42+
// CHECK: tail call void @__clang_ocl_kern_imp_ext_callee_kern({{.*}})
43+
// CHECK: tail call void @callee_func({{.*}})

clang/test/CodeGenOpenCL/spir-calling-conv.cl

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ kernel void foo(global int *A)
1111
// CHECK: %{{[a-z0-9_]+}} = tail call spir_func i32 @get_dummy_id(i32 noundef 0)
1212
A[id] = id;
1313
bar(A);
14-
// CHECK: tail call spir_kernel void @bar(ptr addrspace(1) noundef align 4 %A)
14+
// CHECK: tail call void @__clang_ocl_kern_imp_bar(ptr addrspace(1) noundef align 4 %A)
1515
}
1616

1717
// CHECK: declare spir_func i32 @get_dummy_id(i32 noundef)
18-
// CHECK: declare spir_kernel void @bar(ptr addrspace(1) noundef align 4)
18+
// CHECK: declare void @__clang_ocl_kern_imp_bar(ptr addrspace(1) noundef align 4)

0 commit comments

Comments
 (0)