Skip to content

Commit 75e7ba8

Browse files
authored
[HLSL] Re-implement countbits with the correct return type (#113189)
Restricts hlsl countbits to always return a uint32. Implements a lowering from llvm.ctpop which has an overloaded return type to dxil cbits op which always returns uint32. Closes #112779
1 parent e268398 commit 75e7ba8

File tree

6 files changed

+234
-87
lines changed

6 files changed

+234
-87
lines changed

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 73 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -723,66 +723,88 @@ float4 cosh(float4);
723723

724724
#ifdef __HLSL_ENABLE_16_BIT
725725
_HLSL_AVAILABILITY(shadermodel, 6.2)
726-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
727-
int16_t countbits(int16_t);
726+
const inline uint countbits(int16_t x) {
727+
return __builtin_elementwise_popcount(x);
728+
}
728729
_HLSL_AVAILABILITY(shadermodel, 6.2)
729-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
730-
int16_t2 countbits(int16_t2);
730+
const inline uint2 countbits(int16_t2 x) {
731+
return __builtin_elementwise_popcount(x);
732+
}
731733
_HLSL_AVAILABILITY(shadermodel, 6.2)
732-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
733-
int16_t3 countbits(int16_t3);
734+
const inline uint3 countbits(int16_t3 x) {
735+
return __builtin_elementwise_popcount(x);
736+
}
734737
_HLSL_AVAILABILITY(shadermodel, 6.2)
735-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
736-
int16_t4 countbits(int16_t4);
738+
const inline uint4 countbits(int16_t4 x) {
739+
return __builtin_elementwise_popcount(x);
740+
}
737741
_HLSL_AVAILABILITY(shadermodel, 6.2)
738-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
739-
uint16_t countbits(uint16_t);
742+
const inline uint countbits(uint16_t x) {
743+
return __builtin_elementwise_popcount(x);
744+
}
740745
_HLSL_AVAILABILITY(shadermodel, 6.2)
741-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
742-
uint16_t2 countbits(uint16_t2);
746+
const inline uint2 countbits(uint16_t2 x) {
747+
return __builtin_elementwise_popcount(x);
748+
}
743749
_HLSL_AVAILABILITY(shadermodel, 6.2)
744-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
745-
uint16_t3 countbits(uint16_t3);
750+
const inline uint3 countbits(uint16_t3 x) {
751+
return __builtin_elementwise_popcount(x);
752+
}
746753
_HLSL_AVAILABILITY(shadermodel, 6.2)
747-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
748-
uint16_t4 countbits(uint16_t4);
754+
const inline uint4 countbits(uint16_t4 x) {
755+
return __builtin_elementwise_popcount(x);
756+
}
749757
#endif
750758

751-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
752-
int countbits(int);
753-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
754-
int2 countbits(int2);
755-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
756-
int3 countbits(int3);
757-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
758-
int4 countbits(int4);
759-
760-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
761-
uint countbits(uint);
762-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
763-
uint2 countbits(uint2);
764-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
765-
uint3 countbits(uint3);
766-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
767-
uint4 countbits(uint4);
768-
769-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
770-
int64_t countbits(int64_t);
771-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
772-
int64_t2 countbits(int64_t2);
773-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
774-
int64_t3 countbits(int64_t3);
775-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
776-
int64_t4 countbits(int64_t4);
777-
778-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
779-
uint64_t countbits(uint64_t);
780-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
781-
uint64_t2 countbits(uint64_t2);
782-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
783-
uint64_t3 countbits(uint64_t3);
784-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
785-
uint64_t4 countbits(uint64_t4);
759+
const inline uint countbits(int x) { return __builtin_elementwise_popcount(x); }
760+
const inline uint2 countbits(int2 x) {
761+
return __builtin_elementwise_popcount(x);
762+
}
763+
const inline uint3 countbits(int3 x) {
764+
return __builtin_elementwise_popcount(x);
765+
}
766+
const inline uint4 countbits(int4 x) {
767+
return __builtin_elementwise_popcount(x);
768+
}
769+
770+
const inline uint countbits(uint x) {
771+
return __builtin_elementwise_popcount(x);
772+
}
773+
const inline uint2 countbits(uint2 x) {
774+
return __builtin_elementwise_popcount(x);
775+
}
776+
const inline uint3 countbits(uint3 x) {
777+
return __builtin_elementwise_popcount(x);
778+
}
779+
const inline uint4 countbits(uint4 x) {
780+
return __builtin_elementwise_popcount(x);
781+
}
782+
783+
const inline uint countbits(int64_t x) {
784+
return __builtin_elementwise_popcount(x);
785+
}
786+
const inline uint2 countbits(int64_t2 x) {
787+
return __builtin_elementwise_popcount(x);
788+
}
789+
const inline uint3 countbits(int64_t3 x) {
790+
return __builtin_elementwise_popcount(x);
791+
}
792+
const inline uint4 countbits(int64_t4 x) {
793+
return __builtin_elementwise_popcount(x);
794+
}
795+
796+
const inline uint countbits(uint64_t x) {
797+
return __builtin_elementwise_popcount(x);
798+
}
799+
const inline uint2 countbits(uint64_t2 x) {
800+
return __builtin_elementwise_popcount(x);
801+
}
802+
const inline uint3 countbits(uint64_t3 x) {
803+
return __builtin_elementwise_popcount(x);
804+
}
805+
const inline uint4 countbits(uint64_t4 x) {
806+
return __builtin_elementwise_popcount(x);
807+
}
786808

787809
//===----------------------------------------------------------------------===//
788810
// degrees builtins

clang/test/CodeGenHLSL/builtins/countbits.hlsl

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,51 @@
44

55
#ifdef __HLSL_ENABLE_16_BIT
66
// CHECK-LABEL: test_countbits_ushort
7-
// CHECK: call i16 @llvm.ctpop.i16
8-
uint16_t test_countbits_ushort(uint16_t p0)
7+
// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
8+
// CHECK-NEXT: zext i16 [[A]] to i32
9+
uint test_countbits_ushort(uint16_t p0)
10+
{
11+
return countbits(p0);
12+
}
13+
// CHECK-LABEL: test_countbits_short
14+
// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
15+
// CHECK-NEXT: sext i16 [[A]] to i32
16+
uint test_countbits_short(int16_t p0)
917
{
1018
return countbits(p0);
1119
}
1220
// CHECK-LABEL: test_countbits_ushort2
13-
// CHECK: call <2 x i16> @llvm.ctpop.v2i16
14-
uint16_t2 test_countbits_ushort2(uint16_t2 p0)
21+
// CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
22+
// CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
23+
uint2 test_countbits_ushort2(uint16_t2 p0)
1524
{
1625
return countbits(p0);
1726
}
1827
// CHECK-LABEL: test_countbits_ushort3
19-
// CHECK: call <3 x i16> @llvm.ctpop.v3i16
20-
uint16_t3 test_countbits_ushort3(uint16_t3 p0)
28+
// CHECK: [[A:%.*]] = call <3 x i16> @llvm.ctpop.v3i16
29+
// CHECK-NEXT: zext <3 x i16> [[A]] to <3 x i32>
30+
uint3 test_countbits_ushort3(uint16_t3 p0)
2131
{
2232
return countbits(p0);
2333
}
2434
// CHECK-LABEL: test_countbits_ushort4
25-
// CHECK: call <4 x i16> @llvm.ctpop.v4i16
26-
uint16_t4 test_countbits_ushort4(uint16_t4 p0)
35+
// CHECK: [[A:%.*]] = call <4 x i16> @llvm.ctpop.v4i16
36+
// CHECK-NEXT: zext <4 x i16> [[A]] to <4 x i32>
37+
uint4 test_countbits_ushort4(uint16_t4 p0)
2738
{
2839
return countbits(p0);
2940
}
3041
#endif
3142

3243
// CHECK-LABEL: test_countbits_uint
3344
// CHECK: call i32 @llvm.ctpop.i32
34-
int test_countbits_uint(uint p0)
45+
uint test_countbits_uint(uint p0)
46+
{
47+
return countbits(p0);
48+
}
49+
// CHECK-LABEL: test_countbits_int
50+
// CHECK: call i32 @llvm.ctpop.i32
51+
uint test_countbits_int(int p0)
3552
{
3653
return countbits(p0);
3754
}
@@ -55,26 +72,37 @@ uint4 test_countbits_uint4(uint4 p0)
5572
}
5673

