-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[HLSL] Re-implement countbits with the correct return type #113189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-x86 @llvm/pr-subscribers-backend-directx Author: Sarah Spall (spall) ChangesRestricts hlsl countbits to always return a uint32. Full diff: https://github.com/llvm/llvm-project/pull/113189.diff 6 Files Affected:
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..2a612c3746076c 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -705,66 +705,90 @@ float4 cosh(float4);
#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t countbits(int16_t);
+constexpr uint countbits(int16_t x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t2 countbits(int16_t2);
+constexpr uint2 countbits(int16_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t3 countbits(int16_t3);
+constexpr uint3 countbits(int16_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t4 countbits(int16_t4);
+constexpr uint4 countbits(int16_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t countbits(uint16_t);
+constexpr uint countbits(uint16_t x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t2 countbits(uint16_t2);
+constexpr uint2 countbits(uint16_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t3 countbits(uint16_t3);
+constexpr uint3 countbits(uint16_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t4 countbits(uint16_t4);
+constexpr uint4 countbits(uint16_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
#endif
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int countbits(int);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int2 countbits(int2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int3 countbits(int3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int4 countbits(int4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint countbits(uint);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint2 countbits(uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint3 countbits(uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint4 countbits(uint4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t countbits(int64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t2 countbits(int64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t3 countbits(int64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t4 countbits(int64_t4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t countbits(uint64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t2 countbits(uint64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t3 countbits(uint64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t4 countbits(uint64_t4);
+constexpr uint countbits(int x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(int2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int4 x) {
+ return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint4 x) {
+ return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(int64_t x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(int64_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int64_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int64_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint64_t x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint64_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint64_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint64_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
//===----------------------------------------------------------------------===//
// degrees builtins
diff --git a/clang/test/CodeGenHLSL/builtins/countbits.hlsl b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
index 8dfe977bfae626..aa9ef40d7a0dc8 100644
--- a/clang/test/CodeGenHLSL/builtins/countbits.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
@@ -4,26 +4,30 @@
#ifdef __HLSL_ENABLE_16_BIT
// CHECK-LABEL: test_countbits_ushort
-// CHECK: call i16 @llvm.ctpop.i16
-uint16_t test_countbits_ushort(uint16_t p0)
+// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
+// CHECK-NEXT: zext i16 [[A]] to i32
+uint test_countbits_ushort(uint16_t p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort2
-// CHECK: call <2 x i16> @llvm.ctpop.v2i16
-uint16_t2 test_countbits_ushort2(uint16_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
+// CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
+uint2 test_countbits_ushort2(uint16_t2 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort3
-// CHECK: call <3 x i16> @llvm.ctpop.v3i16
-uint16_t3 test_countbits_ushort3(uint16_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i16> @llvm.ctpop.v3i16
+// CHECK-NEXT: zext <3 x i16> [[A]] to <3 x i32>
+uint3 test_countbits_ushort3(uint16_t3 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort4
-// CHECK: call <4 x i16> @llvm.ctpop.v4i16
-uint16_t4 test_countbits_ushort4(uint16_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i16> @llvm.ctpop.v4i16
+// CHECK-NEXT: zext <4 x i16> [[A]] to <4 x i32>
+uint4 test_countbits_ushort4(uint16_t4 p0)
{
return countbits(p0);
}
@@ -31,7 +35,7 @@ uint16_t4 test_countbits_ushort4(uint16_t4 p0)
// CHECK-LABEL: test_countbits_uint
// CHECK: call i32 @llvm.ctpop.i32
-int test_countbits_uint(uint p0)
+uint test_countbits_uint(uint p0)
{
return countbits(p0);
}
@@ -55,26 +59,30 @@ uint4 test_countbits_uint4(uint4 p0)
}
// CHECK-LABEL: test_countbits_long
-// CHECK: call i64 @llvm.ctpop.i64
-uint64_t test_countbits_long(uint64_t p0)
+// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
+// CHECK-NEXT: trunc i64 [[A]] to i32
+uint test_countbits_long(uint64_t p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long2
-// CHECK: call <2 x i64> @llvm.ctpop.v2i64
-uint64_t2 test_countbits_long2(uint64_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
+// CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
+uint2 test_countbits_long2(uint64_t2 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long3
-// CHECK: call <3 x i64> @llvm.ctpop.v3i64
-uint64_t3 test_countbits_long3(uint64_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i64> @llvm.ctpop.v3i64
+// CHECK-NEXT: trunc <3 x i64> [[A]] to <3 x i32>
+uint3 test_countbits_long3(uint64_t3 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long4
-// CHECK: call <4 x i64> @llvm.ctpop.v4i64
-uint64_t4 test_countbits_long4(uint64_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i64> @llvm.ctpop.v4i64
+// CHECK-NEXT: trunc <4 x i64> [[A]] to <4 x i32>
+uint4 test_countbits_long4(uint64_t4 p0)
{
return countbits(p0);
}
diff --git a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
index 8d5f0abb2860f8..5704165e1a4505 100644
--- a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
@@ -1,6 +1,4 @@
-// RUN: %clang_cc1 -finclude-default-header
-// -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only
-// -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
double test_int_builtin(double p0) {
@@ -9,13 +7,11 @@ double test_int_builtin(double p0) {
}
double2 test_int_builtin_2(double2 p0) {
- return __builtin_elementwise_popcount(p0);
- // expected-error@-1 {{1st argument must be a vector of integers
- // (was 'double2' (aka 'vector<double, 2>'))}}
+ return countbits(p0);
+ // expected-error@-1 {{call to 'countbits' is ambiguous}}
}
double test_int_builtin_3(float p0) {
- return __builtin_elementwise_popcount(p0);
- // expected-error@-1 {{1st argument must be a vector of integers
- // (was 'float')}}
+ return countbits(p0);
+ // expected-error@-1 {{call to 'countbits' is ambiguous}}
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..73636739de0659 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -553,11 +553,10 @@ def Rbits : DXILOp<30, unary> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-def CBits : DXILOp<31, unary> {
+def CBits : DXILOp<31, unaryBits> {
let Doc = "Returns the number of 1 bits in the specified value.";
- let LLVMIntrinsic = int_ctpop;
let arguments = [OverloadTy];
- let result = OverloadTy;
+ let result = Int32Ty;
let overloads =
[Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 99df4850872078..a0b5df25760206 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -461,6 +461,67 @@ class OpLowerer {
});
}
+ [[nodiscard]] bool lowerCtpopToCBits(Function &F) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+ Type *Int32Ty = IRB.getInt32Ty();
+
+ return replaceFunction(F, [&](CallInst *CI) -> Error {
+ IRB.SetInsertPoint(CI);
+ SmallVector<Value *> Args;
+ Args.append(CI->arg_begin(), CI->arg_end());
+
+ Type *RetTy = Int32Ty;
+ Type *FRT = F.getReturnType();
+ if (FRT->isVectorTy()) {
+ VectorType *VT = cast<VectorType>(FRT);
+ RetTy = VectorType::get(RetTy, VT);
+ }
+
+ Expected<CallInst *> OpCall =
+ OpBuilder.tryCreateOp(dxil::OpCode::CBits, Args, CI->getName(), RetTy);
+ if (Error E = OpCall.takeError())
+ return E;
+
+ // If the result type is 32 bits we can do a direct replacement.
+ if (FRT->isIntOrIntVectorTy(32)) {
+ CI->replaceAllUsesWith(*OpCall);
+ CI->eraseFromParent();
+ return Error::success();
+ }
+
+ unsigned CastOp;
+ if (FRT->isIntOrIntVectorTy(16))
+ CastOp = Instruction::ZExt;
+ else // must be 64 bits
+ CastOp = Instruction::Trunc;
+
+ // It is correct to replace the ctpop with the dxil op and
+ // remove an existing cast iff the cast is the only usage of
+ // the ctpop
+ // can use hasOneUse instead of hasOneUser, because the user
+ // we care about should have one operand
+ if (CI->hasOneUse()) {
+ User *U = CI->user_back();
+ Instruction *I;
+ if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
+ I->getOpcode() == CastOp && I->getType() == RetTy) {
+ I->replaceAllUsesWith(*OpCall);
+ I->eraseFromParent();
+ CI->eraseFromParent();
+ return Error::success();
+ }
+ }
+
+ // It is always correct to replace a ctpop with the dxil op and
+ // a cast
+ Value *Cast = IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(),
+ "ctpop.cast");
+ CI->replaceAllUsesWith(Cast);
+ CI->eraseFromParent();
+ return Error::success();
+ });
+ }
+
bool lowerIntrinsics() {
bool Updated = false;
bool HasErrors = false;
@@ -489,6 +550,9 @@ class OpLowerer {
case Intrinsic::dx_typedBufferStore:
HasErrors |= lowerTypedBufferStore(F);
break;
+ case Intrinsic::ctpop:
+ HasErrors |= lowerCtpopToCBits(F);
+ break;
}
Updated = true;
}
diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll
index c6bc2b6790948e..91f6f560903f01 100644
--- a/llvm/test/CodeGen/DirectX/countbits.ll
+++ b/llvm/test/CodeGen/DirectX/countbits.ll
@@ -4,35 +4,58 @@
define noundef i16 @test_countbits_short(i16 noundef %a) {
entry:
-; CHECK: call i16 @dx.op.unary.i16(i32 31, i16 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i16
+; CHECK-NEXT ret i16 [[B]]
%elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
ret i16 %elt.ctpop
}
+define noundef i32 @test_countbits_short2(i16 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+ %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
+ %elt.zext = zext i16 %elt.ctpop to i32
+ ret i32 %elt.zext
+}
+
define noundef i32 @test_countbits_int(i32 noundef %a) {
entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 31, i32 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
%elt.ctpop = call i32 @llvm.ctpop.i32(i32 %a)
ret i32 %elt.ctpop
}
define noundef i64 @test_countbits_long(i64 noundef %a) {
entry:
-; CHECK: call i64 @dx.op.unary.i64(i32 31, i64 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = zext i32 [[A]] to i64
+; CHECK-NEXT ret i64 [[B]]
%elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
ret i64 %elt.ctpop
}
+define noundef i32 @test_countbits_long2(i64 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+ %elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
+ %elt.trunc = trunc i64 %elt.ctpop to i32
+ ret i32 %elt.trunc
+}
+
define noundef <4 x i32> @countbits_vec4_i32(<4 x i32> noundef %a) {
entry:
; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
- ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee0]])
+ ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee0]])
; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
- ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee1]])
+ ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee1]])
; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
- ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee2]])
+ ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee2]])
; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
- ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee3]])
+ ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee3]])
; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
@@ -44,4 +67,4 @@ entry:
declare i16 @llvm.ctpop.i16(i16)
declare i32 @llvm.ctpop.i32(i32)
declare i64 @llvm.ctpop.i64(i64)
-declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
+declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
\ No newline at end of file
|
@llvm/pr-subscribers-hlsl Author: Sarah Spall (spall) ChangesRestricts hlsl countbits to always return a uint32. Full diff: https://github.com/llvm/llvm-project/pull/113189.diff 6 Files Affected:
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..2a612c3746076c 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -705,66 +705,90 @@ float4 cosh(float4);
#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t countbits(int16_t);
+constexpr uint countbits(int16_t x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t2 countbits(int16_t2);
+constexpr uint2 countbits(int16_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t3 countbits(int16_t3);
+constexpr uint3 countbits(int16_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t4 countbits(int16_t4);
+constexpr uint4 countbits(int16_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t countbits(uint16_t);
+constexpr uint countbits(uint16_t x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t2 countbits(uint16_t2);
+constexpr uint2 countbits(uint16_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t3 countbits(uint16_t3);
+constexpr uint3 countbits(uint16_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t4 countbits(uint16_t4);
+constexpr uint4 countbits(uint16_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
#endif
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int countbits(int);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int2 countbits(int2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int3 countbits(int3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int4 countbits(int4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint countbits(uint);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint2 countbits(uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint3 countbits(uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint4 countbits(uint4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t countbits(int64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t2 countbits(int64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t3 countbits(int64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t4 countbits(int64_t4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t countbits(uint64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t2 countbits(uint64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t3 countbits(uint64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t4 countbits(uint64_t4);
+constexpr uint countbits(int x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(int2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int4 x) {
+ return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint4 x) {
+ return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(int64_t x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(int64_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int64_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int64_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint64_t x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint64_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint64_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint64_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
//===----------------------------------------------------------------------===//
// degrees builtins
diff --git a/clang/test/CodeGenHLSL/builtins/countbits.hlsl b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
index 8dfe977bfae626..aa9ef40d7a0dc8 100644
--- a/clang/test/CodeGenHLSL/builtins/countbits.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
@@ -4,26 +4,30 @@
#ifdef __HLSL_ENABLE_16_BIT
// CHECK-LABEL: test_countbits_ushort
-// CHECK: call i16 @llvm.ctpop.i16
-uint16_t test_countbits_ushort(uint16_t p0)
+// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
+// CHECK-NEXT: zext i16 [[A]] to i32
+uint test_countbits_ushort(uint16_t p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort2
-// CHECK: call <2 x i16> @llvm.ctpop.v2i16
-uint16_t2 test_countbits_ushort2(uint16_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
+// CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
+uint2 test_countbits_ushort2(uint16_t2 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort3
-// CHECK: call <3 x i16> @llvm.ctpop.v3i16
-uint16_t3 test_countbits_ushort3(uint16_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i16> @llvm.ctpop.v3i16
+// CHECK-NEXT: zext <3 x i16> [[A]] to <3 x i32>
+uint3 test_countbits_ushort3(uint16_t3 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort4
-// CHECK: call <4 x i16> @llvm.ctpop.v4i16
-uint16_t4 test_countbits_ushort4(uint16_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i16> @llvm.ctpop.v4i16
+// CHECK-NEXT: zext <4 x i16> [[A]] to <4 x i32>
+uint4 test_countbits_ushort4(uint16_t4 p0)
{
return countbits(p0);
}
@@ -31,7 +35,7 @@ uint16_t4 test_countbits_ushort4(uint16_t4 p0)
// CHECK-LABEL: test_countbits_uint
// CHECK: call i32 @llvm.ctpop.i32
-int test_countbits_uint(uint p0)
+uint test_countbits_uint(uint p0)
{
return countbits(p0);
}
@@ -55,26 +59,30 @@ uint4 test_countbits_uint4(uint4 p0)
}
// CHECK-LABEL: test_countbits_long
-// CHECK: call i64 @llvm.ctpop.i64
-uint64_t test_countbits_long(uint64_t p0)
+// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
+// CHECK-NEXT: trunc i64 [[A]] to i32
+uint test_countbits_long(uint64_t p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long2
-// CHECK: call <2 x i64> @llvm.ctpop.v2i64
-uint64_t2 test_countbits_long2(uint64_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
+// CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
+uint2 test_countbits_long2(uint64_t2 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long3
-// CHECK: call <3 x i64> @llvm.ctpop.v3i64
-uint64_t3 test_countbits_long3(uint64_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i64> @llvm.ctpop.v3i64
+// CHECK-NEXT: trunc <3 x i64> [[A]] to <3 x i32>
+uint3 test_countbits_long3(uint64_t3 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long4
-// CHECK: call <4 x i64> @llvm.ctpop.v4i64
-uint64_t4 test_countbits_long4(uint64_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i64> @llvm.ctpop.v4i64
+// CHECK-NEXT: trunc <4 x i64> [[A]] to <4 x i32>
+uint4 test_countbits_long4(uint64_t4 p0)
{
return countbits(p0);
}
diff --git a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
index 8d5f0abb2860f8..5704165e1a4505 100644
--- a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
@@ -1,6 +1,4 @@
-// RUN: %clang_cc1 -finclude-default-header
-// -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only
-// -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
double test_int_builtin(double p0) {
@@ -9,13 +7,11 @@ double test_int_builtin(double p0) {
}
double2 test_int_builtin_2(double2 p0) {
- return __builtin_elementwise_popcount(p0);
- // expected-error@-1 {{1st argument must be a vector of integers
- // (was 'double2' (aka 'vector<double, 2>'))}}
+ return countbits(p0);
+ // expected-error@-1 {{call to 'countbits' is ambiguous}}
}
double test_int_builtin_3(float p0) {
- return __builtin_elementwise_popcount(p0);
- // expected-error@-1 {{1st argument must be a vector of integers
- // (was 'float')}}
+ return countbits(p0);
+ // expected-error@-1 {{call to 'countbits' is ambiguous}}
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..73636739de0659 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -553,11 +553,10 @@ def Rbits : DXILOp<30, unary> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-def CBits : DXILOp<31, unary> {
+def CBits : DXILOp<31, unaryBits> {
let Doc = "Returns the number of 1 bits in the specified value.";
- let LLVMIntrinsic = int_ctpop;
let arguments = [OverloadTy];
- let result = OverloadTy;
+ let result = Int32Ty;
let overloads =
[Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 99df4850872078..a0b5df25760206 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -461,6 +461,67 @@ class OpLowerer {
});
}
+ [[nodiscard]] bool lowerCtpopToCBits(Function &F) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+ Type *Int32Ty = IRB.getInt32Ty();
+
+ return replaceFunction(F, [&](CallInst *CI) -> Error {
+ IRB.SetInsertPoint(CI);
+ SmallVector<Value *> Args;
+ Args.append(CI->arg_begin(), CI->arg_end());
+
+ Type *RetTy = Int32Ty;
+ Type *FRT = F.getReturnType();
+ if (FRT->isVectorTy()) {
+ VectorType *VT = cast<VectorType>(FRT);
+ RetTy = VectorType::get(RetTy, VT);
+ }
+
+ Expected<CallInst *> OpCall =
+ OpBuilder.tryCreateOp(dxil::OpCode::CBits, Args, CI->getName(), RetTy);
+ if (Error E = OpCall.takeError())
+ return E;
+
+ // If the result type is 32 bits we can do a direct replacement.
+ if (FRT->isIntOrIntVectorTy(32)) {
+ CI->replaceAllUsesWith(*OpCall);
+ CI->eraseFromParent();
+ return Error::success();
+ }
+
+ unsigned CastOp;
+ if (FRT->isIntOrIntVectorTy(16))
+ CastOp = Instruction::ZExt;
+ else // must be 64 bits
+ CastOp = Instruction::Trunc;
+
+ // It is correct to replace the ctpop with the dxil op and
+ // remove an existing cast iff the cast is the only usage of
+ // the ctpop
+ // can use hasOneUse instead of hasOneUser, because the user
+ // we care about should have one operand
+ if (CI->hasOneUse()) {
+ User *U = CI->user_back();
+ Instruction *I;
+ if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
+ I->getOpcode() == CastOp && I->getType() == RetTy) {
+ I->replaceAllUsesWith(*OpCall);
+ I->eraseFromParent();
+ CI->eraseFromParent();
+ return Error::success();
+ }
+ }
+
+ // It is always correct to replace a ctpop with the dxil op and
+ // a cast
+ Value *Cast = IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(),
+ "ctpop.cast");
+ CI->replaceAllUsesWith(Cast);
+ CI->eraseFromParent();
+ return Error::success();
+ });
+ }
+
bool lowerIntrinsics() {
bool Updated = false;
bool HasErrors = false;
@@ -489,6 +550,9 @@ class OpLowerer {
case Intrinsic::dx_typedBufferStore:
HasErrors |= lowerTypedBufferStore(F);
break;
+ case Intrinsic::ctpop:
+ HasErrors |= lowerCtpopToCBits(F);
+ break;
}
Updated = true;
}
diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll
index c6bc2b6790948e..91f6f560903f01 100644
--- a/llvm/test/CodeGen/DirectX/countbits.ll
+++ b/llvm/test/CodeGen/DirectX/countbits.ll
@@ -4,35 +4,58 @@
define noundef i16 @test_countbits_short(i16 noundef %a) {
entry:
-; CHECK: call i16 @dx.op.unary.i16(i32 31, i16 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i16
+; CHECK-NEXT ret i16 [[B]]
%elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
ret i16 %elt.ctpop
}
+define noundef i32 @test_countbits_short2(i16 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+ %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
+ %elt.zext = zext i16 %elt.ctpop to i32
+ ret i32 %elt.zext
+}
+
define noundef i32 @test_countbits_int(i32 noundef %a) {
entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 31, i32 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
%elt.ctpop = call i32 @llvm.ctpop.i32(i32 %a)
ret i32 %elt.ctpop
}
define noundef i64 @test_countbits_long(i64 noundef %a) {
entry:
-; CHECK: call i64 @dx.op.unary.i64(i32 31, i64 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = zext i32 [[A]] to i64
+; CHECK-NEXT ret i64 [[B]]
%elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
ret i64 %elt.ctpop
}
+define noundef i32 @test_countbits_long2(i64 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+ %elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
+ %elt.trunc = trunc i64 %elt.ctpop to i32
+ ret i32 %elt.trunc
+}
+
define noundef <4 x i32> @countbits_vec4_i32(<4 x i32> noundef %a) {
entry:
; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
- ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee0]])
+ ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee0]])
; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
- ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee1]])
+ ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee1]])
; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
- ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee2]])
+ ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee2]])
; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
- ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee3]])
+ ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee3]])
; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
@@ -44,4 +67,4 @@ entry:
declare i16 @llvm.ctpop.i16(i16)
declare i32 @llvm.ctpop.i32(i32)
declare i64 @llvm.ctpop.i64(i64)
-declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
+declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
\ No newline at end of file
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
@@ -705,66 +705,74 @@ float4 cosh(float4); | |||
|
|||
#ifdef __HLSL_ENABLE_16_BIT | |||
_HLSL_AVAILABILITY(shadermodel, 6.2) | |||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) | |||
int16_t countbits(int16_t); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed the earlier PR where this was added, but I don't know how signed integers work here. We only support unsigned in HLSL. Is that a change for clang?
Missing tests as well, if so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I asked Justin about this he pointed to this link:
https://github.com/microsoft/DirectXShaderCompiler/blob/main/utils/hct/gen_intrin_main.txt#L114
and suggested it meant both signed and unsigned were supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, of course our documentation is wrong. countbits of signed is kind of illogical, but definitely should add tests regardless.
|
||
unsigned CastOp; | ||
if (FRT->isIntOrIntVectorTy(16)) | ||
CastOp = Instruction::ZExt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to the signed vs unsigned: this is where things I think get confusing. int16_t would almost always be sign extended, not zero extended. But doing a sign extension here doesn't make sense as you'll count 16 extra bits for any negative int16_t.
I think this is why std::bitset basically ignores the type. It's a forcing function to the author to say "yes, I'm purposely counting a signed value now go away"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if I'm being confusing - I am not saying you have the wrong code here, I'm just bringing it up that this is where my thought drifted towards the topic because of this zero extension (which I think is correct)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The extension should only be on the return value which is always unsigned. Is this code wrong with that in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. This is overriding the cast op below. Is a zero/sign extension needed? I notice the return type is "int16_t".
Although I guess that's another question - should the return types all be 32 bits? i.e., "int_t"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disregard. I see you actually updated that correctly! (the whole point of this PR haha)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of minor style nitpicks, but otherwise this LGTM!
Restricts hlsl countbits to always return a uint32. Implements a lowering from llvm.ctpop which has an overloaded return type to dxil cbits op which always returns uint32. Closes llvm#112779
Restricts hlsl countbits to always return a uint32.
Implements a lowering from llvm.ctpop which has an overloaded return type to dxil cbits op which always returns uint32.
Closes #112779