Skip to content

Commit 2d47a0b

Browse files
authored
Add step builtins and step HLSL function to DirectX and SPIR-V backend (#106471)
This PR adds the step intrinsic and an HLSL function that uses it. The SPIRV backend is also implemented. Used #102683 as a reference. Fixes #99157
1 parent 9e2bb41 commit 2d47a0b

File tree

13 files changed

+340
-2
lines changed

13 files changed

+340
-2
lines changed

clang/include/clang/Basic/Builtins.td

+7
Original file line numberDiff line numberDiff line change
@@ -4763,6 +4763,7 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
47634763
let Prototype = "void(...)";
47644764
}
47654765

4766+
47664767
def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
47674768
let Spellings = ["__builtin_hlsl_select"];
47684769
let Attributes = [NoThrow, Const];
@@ -4775,6 +4776,12 @@ def HLSLSign : LangBuiltin<"HLSL_LANG"> {
47754776
let Prototype = "void(...)";
47764777
}
47774778

4779+
def HLSLStep: LangBuiltin<"HLSL_LANG"> {
4780+
let Spellings = ["__builtin_hlsl_step"];
4781+
let Attributes = [NoThrow, Const];
4782+
let Prototype = "void(...)";
4783+
}
4784+
47784785
// Builtins for XRay.
47794786
def XRayCustomEvent : Builtin {
47804787
let Spellings = ["__xray_customevent"];

clang/lib/CodeGen/CGBuiltin.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -18861,6 +18861,16 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1886118861

1886218862
return SelectVal;
1886318863
}
18864+
case Builtin::BI__builtin_hlsl_step: {
18865+
Value *Op0 = EmitScalarExpr(E->getArg(0));
18866+
Value *Op1 = EmitScalarExpr(E->getArg(1));
18867+
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
18868+
E->getArg(1)->getType()->hasFloatingRepresentation() &&
18869+
"step operands must have a float representation");
18870+
return Builder.CreateIntrinsic(
18871+
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
18872+
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
18873+
}
1886418874
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
1886518875
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
1886618876
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",

clang/lib/CodeGen/CGHLSLRuntime.h

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class CGHLSLRuntime {
8181
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
8282
GENERATE_HLSL_INTRINSIC_FUNCTION(Saturate, saturate)
8383
GENERATE_HLSL_INTRINSIC_FUNCTION(Sign, sign)
84+
GENERATE_HLSL_INTRINSIC_FUNCTION(Step, step)
8485
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
8586
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
8687
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)

clang/lib/Headers/hlsl/hlsl_intrinsics.h

+33
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,39 @@ float3 sqrt(float3);
17171717
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
17181718
float4 sqrt(float4);
17191719

