Skip to content

Commit 60bbf11

Browse files
V-FEXrtinbelicpow2clk
authored andcommitted
[HLSL] Implement WaveActiveAnyTrue intrinsic (#115902)
Resolves llvm/llvm-project#99160 - [x] Implement `WaveActiveAnyTrue` clang builtin, - [x] Link `WaveActiveAnyTrue` clang builtin with `hlsl_intrinsics.h` - [x] Add sema checks for `WaveActiveAnyTrue` to `CheckHLSLBuiltinFunctionCall` in `SemaChecking.cpp` - [x] Add codegen for `WaveActiveAnyTrue` to `EmitHLSLBuiltinExpr` in `CGBuiltin.cpp` - [x] Add codegen tests to `clang/test/CodeGenHLSL/builtins/WaveActiveAnyTrue.hlsl` - [x] Add sema tests to `clang/test/SemaHLSL/BuiltIns/WaveActiveAnyTrue-errors.hlsl` - [x] Create the `int_dx_WaveActiveAnyTrue` intrinsic in `IntrinsicsDirectX.td` - [x] Create the `DXILOpMapping` of `int_dx_WaveActiveAnyTrue` to `113` in `DXIL.td` - [x] Create the `WaveActiveAnyTrue.ll` and `WaveActiveAnyTrue_errors.ll` tests in `llvm/test/CodeGen/DirectX/` - [x] Create the `int_spv_WaveActiveAnyTrue` intrinsic in `IntrinsicsSPIRV.td` - [x] In SPIRVInstructionSelector.cpp create the `WaveActiveAnyTrue` lowering and map it to `int_spv_WaveActiveAnyTrue` in `SPIRVInstructionSelector::selectIntrinsic`. - [x] Create SPIR-V backend test case in `llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAnyTrue.ll` --------- Co-authored-by: Finn Plummer <[email protected]> Co-authored-by: Greg Roth <[email protected]>
1 parent 75429eb commit 60bbf11

File tree

13 files changed

+146
-51
lines changed

13 files changed

+146
-51
lines changed

clang/include/clang/Basic/Builtins.td

+6
Original file line numberDiff line numberDiff line change
@@ -4744,6 +4744,12 @@ def HLSLAny : LangBuiltin<"HLSL_LANG"> {
47444744
let Prototype = "bool(...)";
47454745
}
47464746

4747+
def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
4748+
let Spellings = ["__builtin_hlsl_wave_active_any_true"];
4749+
let Attributes = [NoThrow, Const];
4750+
let Prototype = "bool(bool)";
4751+
}
4752+
47474753
def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
47484754
let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
47494755
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGBuiltin.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -19175,6 +19175,16 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1917519175
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
1917619176
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
1917719177
}
19178+
case Builtin::BI__builtin_hlsl_wave_active_any_true: {
19179+
Value *Op = EmitScalarExpr(E->getArg(0));
19180+
llvm::Type *Ty = Op->getType();
19181+
assert(Ty->isIntegerTy(1) &&
19182+
"Intrinsic WaveActiveAnyTrue operand must be a bool");
19183+
19184+
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAnyTrueIntrinsic();
19185+
return EmitRuntimeCall(
19186+
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
19187+
}
1917819188
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
1917919189
Value *OpExpr = EmitScalarExpr(E->getArg(0));
1918019190
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();

clang/lib/CodeGen/CGHLSLRuntime.h

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class CGHLSLRuntime {
9191
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
9292
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
9393
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
94+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
9495
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
9596
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
9697
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)

clang/lib/Headers/hlsl/hlsl_intrinsics.h

+9
Original file line numberDiff line numberDiff line change
@@ -2184,6 +2184,15 @@ float4 trunc(float4);
21842184
// Wave* builtins
21852185
//===----------------------------------------------------------------------===//
21862186

2187+
/// \brief Returns true if the expression is true in any active lane in the
2188+
/// current wave.
2189+
///
2190+
/// \param Val The boolean expression to evaluate.
2191+
/// \return True if the expression is true in any lane.
2192+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2193+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_any_true)
2194+
__attribute__((convergent)) bool WaveActiveAnyTrue(bool Val);
2195+
21872196
/// \brief Counts the number of boolean variables which evaluate to true across
21882197
/// all active lanes in the current wave.
21892198
///
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
7+
8+
// Test basic lowering to runtime function call for int values.
9+
10+
// CHECK-LABEL: define {{.*}}test
11+
bool test(bool p1) {
12+
// CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
13+
// CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.any(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
14+
// CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.any(i1 %{{[a-zA-Z0-9]+}})
15+
// CHECK: ret i1 %[[RET]]
16+
return WaveActiveAnyTrue(p1);
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
bool test_too_few_arg() {
4+
return __builtin_hlsl_wave_active_any_true();
5+
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
6+
}
7+
8+
bool test_too_many_arg(bool p0) {
9+
return __builtin_hlsl_wave_active_any_true(p0, p0);
10+
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
11+
}
12+
13+
struct Foo
14+
{
15+
int a;
16+
};
17+
18+
bool test_type_check(Foo p0) {
19+
return __builtin_hlsl_wave_active_any_true(p0);
20+
// expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
21+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
9292
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
9393
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
9494
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
95+
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
9596
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
9697
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
9798
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ let TargetPrefix = "spv" in {
8686
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
8787
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
8888
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
89+
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
8990
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
9091
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
9192
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;

llvm/lib/Target/DirectX/DXIL.td

+8
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,14 @@ def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
854854
let stages = [Stages<DXIL1_6, [all_stages]>];
855855
}
856856

857+
def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> {
858+
let Doc = "returns true if the expression is true in any of the active lanes in the current wave";
859+
let LLVMIntrinsic = int_dx_wave_any;
860+
let arguments = [Int1Ty];
861+
let result = Int1Ty;
862+
let stages = [Stages<DXIL1_0, [all_stages]>];
863+
}
864+
857865
def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
858866
let Doc = "returns 1 for the first lane in the wave";
859867
let LLVMIntrinsic = int_dx_wave_is_first_lane;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

+32-43
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,12 @@ class SPIRVInstructionSelector : public InstructionSelector {
256256
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
257257
MachineInstr &I) const;
258258

259+
bool selectWaveOpInst(Register ResVReg, const SPIRVType *ResType,
260+
MachineInstr &I, unsigned Opcode) const;
261+
259262
bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
260263
MachineInstr &I) const;
261264

262-
bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
263-
MachineInstr &I) const;
264-
265265
bool selectUnmergeValues(MachineInstr &I) const;
266266

