Skip to content

[SPIRV] Added support for 2 kernel query builtins #142280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,18 +372,15 @@ static MachineInstr *getBlockStructInstr(Register ParamReg,
// We expect the following sequence of instructions:
// %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
// or = G_GLOBAL_VALUE @block_literal_global
// %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
// %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
// %1:_(p4) = G_ADDRSPACE_CAST %0:_(pN)
MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
MI->getOperand(1).isReg());
Register BitcastReg = MI->getOperand(1).getReg();
MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
BitcastMI->getOperand(2).isReg());
Register ValueReg = BitcastMI->getOperand(2).getReg();
MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
return ValueMI;
Register PtrReg = MI->getOperand(1).getReg();
MachineInstr *PtrMI = MRI->getUniqueVRegDef(PtrReg);
assert(PtrMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
isSpvIntrinsic(*PtrMI, Intrinsic::spv_alloca));
return PtrMI;
}

// Return an integer constant corresponding to the given register and
Expand Down Expand Up @@ -2509,6 +2506,59 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
return true;
}

static bool buildNDRangeSubGroup(const SPIRV::IncomingCall *Call,
unsigned Opcode, MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
const DataLayout &DL = MIRBuilder.getDataLayout();

auto MIB = MIRBuilder.buildInstr(Opcode)
.addDef(Call->ReturnRegister)
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
.addUse(Call->Arguments[0]);
unsigned int BlockFIdx = 1;
MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
// Invoke: Pointer to invoke function.
Register BlockFReg = BlockMI->getOperand(0).getReg();
MIB.addUse(BlockFReg);
MRI->setRegClass(BlockFReg, &SPIRV::pIDRegClass);

Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
// Param: Pointer to block literal.
MIB.addUse(BlockLiteralReg);
BlockMI = MRI->getUniqueVRegDef(BlockLiteralReg);
Register BlockMIReg =
stripAddrspaceCast(BlockMI->getOperand(1).getReg(), *MRI);
BlockMI = MRI->getUniqueVRegDef(BlockMIReg);

if (BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE) {
// Size and align are given explicitly here.
const GlobalValue *GV = BlockMI->getOperand(1).getGlobal();

const GlobalVariable *BlockGV = dyn_cast<GlobalVariable>(GV);
assert(BlockGV->hasInitializer() &&
"Block literal should have an initializer");
const Constant *Init = BlockGV->getInitializer();
const ConstantStruct *CS = dyn_cast<ConstantStruct>(Init);
// Extract fields
const ConstantInt *SizeConst = dyn_cast<ConstantInt>(CS->getOperand(0));
const ConstantInt *AlignConst = dyn_cast<ConstantInt>(CS->getOperand(1));
uint64_t BlockSize = SizeConst->getZExtValue();
uint64_t BlockAlign = AlignConst->getZExtValue();
MIB.addUse(buildConstantIntReg32(BlockSize, MIRBuilder, GR));
MIB.addUse(buildConstantIntReg32(BlockAlign, MIRBuilder, GR));
} else {
Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
// Fallback to default if not found
MIB.addUse(
buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
MIRBuilder, GR));
}
return true;
}

static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
Expand Down Expand Up @@ -2544,6 +2594,9 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
return buildNDRange(Call, MIRBuilder, GR);
case SPIRV::OpEnqueueKernel:
return buildEnqueueKernel(Call, MIRBuilder, GR);
case SPIRV::OpGetKernelNDrangeSubGroupCount:
case SPIRV::OpGetKernelNDrangeMaxSubGroupSize:
return buildNDRangeSubGroup(Call, Opcode, MIRBuilder, GR);
default:
return false;
}
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,9 @@ defm : DemangledNativeBuiltin<"__spirv_GetDefaultQueue", OpenCL_std, Enqueue, 0,
defm : DemangledNativeBuiltin<"ndrange_1D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
defm : DemangledNativeBuiltin<"ndrange_2D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
defm : DemangledNativeBuiltin<"ndrange_3D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
defm : DemangledNativeBuiltin<"__get_kernel_sub_group_count_for_ndrange_impl", OpenCL_std, Enqueue, 2, 2, OpGetKernelNDrangeSubGroupCount>;
defm : DemangledNativeBuiltin<"__get_kernel_max_sub_group_size_for_ndrange_impl", OpenCL_std, Enqueue, 2, 2, OpGetKernelNDrangeMaxSubGroupSize>;


