Skip to content

Commit 2582436

Browse files
committed
[dcompute] has->getKernelAttr, return expression literal
Needed for Vulkan where information about the number of threads to use is supplied to the attribute.
1 parent 90e39b6 commit 2582436

File tree

9 files changed

+20
-14
lines changed

9 files changed

+20
-14
lines changed

gen/abi/nvptx.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct NVPTXTargetABI : TargetABI {
2323
return llvm::CallingConv::PTX_Device;
2424
}
2525
llvm::CallingConv::ID callingConv(FuncDeclaration *fdecl) override {
26-
return hasKernelAttr(fdecl) ? llvm::CallingConv::PTX_Kernel
26+
return getKernelAttr(fdecl) ? llvm::CallingConv::PTX_Kernel
2727
: llvm::CallingConv::PTX_Device;
2828
}
2929
bool passByVal(TypeFunction *, Type *t) override {

gen/abi/spirv.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct SPIRVTargetABI : TargetABI {
2323
return llvm::CallingConv::SPIR_FUNC;
2424
}
2525
llvm::CallingConv::ID callingConv(FuncDeclaration *fdecl) override {
26-
return hasKernelAttr(fdecl) ? llvm::CallingConv::SPIR_KERNEL
26+
return getKernelAttr(fdecl) ? llvm::CallingConv::SPIR_KERNEL
2727
: llvm::CallingConv::SPIR_FUNC;
2828
}
2929
bool passByVal(TypeFunction *, Type *t) override {

gen/dcompute/target.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class DComputeTarget {
4949
void writeModule();
5050

5151
virtual void addMetadata() = 0;
52-
virtual void addKernelMetadata(FuncDeclaration *df, llvm::Function *llf) = 0;
52+
virtual void addKernelMetadata(FuncDeclaration *df,
53+
llvm::Function *llf,
54+
StructLiteralExp *kernAttr) = 0;
5355
};
5456

5557
#if LDC_LLVM_SUPPORTED_TARGET_NVPTX

gen/dcompute/targetCUDA.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class TargetCUDA : public DComputeTarget {
5252
// sm version?
5353
}
5454

55-
void addKernelMetadata(FuncDeclaration *df, llvm::Function *llf) override {
55+
void addKernelMetadata(FuncDeclaration *df, llvm::Function *llf, StructLiteralExp *_unused_) override {
5656
// TODO: Handle Function attibutes
5757
llvm::NamedMDNode *na =
5858
_ir->module.getOrInsertNamedMetadata("nvvm.annotations");

gen/dcompute/targetOCL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class TargetOCL : public DComputeTarget {
138138
KernArgMD_name,
139139
count_KernArgMD
140140
};
141-
void addKernelMetadata(FuncDeclaration *fd, llvm::Function *llf) override {
141+
void addKernelMetadata(FuncDeclaration *fd, llvm::Function *llf, StructLiteralExp *_unused_) override {
142142
// By the time we get here the ABI should have rewritten the function
143143
// type so that the magic types in ldc.dcompute are transformed into
144144
// what the LLVM backend expects.

gen/functions.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,9 +1342,12 @@ void DtoDefineFunction(FuncDeclaration *fd, bool linkageAvailableExternally) {
13421342
allocaPoint = nullptr;
13431343
}
13441344

1345-
if (gIR->dcomputetarget && hasKernelAttr(fd)) {
1346-
auto fn = gIR->module.getFunction(fd->mangleString);
1347-
gIR->dcomputetarget->addKernelMetadata(fd, fn);
1345+
if (gIR->dcomputetarget) {
1346+
auto kernAttr = getKernelAttr(fd);
1347+
if (kernAttr) {
1348+
auto fn = gIR->module.getFunction(fd->mangleString);
1349+
gIR->dcomputetarget->addKernelMetadata(fd, fn, kernAttr);
1350+
}
13481351
}
13491352

13501353
if (func->getLinkage() == LLGlobalValue::WeakAnyLinkage &&

gen/semantic-dcompute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ struct DComputeSemanticAnalyser : public StoppableVisitor {
237237
}
238238

239239
void visit(FuncDeclaration *fd) override {
240-
if (hasKernelAttr(fd) && fd->vthis) {
240+
if (getKernelAttr(fd) && fd->vthis) {
241241
error(fd->loc, "`@kernel` functions must not require `this`");
242242
stop = true;
243243
return;

gen/uda.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -654,11 +654,11 @@ extern "C" DComputeCompileFor hasComputeAttr(Dsymbol *sym) {
654654
return static_cast<DComputeCompileFor>(1 + (*sle->elements)[0]->toInteger());
655655
}
656656

657-
/// Checks whether 'sym' has the @ldc.dcompute._kernel() UDA applied.
658-
bool hasKernelAttr(Dsymbol *sym) {
657+
/// Returns whether `sym` has the `@ldc.dcompute._kernel()` UDA applied.
658+
StructLiteralExp *getKernelAttr(Dsymbol *sym) {
659659
auto sle = getMagicAttribute(sym, Id::udaKernel, Id::dcompute);
660660
if (!sle)
661-
return false;
661+
return nullptr;
662662

663663
checkStructElems(sle, {});
664664

@@ -668,7 +668,7 @@ bool hasKernelAttr(Dsymbol *sym) {
668668
" in modules marked `@ldc.dcompute.compute`");
669669
}
670670

671-
return true;
671+
return sle;
672672
}
673673

674674
/// Check whether `fd` has the `@ldc.attributes.noSplitStack` UDA applied.

gen/uda.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
class Dsymbol;
2020
class FuncDeclaration;
2121
class VarDeclaration;
22+
class StructLiteralExp;
2223
struct IrFunction;
2324
namespace llvm {
2425
class GlobalVariable;
@@ -29,7 +30,7 @@ void applyVarDeclUDAs(VarDeclaration *decl, llvm::GlobalVariable *gvar);
2930

3031
bool hasCallingConventionUDA(FuncDeclaration *fd, llvm::CallingConv::ID *callconv);
3132
bool hasWeakUDA(Dsymbol *sym);
32-
bool hasKernelAttr(Dsymbol *sym);
33+
StructLiteralExp *getKernelAttr(Dsymbol *sym);
3334
/// Must match ldc.dcompute.Compilefor + 1 == DComputeCompileFor
3435
enum class DComputeCompileFor : int
3536
{

0 commit comments

Comments
 (0)