267267
void selectHandleFromBinding(Register &ResVReg, const SPIRVType *ResType,
@@ -1920,24 +1920,36 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
19201920
return Result;
19211921
}
19221922

1923+
bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
1924+
const SPIRVType *ResType,
1925+
MachineInstr &I,
1926+
unsigned Opcode) const {
1927+
MachineBasicBlock &BB = *I.getParent();
1928+
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
1929+
1930+
auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
1931+
.addDef(ResVReg)
1932+
.addUse(GR.getSPIRVTypeID(ResType))
1933+
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
1934+
IntTy, TII));
1935+
1936+
for (unsigned J = 2; J < I.getNumOperands(); J++) {
1937+
BMI.addUse(I.getOperand(J).getReg());
1938+
}
1939+
1940+
return BMI.constrainAllUses(TII, TRI, RBI);
1941+
}
1942+
19231943
bool SPIRVInstructionSelector::selectWaveActiveCountBits(
19241944
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
1925-
assert(I.getNumOperands() == 3);
1926-
assert(I.getOperand(2).isReg());
1927-
MachineBasicBlock &BB = *I.getParent();
19281945

19291946
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
19301947
SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
19311948
Register BallotReg = MRI->createVirtualRegister(GR.getRegClass(BallotType));
1949+
bool Result = selectWaveOpInst(BallotReg, BallotType, I,
1950+
SPIRV::OpGroupNonUniformBallot);
19321951

1933-
bool Result =
1934-
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot))
1935-
.addDef(BallotReg)
1936-
.addUse(GR.getSPIRVTypeID(BallotType))
1937-
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
1938-
.addUse(I.getOperand(2).getReg())
1939-
.constrainAllUses(TII, TRI, RBI);
1940-
1952+
MachineBasicBlock &BB = *I.getParent();
19411953
Result &=
19421954
BuildMI(BB, I, I.getDebugLoc(),
19431955
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
@@ -1951,26 +1963,6 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
19511963
return Result;
19521964
}
19531965