// Spec constant builtin records:
defm : DemangledNativeBuiltin<"__spirv_SpecConstant", OpenCL_std, SpecConstant, 2, 2, OpSpecConstant>;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,10 @@ def OpGetDefaultQueue: Op<303, (outs ID:$res), (ins TYPE:$type),
"$res = OpGetDefaultQueue $type">;
def OpBuildNDRange: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$GWS, ID:$LWS, ID:$GWO),
"$res = OpBuildNDRange $type $GWS $LWS $GWO">;
def OpGetKernelNDrangeSubGroupCount: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$NDR, ID:$Invoke, ID:$Param, ID:$ParamSize, ID:$ParamAlign),
"$res = OpGetKernelNDrangeSubGroupCount $type $NDR $Invoke $Param $ParamSize $ParamAlign">;
def OpGetKernelNDrangeMaxSubGroupSize: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$NDR, ID:$Invoke, ID:$Param, ID:$ParamSize, ID:$ParamAlign),
"$res = OpGetKernelNDrangeMaxSubGroupSize $type $NDR $Invoke $Param $ParamSize $ParamAlign">;

// TODO: 3.42.23. Pipe Instructions

Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1848,6 +1848,11 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
break;
}
case SPIRV::OpGetKernelNDrangeMaxSubGroupSize:
case SPIRV::OpGetKernelNDrangeSubGroupCount: {
Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
break;
}

default:
break;
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,16 @@ bool isEntryPoint(const Function &F) {
return false;
}

Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI) {
while (true) {
MachineInstr *Def = MRI.getVRegDef(Reg);
if (!Def || Def->getOpcode() != TargetOpcode::G_ADDRSPACE_CAST)
break;
Reg = Def->getOperand(1).getReg(); // Unwrap the cast
}
return Reg;
}

Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
TypeName.consume_front("atomic_");
if (TypeName.consume_front("void"))
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ bool isSpecialOpaqueType(const Type *Ty);

// Check if the function is an SPIR-V entry point
bool isEntryPoint(const Function &F);

Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI);
// Parse basic scalar type name, substring TypeName, and return LLVM type.
Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx);

Expand Down
95 changes: 95 additions & 0 deletions llvm/test/CodeGen/SPIRV/transcoding/kernel_query.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir-unknown-unknown"

%struct.ndrange_t = type { i32 }
%1 = type <{ i32, i32 }>

@__block_literal_global = internal addrspace(1) constant { i32, i32 } { i32 8, i32 4 }, align 4
@__block_literal_global.1 = internal addrspace(1) constant { i32, i32 } { i32 8, i32 4 }, align 4

; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#C4:]] = OpConstant %[[#Int32Ty]] 4
; CHECK-DAG: %[[#C8:]] = OpConstant %[[#Int32Ty]] 8
; CHECK-DAG: %[[#NDRangeTy:]] = OpTypeStruct %[[#Int32Ty]]
; CHECK-DAG: %[[#NDRangePtrTy:]] = OpTypePointer Function %[[#NDRangeTy]]

; Function Attrs: convergent noinline nounwind optnone
define spir_kernel void @device_side_enqueue() #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !2 !kernel_arg_type !2 !kernel_arg_base_type !2 !kernel_arg_type_qual !2 {
entry:

; CHECK: %[[#NDRange:]] = OpVariable %[[#NDRangePtrTy]]

%ndrange = alloca %struct.ndrange_t, align 4

; CHECK: %[[#BlockLit1:]] = OpPtrCastToGeneric %[[#]] %[[#]]
; CHECK: %[[#]] = OpGetKernelNDrangeMaxSubGroupSize %[[#Int32Ty]] %[[#NDRange]] %[[#]] %[[#BlockLit1]] %[[#C8]] %[[#C4]]

%0 = call i32 @__get_kernel_max_sub_group_size_for_ndrange_impl(ptr %ndrange, ptr addrspace(4) addrspacecast (ptr @__device_side_enqueue_block_invoke_kernel to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @__block_literal_global to ptr addrspace(4)))

; CHECK: %[[#BlockLit2:]] = OpPtrCastToGeneric %[[#]] %[[#]]
; CHECK: %[[#]] = OpGetKernelNDrangeSubGroupCount %[[#Int32Ty]] %[[#NDRange]] %[[#]] %[[#BlockLit2]] %[[#C8]] %[[#C4]]

%1 = call i32 @__get_kernel_sub_group_count_for_ndrange_impl(ptr %ndrange, ptr addrspace(4) addrspacecast (ptr @__device_side_enqueue_block_invoke_1_kernel to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @__block_literal_global.1 to ptr addrspace(4)))
ret void
}

declare i32 @__get_kernel_preferred_work_group_size_multiple_impl(ptr addrspace(4), ptr addrspace(4))

; Function Attrs: convergent noinline nounwind optnone
define internal spir_func void @__device_side_enqueue_block_invoke(ptr addrspace(4) %.block_descriptor) #1 {
entry:
%.block_descriptor.addr = alloca ptr addrspace(4), align 4
%block.addr = alloca ptr addrspace(4), align 4
store ptr addrspace(4) %.block_descriptor, ptr %.block_descriptor.addr, align 4
store ptr addrspace(4) %.block_descriptor, ptr %block.addr, align 4
ret void
}

; Function Attrs: nounwind
define internal spir_kernel void @__device_side_enqueue_block_invoke_kernel(ptr addrspace(4)) #2 {
entry:
call void @__device_side_enqueue_block_invoke(ptr addrspace(4) %0)
ret void
}

declare i32 @__get_kernel_max_sub_group_size_for_ndrange_impl(ptr, ptr addrspace(4), ptr addrspace(4))

; Function Attrs: convergent noinline nounwind optnone
define internal spir_func void @__device_side_enqueue_block_invoke_1(ptr addrspace(4) %.block_descriptor) #1 {
entry:
%.block_descriptor.addr = alloca ptr addrspace(4), align 4
%block.addr = alloca ptr addrspace(4), align 4
store ptr addrspace(4) %.block_descriptor, ptr %.block_descriptor.addr, align 4
store ptr addrspace(4) %.block_descriptor, ptr %block.addr, align 4
ret void
}

; Function Attrs: nounwind
define internal spir_kernel void @__device_side_enqueue_block_invoke_1_kernel(ptr addrspace(4)) #2 {
entry:
call void @__device_side_enqueue_block_invoke_1(ptr addrspace(4) %0)
ret void
}

declare i32 @__get_kernel_sub_group_count_for_ndrange_impl(ptr, ptr addrspace(4), ptr addrspace(4))

attributes #0 = { convergent noinline nounwind optnone "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { convergent noinline nounwind optnone "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #2 = { nounwind }
attributes #3 = { argmemonly nounwind }

!llvm.module.flags = !{!0}
!opencl.enable.FP_CONTRACT = !{}
!opencl.ocl.version = !{!1}
!opencl.spir.version = !{!1}
!opencl.used.extensions = !{!2}
!opencl.used.optional.core.features = !{!2}
!opencl.compiler.options = !{!2}
!llvm.ident = !{!3}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 2, i32 0}
!2 = !{}
!3 = !{!"clang version 7.0.0"}