5774
// CHECK-LABEL: test_countbits_long
58-
// CHECK: call i64 @llvm.ctpop.i64
59-
uint64_t test_countbits_long(uint64_t p0)
75+
// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
76+
// CHECK-NEXT: trunc i64 [[A]] to i32
77+
uint test_countbits_long(uint64_t p0)
78+
{
79+
return countbits(p0);
80+
}
81+
// CHECK-LABEL: test_countbits_slong
82+
// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
83+
// CHECK-NEXT: trunc i64 [[A]] to i32
84+
uint test_countbits_slong(int64_t p0)
6085
{
6186
return countbits(p0);
6287
}
6388
// CHECK-LABEL: test_countbits_long2
64-
// CHECK: call <2 x i64> @llvm.ctpop.v2i64
65-
uint64_t2 test_countbits_long2(uint64_t2 p0)
89+
// CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
90+
// CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
91+
uint2 test_countbits_long2(uint64_t2 p0)
6692
{
6793
return countbits(p0);
6894
}
6995
// CHECK-LABEL: test_countbits_long3
70-
// CHECK: call <3 x i64> @llvm.ctpop.v3i64
71-
uint64_t3 test_countbits_long3(uint64_t3 p0)
96+
// CHECK: [[A:%.*]] = call <3 x i64> @llvm.ctpop.v3i64
97+
// CHECK-NEXT: trunc <3 x i64> [[A]] to <3 x i32>
98+
uint3 test_countbits_long3(uint64_t3 p0)
7299
{
73100
return countbits(p0);
74101
}
75102
// CHECK-LABEL: test_countbits_long4
76-
// CHECK: call <4 x i64> @llvm.ctpop.v4i64
77-
uint64_t4 test_countbits_long4(uint64_t4 p0)
103+
// CHECK: [[A:%.*]] = call <4 x i64> @llvm.ctpop.v4i64
104+
// CHECK-NEXT: trunc <4 x i64> [[A]] to <4 x i32>
105+
uint4 test_countbits_long4(uint64_t4 p0)
78106
{
79107
return countbits(p0);
80108
}
Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
// RUN: %clang_cc1 -finclude-default-header
2-
// -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only
3-
// -disable-llvm-passes -verify -verify-ignore-unexpected
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
42