1720+
//===----------------------------------------------------------------------===//
1721+
// step builtins
1722+
//===----------------------------------------------------------------------===//
1723+
1724+
/// \fn T step(T x, T y)
1725+
/// \brief Returns 1 if the x parameter is greater than or equal to the y
1726+
/// parameter; otherwise, 0. vector. \param x [in] The first floating-point
1727+
/// value to compare. \param y [in] The first floating-point value to compare.
1728+
///
1729+
/// Step is based on the following formula: (x >= y) ? 1 : 0
1730+
1731+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1732+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
1733+
half step(half, half);
1734+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1735+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
1736+
half2 step(half2, half2);
1737+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1738+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
1739+
half3 step(half3, half3);
1740+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1741+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
1742+
half4 step(half4, half4);
1743+
1744+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
1745+
float step(float, float);
1746+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
1747+
float2 step(float2, float2);
1748+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
1749+
float3 step(float3, float3);
1750+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
1751+
float4 step(float4, float4);
1752+
17201753
//===----------------------------------------------------------------------===//
17211754
// tan builtins
17221755
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
17471747
SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy);
17481748
break;
17491749
}
1750+
case Builtin::BI__builtin_hlsl_step: {
1751+
if (SemaRef.checkArgCount(TheCall, 2))
1752+
return true;
1753+
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
1754+
return true;
1755+
1756+
ExprResult A = TheCall->getArg(0);
1757+
QualType ArgTyA = A.get()->getType();
1758+
// return type is the same as the input type
1759+
TheCall->setType(ArgTyA);
1760+
break;
1761+
}
17501762
// Note these are llvm builtins that we want to catch invalid intrinsic
17511763
// generation. Normal handling of these builitns will occur elsewhere.
17521764
case Builtin::BI__builtin_elementwise_bitreverse: {
+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
2+
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
3+
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
4+
// RUN: --check-prefixes=CHECK,NATIVE_HALF \
5+
// RUN: -DFNATTRS=noundef -DTARGET=dx
6+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
7+
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
8+
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
9+
// RUN: -DFNATTRS=noundef -DTARGET=dx
10+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
11+
// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
12+
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
13+
// RUN: --check-prefixes=CHECK,NATIVE_HALF \
14+
// RUN: -DFNATTRS="spir_func noundef" -DTARGET=spv
15+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
16+
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
17+
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
18+
// RUN: -DFNATTRS="spir_func noundef" -DTARGET=spv
19+
20+
// NATIVE_HALF: define [[FNATTRS]] half @
21+
// NATIVE_HALF: call half @llvm.[[TARGET]].step.f16(half
22+
// NO_HALF: call float @llvm.[[TARGET]].step.f32(float
23+
// NATIVE_HALF: ret half
24+
// NO_HALF: ret float
25+
half test_step_half(half p0, half p1)
26+
{
27+
return step(p0, p1);
28+
}
29+
// NATIVE_HALF: define [[FNATTRS]] <2 x half> @
30+
// NATIVE_HALF: call <2 x half> @llvm.[[TARGET]].step.v2f16(<2 x half>
31+
// NO_HALF: call <2 x float> @llvm.[[TARGET]].step.v2f32(<2 x float>
32+
// NATIVE_HALF: ret <2 x half> %hlsl.step
33+
// NO_HALF: ret <2 x float> %hlsl.step
34+
half2 test_step_half2(half2 p0, half2 p1)
35+
{
36+
return step(p0, p1);
37+
}
38+
// NATIVE_HALF: define [[FNATTRS]] <3 x half> @
39+
// NATIVE_HALF: call <3 x half> @llvm.[[TARGET]].step.v3f16(<3 x half>
40+
// NO_HALF: call <3 x float> @llvm.[[TARGET]].step.v3f32(<3 x float>
41+
// NATIVE_HALF: ret <3 x half> %hlsl.step
42+
// NO_HALF: ret <3 x float> %hlsl.step
43+
half3 test_step_half3(half3 p0, half3 p1)
44+
{
45+
return step(p0, p1);
46+
}
47+
// NATIVE_HALF: define [[FNATTRS]] <4 x half> @
48+
// NATIVE_HALF: call <4 x half> @llvm.[[TARGET]].step.v4f16(<4 x half>
49+
// NO_HALF: call <4 x float> @llvm.[[TARGET]].step.v4f32(<4 x float>
50+
// NATIVE_HALF: ret <4 x half> %hlsl.step
51+
// NO_HALF: ret <4 x float> %hlsl.step
52+
half4 test_step_half4(half4 p0, half4 p1)
53+
{
54+
return step(p0, p1);
55+
}
56+
57+
// CHECK: define [[FNATTRS]] float @
58+
// CHECK: call float @llvm.[[TARGET]].step.f32(float
59+
// CHECK: ret float
60+
float test_step_float(float p0, float p1)
61+
{
62+
return step(p0, p1);
63+
}
64+
// CHECK: define [[FNATTRS]] <2 x float> @
65+
// CHECK: %hlsl.step = call <2 x float> @llvm.[[TARGET]].step.v2f32(
66+
// CHECK: ret <2 x float> %hlsl.step
67+
float2 test_step_float2(float2 p0, float2 p1)
68+
{
69+
return step(p0, p1);
70+
}
71+
// CHECK: define [[FNATTRS]] <3 x float> @
72+
// CHECK: %hlsl.step = call <3 x float> @llvm.[[TARGET]].step.v3f32(
73+
// CHECK: ret <3 x float> %hlsl.step
74+
float3 test_step_float3(float3 p0, float3 p1)
75+
{
76+
return step(p0, p1);
77+
}
78+
// CHECK: define [[FNATTRS]] <4 x float> @
79+
// CHECK: %hlsl.step = call <4 x float> @llvm.[[TARGET]].step.v4f32(
80+
// CHECK: ret <4 x float> %hlsl.step
81+
float4 test_step_float4(float4 p0, float4 p1)
82+
{
83+
return step(p0, p1);
84+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -disable-llvm-passes -verify -verify-ignore-unexpected
2+
3+
void test_too_few_arg()
4+
{
5+
return __builtin_hlsl_step();
6+
// expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
7+
}
8+
9+
void test_too_many_arg(float2 p0)
10+
{
11+
return __builtin_hlsl_step(p0, p0, p0);
12+
// expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
13+
}
14+
15+
bool builtin_bool_to_float_type_promotion(bool p1)
16+
{
17+
return __builtin_hlsl_step(p1, p1);
18+
// expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
19+
}
20+
21+
bool builtin_step_int_to_float_promotion(int p1)
22+
{
23+
return __builtin_hlsl_step(p1, p1);
24+
// expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
25+
}
26+
27+
bool2 builtin_step_int2_to_float2_promotion(int2 p1)
28+
{
29+
return __builtin_hlsl_step(p1, p1);
30+
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
31+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
8787
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
8888
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
8989
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
90-
9190
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
9291
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
92+
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>]>;
9393
}

llvm/include/llvm/IR/IntrinsicsSPIRV.td

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ let TargetPrefix = "spv" in {
6767
def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
6868
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
6969
def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
70+
def int_spv_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [LLVMMatchType<0>, llvm_anyfloat_ty]>;
7071
def int_spv_fdot :
7172
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
7273
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

+25-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ static bool isIntrinsicExpansion(Function &F) {
5050
case Intrinsic::dx_sdot:
5151
case Intrinsic::dx_udot:
5252
case Intrinsic::dx_sign:
53+
case Intrinsic::dx_step:
5354
return true;
5455
}
5556
return false;
@@ -322,6 +323,28 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
322323
return Exp2Call;
323324
}
324325

326+
static Value *expandStepIntrinsic(CallInst *Orig) {
327+
328+
Value *X = Orig->getOperand(0);
329+
Value *Y = Orig->getOperand(1);
330+
Type *Ty = X->getType();
331+
IRBuilder<> Builder(Orig);
332+
333+
Constant *One = ConstantFP::get(Ty->getScalarType(), 1.0);
334+
Constant *Zero = ConstantFP::get(Ty->getScalarType(), 0.0);
335+
Value *Cond = Builder.CreateFCmpOLT(Y, X);
336+
337+
if (Ty != Ty->getScalarType()) {
338+
auto *XVec = dyn_cast<FixedVectorType>(Ty);
339+
One = ConstantVector::getSplat(
340+
ElementCount::getFixed(XVec->getNumElements()), One);
341+
Zero = ConstantVector::getSplat(
342+
ElementCount::getFixed(XVec->getNumElements()), Zero);
343+
}
344+
345+
return Builder.CreateSelect(Cond, Zero, One);
346+
}
347+
325348
static Intrinsic::ID getMaxForClamp(Type *ElemTy,
326349
Intrinsic::ID ClampIntrinsic) {
327350
if (ClampIntrinsic == Intrinsic::dx_uclamp)
@@ -433,8 +456,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
433456
case Intrinsic::dx_sign:
434457
Result = expandSignIntrinsic(Orig);
435458
break;
459+
case Intrinsic::dx_step:
460+
Result = expandStepIntrinsic(Orig);
436461
}
437-
438462
if (Result) {
439463
Orig->replaceAllUsesWith(Result);
440464
Orig->eraseFromParent();

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
263263
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
264264
MachineInstr &I) const;
265265

266+
bool selectStep(Register ResVReg, const SPIRVType *ResType,
267+
MachineInstr &I) const;
268+
266269
bool selectUnmergeValues(MachineInstr &I) const;
267270

268271
Register buildI32Constant(uint32_t Val, MachineInstr &I,
@@ -1710,6 +1713,25 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
17101713
return Result;
17111714
}
17121715

1716+
bool SPIRVInstructionSelector::selectStep(Register ResVReg,
1717+
const SPIRVType *ResType,
1718+
MachineInstr &I) const {
1719+
1720+
assert(I.getNumOperands() == 4);
1721+
assert(I.getOperand(2).isReg());
1722+
assert(I.getOperand(3).isReg());
1723+
MachineBasicBlock &BB = *I.getParent();
1724+
1725+
return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
1726+
.addDef(ResVReg)
1727+
.addUse(GR.getSPIRVTypeID(ResType))
1728+
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
1729+
.addImm(GL::Step)
1730+
.addUse(I.getOperand(2).getReg())
1731+
.addUse(I.getOperand(3).getReg())
1732+
.constrainAllUses(TII, TRI, RBI);
1733+
}
1734+
17131735
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
17141736
const SPIRVType *ResType,
17151737
MachineInstr &I) const {
@@ -2468,6 +2490,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
24682490
.addUse(GR.getSPIRVTypeID(ResType))
24692491
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
24702492
}
2493+
case Intrinsic::spv_step:
2494+
return selectStep(ResVReg, ResType, I);
24712495
default: {
24722496
std::string DiagMsg;
24732497
raw_string_ostream OS(DiagMsg);

0 commit comments

Comments
 (0)