Skip to content

Commit 642481a

Browse files
lalaniket8anikelal
and
anikelal
authored
[Clang][OpenCL][AMDGPU] Allow a kernel to call another kernel (#115821)
This feature is currently not supported in the compiler. To facilitate this we emit a stub version of each kernel function body with different name mangling scheme, and replaces the respective kernel call-sites appropriately. Fixes #60313 D120566 was an earlier attempt made to upstream a solution for this issue. --------- Co-authored-by: anikelal <[email protected]>
1 parent 65cede2 commit 642481a

33 files changed

+3375
-1375
lines changed

clang/include/clang/AST/Decl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3048,6 +3048,8 @@ class FunctionDecl : public DeclaratorDecl,
30483048
static FunctionDecl *castFromDeclContext(const DeclContext *DC) {
30493049
return static_cast<FunctionDecl *>(const_cast<DeclContext*>(DC));
30503050
}
3051+
3052+
bool isReferenceableKernel() const;
30513053
};
30523054

30533055
/// Represents a member of a struct/union/class.

clang/include/clang/AST/GlobalDecl.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ class GlobalDecl {
7070
GlobalDecl(const VarDecl *D) { Init(D);}
7171
GlobalDecl(const FunctionDecl *D, unsigned MVIndex = 0)
7272
: MultiVersionIndex(MVIndex) {
73-
if (!D->hasAttr<CUDAGlobalAttr>()) {
74-
Init(D);
73+
if (D->isReferenceableKernel()) {
74+
Value.setPointerAndInt(D, unsigned(getDefaultKernelReference(D)));
7575
return;
7676
}
77-
Value.setPointerAndInt(D, unsigned(getDefaultKernelReference(D)));
77+
Init(D);
7878
}
7979
GlobalDecl(const FunctionDecl *D, KernelReferenceKind Kind)
8080
: Value(D, unsigned(Kind)) {
81-
assert(D->hasAttr<CUDAGlobalAttr>() && "Decl is not a GPU kernel!");
81+
assert(D->isReferenceableKernel() && "Decl is not a GPU kernel!");
8282
}
8383
GlobalDecl(const NamedDecl *D) { Init(D); }
8484
GlobalDecl(const BlockDecl *D) { Init(D); }
@@ -131,12 +131,13 @@ class GlobalDecl {
131131

132132
KernelReferenceKind getKernelReferenceKind() const {
133133
assert(((isa<FunctionDecl>(getDecl()) &&
134-
cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>()) ||
134+
cast<FunctionDecl>(getDecl())->isReferenceableKernel()) ||
135135
(isa<FunctionTemplateDecl>(getDecl()) &&
136136
cast<FunctionTemplateDecl>(getDecl())
137137
->getTemplatedDecl()
138138
->hasAttr<CUDAGlobalAttr>())) &&
139139
"Decl is not a GPU kernel!");
140+
140141
return static_cast<KernelReferenceKind>(Value.getInt());
141142
}
142143

@@ -160,8 +161,9 @@ class GlobalDecl {
160161
}
161162

162163
static KernelReferenceKind getDefaultKernelReference(const FunctionDecl *D) {
163-
return D->getLangOpts().CUDAIsDevice ? KernelReferenceKind::Kernel
164-
: KernelReferenceKind::Stub;
164+
return (D->hasAttr<OpenCLKernelAttr>() || D->getLangOpts().CUDAIsDevice)
165+
? KernelReferenceKind::Kernel
166+
: KernelReferenceKind::Stub;
165167
}
166168

167169
GlobalDecl getWithDecl(const Decl *D) {
@@ -197,7 +199,7 @@ class GlobalDecl {
197199

198200
GlobalDecl getWithKernelReferenceKind(KernelReferenceKind Kind) {
199201
assert(isa<FunctionDecl>(getDecl()) &&
200-
cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>() &&
202+
cast<FunctionDecl>(getDecl())->isReferenceableKernel() &&
201203
"Decl is not a GPU kernel!");
202204
GlobalDecl Result(*this);
203205
Result.Value.setInt(unsigned(Kind));

clang/lib/AST/Decl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5468,6 +5468,10 @@ FunctionDecl *FunctionDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID) {
54685468
/*TrailingRequiresClause=*/{});
54695469
}
54705470

5471+
bool FunctionDecl::isReferenceableKernel() const {
5472+
return hasAttr<CUDAGlobalAttr>() || hasAttr<OpenCLKernelAttr>();
5473+
}
5474+
54715475
BlockDecl *BlockDecl::Create(ASTContext &C, DeclContext *DC, SourceLocation L) {
54725476
return new (C, DC) BlockDecl(DC, L);
54735477
}

clang/lib/AST/Expr.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -695,9 +695,9 @@ std::string PredefinedExpr::ComputeName(PredefinedIdentKind IK,
695695
GD = GlobalDecl(CD, Ctor_Base);
696696
else if (const CXXDestructorDecl *DD = dyn_cast<CXXDestructorDecl>(ND))
697697
GD = GlobalDecl(DD, Dtor_Base);
698-
else if (ND->hasAttr<CUDAGlobalAttr>())
699-
GD = GlobalDecl(cast<FunctionDecl>(ND));
700-
else
698+
else if (auto FD = dyn_cast<FunctionDecl>(ND)) {
699+
GD = FD->isReferenceableKernel() ? GlobalDecl(FD) : GlobalDecl(ND);
700+
} else
701701
GD = GlobalDecl(ND);
702702
MC->mangleName(GD, Out);
703703

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ class CXXNameMangler {
526526
void mangleSourceName(const IdentifierInfo *II);
527527
void mangleRegCallName(const IdentifierInfo *II);
528528
void mangleDeviceStubName(const IdentifierInfo *II);
529+
void mangleOCLDeviceStubName(const IdentifierInfo *II);
529530
void mangleSourceNameWithAbiTags(
530531
const NamedDecl *ND, const AbiTagList *AdditionalAbiTags = nullptr);
531532
void mangleLocalName(GlobalDecl GD,
@@ -1561,8 +1562,13 @@ void CXXNameMangler::mangleUnqualifiedName(
15611562
bool IsDeviceStub =
15621563
FD && FD->hasAttr<CUDAGlobalAttr>() &&
15631564
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
1565+
bool IsOCLDeviceStub =
1566+
FD && FD->hasAttr<OpenCLKernelAttr>() &&
1567+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
15641568
if (IsDeviceStub)
15651569
mangleDeviceStubName(II);
1570+
else if (IsOCLDeviceStub)
1571+
mangleOCLDeviceStubName(II);
15661572
else if (IsRegCall)
15671573
mangleRegCallName(II);
15681574
else
@@ -1780,6 +1786,15 @@ void CXXNameMangler::mangleDeviceStubName(const IdentifierInfo *II) {
17801786
<< II->getName();
17811787
}
17821788

1789+
void CXXNameMangler::mangleOCLDeviceStubName(const IdentifierInfo *II) {
1790+
// <source-name> ::= <positive length number> __clang_ocl_kern_imp_
1791+
// <identifier> <number> ::= [n] <non-negative decimal integer> <identifier>
1792+
// ::= <unqualified source code identifier>
1793+
StringRef OCLDeviceStubNamePrefix = "__clang_ocl_kern_imp_";
1794+
Out << II->getLength() + OCLDeviceStubNamePrefix.size()
1795+
<< OCLDeviceStubNamePrefix << II->getName();
1796+
}
1797+
17831798
void CXXNameMangler::mangleSourceName(const IdentifierInfo *II) {
17841799
// <source-name> ::= <positive length number> <identifier>
17851800
// <number> ::= [n] <non-negative decimal integer>

clang/lib/AST/Mangle.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,9 @@ class ASTNameGenerator::Implementation {
540540
GD = GlobalDecl(CtorD, Ctor_Complete);
541541
else if (const auto *DtorD = dyn_cast<CXXDestructorDecl>(D))
542542
GD = GlobalDecl(DtorD, Dtor_Complete);
543-
else if (D->hasAttr<CUDAGlobalAttr>())
544-
GD = GlobalDecl(cast<FunctionDecl>(D));
545-
else
543+
else if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {
544+
GD = FD->isReferenceableKernel() ? GlobalDecl(FD) : GlobalDecl(D);
545+
} else
546546
GD = GlobalDecl(D);
547547
MC->mangleName(GD, OS);
548548
return false;

clang/lib/AST/MicrosoftMangle.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,9 +1162,15 @@ void MicrosoftCXXNameMangler::mangleUnqualifiedName(GlobalDecl GD,
11621162
->getTemplatedDecl()
11631163
->hasAttr<CUDAGlobalAttr>())) &&
11641164
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
1165+
bool IsOCLDeviceStub =
1166+
ND && isa<FunctionDecl>(ND) && ND->hasAttr<OpenCLKernelAttr>() &&
1167+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
11651168
if (IsDeviceStub)
11661169
mangleSourceName(
11671170
(llvm::Twine("__device_stub__") + II->getName()).str());
1171+
else if (IsOCLDeviceStub)
1172+
mangleSourceName(
1173+
(llvm::Twine("__clang_ocl_kern_imp_") + II->getName()).str());
11681174
else
11691175
mangleSourceName(II->getName());
11701176
break;

clang/lib/CodeGen/CGCall.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,8 @@ CodeGenTypes::arrangeCXXConstructorCall(const CallArgList &args,
499499
/// Arrange the argument and result information for the declaration or
500500
/// definition of the given function.
501501
const CGFunctionInfo &
502-
CodeGenTypes::arrangeFunctionDeclaration(const FunctionDecl *FD) {
502+
CodeGenTypes::arrangeFunctionDeclaration(const GlobalDecl GD) {
503+
const FunctionDecl *FD = cast<FunctionDecl>(GD.getDecl());
503504
if (const CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(FD))
504505
if (MD->isImplicitObjectMemberFunction())
505506
return arrangeCXXMethodDeclaration(MD);
@@ -509,6 +510,13 @@ CodeGenTypes::arrangeFunctionDeclaration(const FunctionDecl *FD) {
509510
assert(isa<FunctionType>(FTy));
510511
setCUDAKernelCallingConvention(FTy, CGM, FD);
511512

513+
if (FD->hasAttr<OpenCLKernelAttr>() &&
514+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
515+
const FunctionType *FT = FTy->getAs<FunctionType>();
516+
CGM.getTargetCodeGenInfo().setOCLKernelStubCallingConvention(FT);
517+
FTy = FT->getCanonicalTypeUnqualified();
518+
}
519+
512520
// When declaring a function without a prototype, always use a
513521
// non-variadic type.
514522
if (CanQual<FunctionNoProtoType> noProto = FTy.getAs<FunctionNoProtoType>()) {
@@ -581,13 +589,11 @@ CodeGenTypes::arrangeUnprototypedObjCMessageSend(QualType returnType,
581589
const CGFunctionInfo &
582590
CodeGenTypes::arrangeGlobalDeclaration(GlobalDecl GD) {
583591
// FIXME: Do we need to handle ObjCMethodDecl?
584-
const FunctionDecl *FD = cast<FunctionDecl>(GD.getDecl());
585-
586592
if (isa<CXXConstructorDecl>(GD.getDecl()) ||
587593
isa<CXXDestructorDecl>(GD.getDecl()))
588594
return arrangeCXXStructorDeclaration(GD);
589595

590-
return arrangeFunctionDeclaration(FD);
596+
return arrangeFunctionDeclaration(GD);
591597
}
592598

593599
/// Arrange a thunk that takes 'this' as the first parameter followed by
@@ -2391,7 +2397,6 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
23912397
// Collect function IR attributes from the callee prototype if we have one.
23922398
AddAttributesFromFunctionProtoType(getContext(), FuncAttrs,
23932399
CalleeInfo.getCalleeFunctionProtoType());
2394-
23952400
const Decl *TargetDecl = CalleeInfo.getCalleeDecl().getDecl();
23962401

23972402
// Attach assumption attributes to the declaration. If this is a call
@@ -2498,7 +2503,11 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
24982503
NumElemsParam);
24992504
}
25002505

2501-
if (TargetDecl->hasAttr<OpenCLKernelAttr>()) {
2506+
if (TargetDecl->hasAttr<OpenCLKernelAttr>() &&
2507+
CallingConv != CallingConv::CC_C &&
2508+
CallingConv != CallingConv::CC_SpirFunction) {
2509+
// Check CallingConv to avoid adding uniform-work-group-size attribute to
2510+
// OpenCL Kernel Stub
25022511
if (getLangOpts().OpenCLVersion <= 120) {
25032512
// OpenCL v1.2 Work groups are always uniform
25042513
FuncAttrs.addAttribute("uniform-work-group-size", "true");

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5752,6 +5752,12 @@ static CGCallee EmitDirectCallee(CodeGenFunction &CGF, GlobalDecl GD) {
57525752
return CGCallee::forDirect(CalleePtr, GD);
57535753
}
57545754

5755+
static GlobalDecl getGlobalDeclForDirectCall(const FunctionDecl *FD) {
5756+
if (FD->hasAttr<OpenCLKernelAttr>())
5757+
return GlobalDecl(FD, KernelReferenceKind::Stub);
5758+
return GlobalDecl(FD);
5759+
}
5760+
57555761
CGCallee CodeGenFunction::EmitCallee(const Expr *E) {
57565762
E = E->IgnoreParens();
57575763

@@ -5765,7 +5771,7 @@ CGCallee CodeGenFunction::EmitCallee(const Expr *E) {
57655771
// Resolve direct calls.
57665772
} else if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
57675773
if (auto FD = dyn_cast<FunctionDecl>(DRE->getDecl())) {
5768-
return EmitDirectCallee(*this, FD);
5774+
return EmitDirectCallee(*this, getGlobalDeclForDirectCall(FD));
57695775
}
57705776
} else if (auto ME = dyn_cast<MemberExpr>(E)) {
57715777
if (auto FD = dyn_cast<FunctionDecl>(ME->getMemberDecl())) {
@@ -6134,6 +6140,10 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType,
61346140

61356141
const auto *FnType = cast<FunctionType>(PointeeType);
61366142

6143+
if (const auto *FD = dyn_cast_or_null<FunctionDecl>(TargetDecl);
6144+
FD && FD->hasAttr<OpenCLKernelAttr>())
6145+
CGM.getTargetCodeGenInfo().setOCLKernelStubCallingConvention(FnType);
6146+
61376147
// If we are checking indirect calls and this call is indirect, check that the
61386148
// function pointer is a member of the bit set for the function type.
61396149
if (SanOpts.has(SanitizerKind::CFIICall) &&

clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,26 @@ void CodeGenFunction::GenerateCode(GlobalDecl GD, llvm::Function *Fn,
15951595
// Implicit copy-assignment gets the same special treatment as implicit
15961596
// copy-constructors.
15971597
emitImplicitAssignmentOperatorBody(Args);
1598+
} else if (FD->hasAttr<OpenCLKernelAttr>() &&
1599+
GD.getKernelReferenceKind() == KernelReferenceKind::Kernel) {
1600+
CallArgList CallArgs;
1601+
for (unsigned i = 0; i < Args.size(); ++i) {
1602+
Address ArgAddr = GetAddrOfLocalVar(Args[i]);
1603+
QualType ArgQualType = Args[i]->getType();
1604+
RValue ArgRValue = convertTempToRValue(ArgAddr, ArgQualType, Loc);
1605+
CallArgs.add(ArgRValue, ArgQualType);
1606+
}
1607+
GlobalDecl GDStub = GlobalDecl(FD, KernelReferenceKind::Stub);
1608+
const FunctionType *FT = cast<FunctionType>(FD->getType());
1609+
CGM.getTargetCodeGenInfo().setOCLKernelStubCallingConvention(FT);
1610+
const CGFunctionInfo &FnInfo = CGM.getTypes().arrangeFreeFunctionCall(
1611+
CallArgs, FT, /*ChainCall=*/false);
1612+
llvm::FunctionType *FTy = CGM.getTypes().GetFunctionType(FnInfo);
1613+
llvm::Constant *GDStubFunctionPointer =
1614+
CGM.getRawFunctionPointer(GDStub, FTy);
1615+
CGCallee GDStubCallee = CGCallee::forDirect(GDStubFunctionPointer, GDStub);
1616+
EmitCall(FnInfo, GDStubCallee, ReturnValueSlot(), CallArgs, nullptr, false,
1617+
Loc);
15981618
} else if (Body) {
15991619
EmitFunctionBody(Body);
16001620
} else

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,6 +1903,9 @@ static std::string getMangledNameImpl(CodeGenModule &CGM, GlobalDecl GD,
19031903
} else if (FD && FD->hasAttr<CUDAGlobalAttr>() &&
19041904
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
19051905
Out << "__device_stub__" << II->getName();
1906+
} else if (FD && FD->hasAttr<OpenCLKernelAttr>() &&
1907+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
1908+
Out << "__clang_ocl_kern_imp_" << II->getName();
19061909
} else {
19071910
Out << II->getName();
19081911
}
@@ -3890,6 +3893,9 @@ void CodeGenModule::EmitGlobal(GlobalDecl GD) {
38903893

38913894
// Ignore declarations, they will be emitted on their first use.
38923895
if (const auto *FD = dyn_cast<FunctionDecl>(Global)) {
3896+
if (FD->hasAttr<OpenCLKernelAttr>() && FD->doesThisDeclarationHaveABody())
3897+
addDeferredDeclToEmit(GlobalDecl(FD, KernelReferenceKind::Stub));
3898+
38933899
// Update deferred annotations with the latest declaration if the function
38943900
// function was already used or defined.
38953901
if (FD->hasAttr<AnnotateAttr>()) {
@@ -4857,6 +4863,11 @@ CodeGenModule::GetAddrOfFunction(GlobalDecl GD, llvm::Type *Ty, bool ForVTable,
48574863
if (!Ty) {
48584864
const auto *FD = cast<FunctionDecl>(GD.getDecl());
48594865
Ty = getTypes().ConvertType(FD->getType());
4866+
if (FD->hasAttr<OpenCLKernelAttr>() &&
4867+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
4868+
const CGFunctionInfo &FI = getTypes().arrangeGlobalDeclaration(GD);
4869+
Ty = getTypes().GetFunctionType(FI);
4870+
}
48604871
}
48614872

48624873
// Devirtualized destructor calls may come through here instead of via

clang/lib/CodeGen/CodeGenTypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class CodeGenTypes {
207207

208208
/// Free functions are functions that are compatible with an ordinary
209209
/// C function pointer type.
210-
const CGFunctionInfo &arrangeFunctionDeclaration(const FunctionDecl *FD);
210+
const CGFunctionInfo &arrangeFunctionDeclaration(const GlobalDecl GD);
211211
const CGFunctionInfo &arrangeFreeFunctionCall(const CallArgList &Args,
212212
const FunctionType *Ty,
213213
bool ChainCall);

clang/lib/CodeGen/TargetInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ unsigned TargetCodeGenInfo::getOpenCLKernelCallingConv() const {
117117
return llvm::CallingConv::SPIR_KERNEL;
118118
}
119119

120+
void TargetCodeGenInfo::setOCLKernelStubCallingConvention(
121+
const FunctionType *&FT) const {
122+
FT = getABIInfo().getContext().adjustFunctionType(
123+
FT, FT->getExtInfo().withCallingConv(CC_C));
124+
}
125+
120126
llvm::Constant *TargetCodeGenInfo::getNullPointer(const CodeGen::CodeGenModule &CGM,
121127
llvm::PointerType *T, QualType QT) const {
122128
return llvm::ConstantPointerNull::get(T);

clang/lib/CodeGen/TargetInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ class TargetCodeGenInfo {
400400
virtual bool shouldEmitDWARFBitFieldSeparators() const { return false; }
401401

402402
virtual void setCUDAKernelCallingConvention(const FunctionType *&FT) const {}
403-
403+
virtual void setOCLKernelStubCallingConvention(const FunctionType *&FT) const;
404404
/// Return the device-side type for the CUDA device builtin surface type.
405405
virtual llvm::Type *getCUDADeviceBuiltinSurfaceDeviceType() const {
406406
// By default, no change from the original one.

clang/lib/CodeGen/Targets/SPIR.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class CommonSPIRTargetCodeGenInfo : public TargetCodeGenInfo {
5858
llvm::Type *getSPIRVImageTypeFromHLSLResource(
5959
const HLSLAttributedResourceType::Attributes &attributes,
6060
llvm::Type *ElementType, llvm::LLVMContext &Ctx) const;
61+
void
62+
setOCLKernelStubCallingConvention(const FunctionType *&FT) const override;
6163
};
6264
class SPIRVTargetCodeGenInfo : public CommonSPIRTargetCodeGenInfo {
6365
public:
@@ -230,6 +232,12 @@ void SPIRVTargetCodeGenInfo::setCUDAKernelCallingConvention(
230232
}
231233
}
232234

235+
void CommonSPIRTargetCodeGenInfo::setOCLKernelStubCallingConvention(
236+
const FunctionType *&FT) const {
237+
FT = getABIInfo().getContext().adjustFunctionType(
238+
FT, FT->getExtInfo().withCallingConv(CC_SpirFunction));
239+
}
240+
233241
LangAS
234242
SPIRVTargetCodeGenInfo::getGlobalVarAddressSpace(CodeGenModule &CGM,
235243
const VarDecl *D) const {

0 commit comments

Comments
 (0)