53

64
double test_int_builtin(double p0) {
@@ -9,13 +7,11 @@ double test_int_builtin(double p0) {
97
}
108

119
double2 test_int_builtin_2(double2 p0) {
12-
return __builtin_elementwise_popcount(p0);
13-
// expected-error@-1 {{1st argument must be a vector of integers
14-
// (was 'double2' (aka 'vector<double, 2>'))}}
10+
return countbits(p0);
11+
// expected-error@-1 {{call to 'countbits' is ambiguous}}
1512
}
1613

1714
double test_int_builtin_3(float p0) {
18-
return __builtin_elementwise_popcount(p0);
19-
// expected-error@-1 {{1st argument must be a vector of integers
20-
// (was 'float')}}
15+
return countbits(p0);
16+
// expected-error@-1 {{call to 'countbits' is ambiguous}}
2117
}

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -554,11 +554,10 @@ def Rbits : DXILOp<30, unary> {
554554
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
555555
}
556556

557-
def CBits : DXILOp<31, unary> {
557+
def CountBits : DXILOp<31, unaryBits> {
558558
let Doc = "Returns the number of 1 bits in the specified value.";
559-
let LLVMIntrinsic = int_ctpop;
560559
let arguments = [OverloadTy];
561-
let result = OverloadTy;
560+
let result = Int32Ty;
562561
let overloads =
563562
[Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
564563
let stages = [Stages<DXIL1_0, [all_stages]>];

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,73 @@ class OpLowerer {
505505
});
506506
}
507507

508+
[[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
509+
IRBuilder<> &IRB = OpBuilder.getIRB();
510+
Type *Int32Ty = IRB.getInt32Ty();
511+
512+
return replaceFunction(F, [&](CallInst *CI) -> Error {
513+
IRB.SetInsertPoint(CI);
514+
SmallVector<Value *> Args;
515+
Args.append(CI->arg_begin(), CI->arg_end());
516+
517+
Type *RetTy = Int32Ty;
518+
Type *FRT = F.getReturnType();
519+
if (const auto *VT = dyn_cast<VectorType>(FRT))
520+
RetTy = VectorType::get(RetTy, VT);
521+
522+
Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
523+
dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
524+
if (Error E = OpCall.takeError())
525+
return E;
526+
527+
// If the result type is 32 bits we can do a direct replacement.
528+
if (FRT->isIntOrIntVectorTy(32)) {
529+
CI->replaceAllUsesWith(*OpCall);
530+
CI->eraseFromParent();
531+
return Error::success();
532+
}
533+
534+
unsigned CastOp;
535+
unsigned CastOp2;
536+
if (FRT->isIntOrIntVectorTy(16)) {
537+
CastOp = Instruction::ZExt;
538+
CastOp2 = Instruction::SExt;
539+
} else { // must be 64 bits
540+
assert(FRT->isIntOrIntVectorTy(64) &&
541+
"Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
542+
is supported.");
543+
CastOp = Instruction::Trunc;
544+
CastOp2 = Instruction::Trunc;
545+
}
546+
547+
// It is correct to replace the ctpop with the dxil op and
548+
// remove all casts to i32
549+
bool NeedsCast = false;
550+
for (User *User : make_early_inc_range(CI->users())) {
551+
Instruction *I = dyn_cast<Instruction>(User);
552+
if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
553+
I->getType() == RetTy) {
554+
I->replaceAllUsesWith(*OpCall);
555+
I->eraseFromParent();
556+
} else
557+
NeedsCast = true;
558+
}
559+
560+
// It is correct to replace a ctpop with the dxil op and
561+
// a cast from i32 to the return type of the ctpop
562+
// the cast is emitted here if there is a non-cast to i32
563+
// instr which uses the ctpop
564+
if (NeedsCast) {
565+
Value *Cast =
566+
IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
567+
CI->replaceAllUsesWith(Cast);
568+
}
569+
570+
CI->eraseFromParent();
571+
return Error::success();
572+
});
573+
}
574+
508575
bool lowerIntrinsics() {
509576
bool Updated = false;
510577
bool HasErrors = false;
@@ -543,6 +610,9 @@ class OpLowerer {
543610
return replaceSplitDoubleCallUsages(CI, Op);
544611
});
545612
break;
613+
case Intrinsic::ctpop:
614+
HasErrors |= lowerCtpopToCountBits(F);
615+
break;
546616
}
547617
Updated = true;
548618
}

0 commit comments

Comments
 (0)