1954-
bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
1955-
const SPIRVType *ResType,
1956-
MachineInstr &I) const {
1957-
assert(I.getNumOperands() == 4);
1958-
assert(I.getOperand(2).isReg());
1959-
assert(I.getOperand(3).isReg());
1960-
MachineBasicBlock &BB = *I.getParent();
1961-
1962-
// IntTy is used to define the execution scope, set to 3 to denote a
1963-
// cross-lane interaction equivalent to a SPIR-V subgroup.
1964-
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
1965-
return BuildMI(BB, I, I.getDebugLoc(),
1966-
TII.get(SPIRV::OpGroupNonUniformShuffle))
1967-
.addDef(ResVReg)
1968-
.addUse(GR.getSPIRVTypeID(ResType))
1969-
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII))
1970-
.addUse(I.getOperand(2).getReg())
1971-
.addUse(I.getOperand(3).getReg());
1972-
}
1973-
19741966
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
19751967
const SPIRVType *ResType,
19761968
MachineInstr &I) const {
@@ -2781,16 +2773,13 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
27812773
return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp);
27822774
case Intrinsic::spv_wave_active_countbits:
27832775
return selectWaveActiveCountBits(ResVReg, ResType, I);
2784-
case Intrinsic::spv_wave_is_first_lane: {
2785-
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
2786-
return BuildMI(BB, I, I.getDebugLoc(),
2787-
TII.get(SPIRV::OpGroupNonUniformElect))
2788-
.addDef(ResVReg)
2789-
.addUse(GR.getSPIRVTypeID(ResType))
2790-
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
2791-
}
2776+
case Intrinsic::spv_wave_any:
2777+
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
2778+
case Intrinsic::spv_wave_is_first_lane:
2779+
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformElect);
27922780
case Intrinsic::spv_wave_readlane:
2793-
return selectWaveReadLaneAt(ResVReg, ResType, I);
2781+
return selectWaveOpInst(ResVReg, ResType, I,
2782+
SPIRV::OpGroupNonUniformShuffle);
27942783
case Intrinsic::spv_step:
27952784
return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step);
27962785
case Intrinsic::spv_radians:

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,15 @@ void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
630630
addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
631631
Capability::Int16});
632632

633+
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
634+
addAvailableCaps({Capability::GroupNonUniform,
635+
Capability::GroupNonUniformVote,
636+
Capability::GroupNonUniformArithmetic,
637+
Capability::GroupNonUniformBallot,
638+
Capability::GroupNonUniformClustered,
639+
Capability::GroupNonUniformShuffle,
640+
Capability::GroupNonUniformShuffleRelative});
641+
633642
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
634643
addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
635644
Capability::DotProductInput4x8Bit,
@@ -673,14 +682,6 @@ void RequirementHandler::initAvailableCapabilitiesForOpenCL(
673682
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
674683
ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
675684
addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
676-
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
677-
addAvailableCaps({Capability::GroupNonUniform,
678-
Capability::GroupNonUniformVote,
679-
Capability::GroupNonUniformArithmetic,
680-
Capability::GroupNonUniformBallot,
681-
Capability::GroupNonUniformClustered,
682-
Capability::GroupNonUniformShuffle,
683-
Capability::GroupNonUniformShuffleRelative});
684685
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
685686
addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
686687
Capability::SignedZeroInfNanPreserve,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
define noundef i1 @wave_any_simple(i1 noundef %p1) {
4+
entry:
5+
; CHECK: call i1 @dx.op.waveAnyTrue(i32 113, i1 %p1)
6+
%ret = call i1 @llvm.dx.wave.any(i1 %p1)
7+
ret i1 %ret
8+
}
9+
10+
declare i1 @llvm.dx.wave.any(i1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#bool:]] = OpTypeBool
5+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
6+
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
7+
; CHECK-DAG: OpCapability GroupNonUniformVote
8+
9+
; CHECK-LABEL: Begin function test_wave_any
10+
define i1 @test_wave_any(i1 %p1) #0 {
11+
entry:
12+
; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]]
13+
; CHECK: %{{.+}} = OpGroupNonUniformAny %[[#bool]] %[[#scope]] %[[#param]]
14+
%0 = call token @llvm.experimental.convergence.entry()
15+
%ret = call i1 @llvm.spv.wave.any(i1 %p1) [ "convergencectrl"(token %0) ]
16+
ret i1 %ret
17+
}
18+
19+
declare i1 @llvm.spv.wave.any(i1) #0
20+
21+
attributes #0 = { convergent }

0 commit comments

Comments
 (0)