Skip to content

[HLSL] Re-implement countbits with the correct return type #113189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 29, 2024

Conversation

spall
Copy link
Contributor

@spall spall commented Oct 21, 2024

Restricts hlsl countbits to always return a uint32.
Implements a lowering from llvm.ctpop which has an overloaded return type to dxil cbits op which always returns uint32.
Closes #112779

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:headers Headers provided by Clang, e.g. for intrinsics backend:DirectX HLSL HLSL Language Support labels Oct 21, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 21, 2024

@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-backend-directx

Author: Sarah Spall (spall)

Changes

Restricts hlsl countbits to always return a uint32.
Implements a lowering from llvm.ctpop which has an overloaded return type to dxil cbits op which always returns uint32.
Closes #112779


Full diff: https://github.com/llvm/llvm-project/pull/113189.diff

6 Files Affected:

  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+75-51)
  • (modified) clang/test/CodeGenHLSL/builtins/countbits.hlsl (+25-17)
  • (modified) clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl (+5-9)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+2-3)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+64)
  • (modified) llvm/test/CodeGen/DirectX/countbits.ll (+31-8)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..2a612c3746076c 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -705,66 +705,90 @@ float4 cosh(float4);
 
 #ifdef __HLSL_ENABLE_16_BIT
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t countbits(int16_t);
+constexpr uint countbits(int16_t x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t2 countbits(int16_t2);
+constexpr uint2 countbits(int16_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t3 countbits(int16_t3);
+constexpr uint3 countbits(int16_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t4 countbits(int16_t4);
+constexpr uint4 countbits(int16_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t countbits(uint16_t);
+constexpr uint countbits(uint16_t x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t2 countbits(uint16_t2);
+constexpr uint2 countbits(uint16_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t3 countbits(uint16_t3);
+constexpr uint3 countbits(uint16_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t4 countbits(uint16_t4);
+constexpr uint4 countbits(uint16_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 #endif
 
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int countbits(int);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int2 countbits(int2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int3 countbits(int3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int4 countbits(int4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint countbits(uint);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint2 countbits(uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint3 countbits(uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint4 countbits(uint4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t countbits(int64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t2 countbits(int64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t3 countbits(int64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t4 countbits(int64_t4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t countbits(uint64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t2 countbits(uint64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t3 countbits(uint64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t4 countbits(uint64_t4);
+constexpr uint countbits(int x) {
+  return __builtin_elementwise_popcount(x);
+}  
+constexpr uint2 countbits(int2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int4 x) {
+  return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint4 x) {
+  return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(int64_t x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(int64_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int64_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int64_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint64_t x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint64_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint64_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint64_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 
 //===----------------------------------------------------------------------===//
 // degrees builtins
diff --git a/clang/test/CodeGenHLSL/builtins/countbits.hlsl b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
index 8dfe977bfae626..aa9ef40d7a0dc8 100644
--- a/clang/test/CodeGenHLSL/builtins/countbits.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
@@ -4,26 +4,30 @@
 
 #ifdef __HLSL_ENABLE_16_BIT
 // CHECK-LABEL: test_countbits_ushort
-// CHECK: call i16 @llvm.ctpop.i16
-uint16_t test_countbits_ushort(uint16_t p0)
+// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
+// CHECK-NEXT: zext i16 [[A]] to i32
+uint test_countbits_ushort(uint16_t p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_ushort2
-// CHECK: call <2 x i16> @llvm.ctpop.v2i16
-uint16_t2 test_countbits_ushort2(uint16_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
+// CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
+uint2 test_countbits_ushort2(uint16_t2 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_ushort3
-// CHECK: call <3 x i16> @llvm.ctpop.v3i16
-uint16_t3 test_countbits_ushort3(uint16_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i16> @llvm.ctpop.v3i16
+// CHECK-NEXT: zext <3 x i16> [[A]] to <3 x i32>
+uint3 test_countbits_ushort3(uint16_t3 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_ushort4
-// CHECK: call <4 x i16> @llvm.ctpop.v4i16
-uint16_t4 test_countbits_ushort4(uint16_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i16> @llvm.ctpop.v4i16
+// CHECK-NEXT: zext <4 x i16> [[A]] to <4 x i32>
+uint4 test_countbits_ushort4(uint16_t4 p0)
 {
 	return countbits(p0);
 }
@@ -31,7 +35,7 @@ uint16_t4 test_countbits_ushort4(uint16_t4 p0)
 
 // CHECK-LABEL: test_countbits_uint
 // CHECK: call i32 @llvm.ctpop.i32
-int test_countbits_uint(uint p0)
+uint test_countbits_uint(uint p0)
 {
 	return countbits(p0);
 }
@@ -55,26 +59,30 @@ uint4 test_countbits_uint4(uint4 p0)
 }
 
 // CHECK-LABEL: test_countbits_long
-// CHECK: call i64 @llvm.ctpop.i64
-uint64_t test_countbits_long(uint64_t p0)
+// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
+// CHECK-NEXT: trunc i64 [[A]] to i32
+uint test_countbits_long(uint64_t p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_long2
-// CHECK: call <2 x i64> @llvm.ctpop.v2i64
-uint64_t2 test_countbits_long2(uint64_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
+// CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
+uint2 test_countbits_long2(uint64_t2 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_long3
-// CHECK: call <3 x i64> @llvm.ctpop.v3i64
-uint64_t3 test_countbits_long3(uint64_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i64> @llvm.ctpop.v3i64
+// CHECK-NEXT: trunc <3 x i64> [[A]] to <3 x i32>
+uint3 test_countbits_long3(uint64_t3 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_long4
-// CHECK: call <4 x i64> @llvm.ctpop.v4i64
-uint64_t4 test_countbits_long4(uint64_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i64> @llvm.ctpop.v4i64
+// CHECK-NEXT: trunc <4 x i64> [[A]] to <4 x i32>
+uint4 test_countbits_long4(uint64_t4 p0)
 {
 	return countbits(p0);
 }
diff --git a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
index 8d5f0abb2860f8..5704165e1a4505 100644
--- a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
@@ -1,6 +1,4 @@
-// RUN: %clang_cc1 -finclude-default-header
-// -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only
-// -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
 
 
 double test_int_builtin(double p0) {
@@ -9,13 +7,11 @@ double test_int_builtin(double p0) {
 }
 
 double2 test_int_builtin_2(double2 p0) {
-  return __builtin_elementwise_popcount(p0);
-  // expected-error@-1 {{1st argument must be a vector of integers
-  // (was 'double2' (aka 'vector<double, 2>'))}}
+  return countbits(p0);
+  // expected-error@-1 {{call to 'countbits' is ambiguous}}
 }
 
 double test_int_builtin_3(float p0) {
-  return __builtin_elementwise_popcount(p0);
-  // expected-error@-1 {{1st argument must be a vector of integers
-  // (was 'float')}}
+  return countbits(p0);
+  // expected-error@-1 {{call to 'countbits' is ambiguous}}
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..73636739de0659 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -553,11 +553,10 @@ def Rbits :  DXILOp<30, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
-def CBits :  DXILOp<31, unary> {
+def CBits :  DXILOp<31, unaryBits> {
   let Doc = "Returns the number of 1 bits in the specified value.";
-  let LLVMIntrinsic = int_ctpop;
   let arguments = [OverloadTy];
-  let result = OverloadTy;
+  let result = Int32Ty;
   let overloads =
       [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 99df4850872078..a0b5df25760206 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -461,6 +461,67 @@ class OpLowerer {
     });
   }
 
+  [[nodiscard]] bool lowerCtpopToCBits(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+    Type *Int32Ty = IRB.getInt32Ty();
+    
+    return replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+      SmallVector<Value *> Args;
+      Args.append(CI->arg_begin(), CI->arg_end());
+
+      Type *RetTy = Int32Ty;
+      Type *FRT = F.getReturnType();
+      if (FRT->isVectorTy()) {
+        VectorType *VT = cast<VectorType>(FRT);
+	RetTy = VectorType::get(RetTy, VT);
+      }
+      
+      Expected<CallInst *> OpCall =
+	OpBuilder.tryCreateOp(dxil::OpCode::CBits, Args, CI->getName(), RetTy);
+      if (Error E = OpCall.takeError())
+	return E;
+
+      // If the result type is 32 bits we can do a direct replacement.
+      if (FRT->isIntOrIntVectorTy(32)) {
+        CI->replaceAllUsesWith(*OpCall);
+	CI->eraseFromParent();
+	return Error::success();
+      }
+
+      unsigned CastOp;
+      if (FRT->isIntOrIntVectorTy(16))
+	CastOp = Instruction::ZExt;
+      else // must be 64 bits
+	CastOp = Instruction::Trunc;
+
+      // It is correct to replace the ctpop with the dxil op and
+      // remove an existing cast iff the cast is the only usage of
+      // the ctpop
+      // can use hasOneUse instead of hasOneUser, because the user
+      // we care about should have one operand
+      if (CI->hasOneUse()) {
+	User *U = CI->user_back();
+	Instruction *I;
+	if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
+	    I->getOpcode() == CastOp && I->getType() == RetTy) {
+          I->replaceAllUsesWith(*OpCall);
+	  I->eraseFromParent();
+	  CI->eraseFromParent();
+	  return Error::success();
+	  }
+      }
+
+      // It is always correct to replace a ctpop with the dxil op and
+      // a cast
+      Value *Cast = IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(),
+					  "ctpop.cast");
+      CI->replaceAllUsesWith(Cast);
+      CI->eraseFromParent();
+      return Error::success();
+    });
+  }
+
   bool lowerIntrinsics() {
     bool Updated = false;
     bool HasErrors = false;
@@ -489,6 +550,9 @@ class OpLowerer {
       case Intrinsic::dx_typedBufferStore:
         HasErrors |= lowerTypedBufferStore(F);
         break;
+      case Intrinsic::ctpop:
+	HasErrors |= lowerCtpopToCBits(F);
+	break;
       }
       Updated = true;
     }
diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll
index c6bc2b6790948e..91f6f560903f01 100644
--- a/llvm/test/CodeGen/DirectX/countbits.ll
+++ b/llvm/test/CodeGen/DirectX/countbits.ll
@@ -4,35 +4,58 @@
 
 define noundef i16 @test_countbits_short(i16 noundef %a) {
 entry:
-; CHECK: call i16 @dx.op.unary.i16(i32 31, i16 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i16
+; CHECK-NEXT ret i16 [[B]]
   %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
   ret i16 %elt.ctpop
 }
 
+define noundef i32 @test_countbits_short2(i16 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+  %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
+  %elt.zext = zext i16 %elt.ctpop to i32
+  ret i32 %elt.zext
+}
+
 define noundef i32 @test_countbits_int(i32 noundef %a) {
 entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 31, i32 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
   %elt.ctpop = call i32 @llvm.ctpop.i32(i32 %a)
   ret i32 %elt.ctpop
 }
 
 define noundef i64 @test_countbits_long(i64 noundef %a) {
 entry:
-; CHECK: call i64 @dx.op.unary.i64(i32 31, i64 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = zext i32 [[A]] to i64
+; CHECK-NEXT ret i64 [[B]]
   %elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
   ret i64 %elt.ctpop
 }
 
+define noundef i32 @test_countbits_long2(i64 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+  %elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
+  %elt.trunc = trunc i64 %elt.ctpop to i32
+  ret i32 %elt.trunc
+}
+
 define noundef <4 x i32> @countbits_vec4_i32(<4 x i32> noundef %a)  {
 entry:
   ; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
-  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee0]])
+  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee0]])
   ; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
-  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee1]])
+  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee1]])
   ; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
-  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee2]])
+  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee2]])
   ; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
-  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee3]])
+  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee3]])
   ; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
   ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
   ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
@@ -44,4 +67,4 @@ entry:
 declare i16 @llvm.ctpop.i16(i16)
 declare i32 @llvm.ctpop.i32(i32)
 declare i64 @llvm.ctpop.i64(i64)
-declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
+declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
\ No newline at end of file

@llvmbot
Copy link
Member

llvmbot commented Oct 21, 2024

@llvm/pr-subscribers-hlsl

Author: Sarah Spall (spall)

Changes

Restricts hlsl countbits to always return a uint32.
Implements a lowering from llvm.ctpop which has an overloaded return type to dxil cbits op which always returns uint32.
Closes #112779


Full diff: https://github.com/llvm/llvm-project/pull/113189.diff

6 Files Affected:

  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+75-51)
  • (modified) clang/test/CodeGenHLSL/builtins/countbits.hlsl (+25-17)
  • (modified) clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl (+5-9)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+2-3)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+64)
  • (modified) llvm/test/CodeGen/DirectX/countbits.ll (+31-8)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..2a612c3746076c 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -705,66 +705,90 @@ float4 cosh(float4);
 
 #ifdef __HLSL_ENABLE_16_BIT
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t countbits(int16_t);
+constexpr uint countbits(int16_t x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t2 countbits(int16_t2);
+constexpr uint2 countbits(int16_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t3 countbits(int16_t3);
+constexpr uint3 countbits(int16_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t4 countbits(int16_t4);
+constexpr uint4 countbits(int16_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t countbits(uint16_t);
+constexpr uint countbits(uint16_t x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t2 countbits(uint16_t2);
+constexpr uint2 countbits(uint16_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t3 countbits(uint16_t3);
+constexpr uint3 countbits(uint16_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t4 countbits(uint16_t4);
+constexpr uint4 countbits(uint16_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 #endif
 
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int countbits(int);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int2 countbits(int2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int3 countbits(int3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int4 countbits(int4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint countbits(uint);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint2 countbits(uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint3 countbits(uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint4 countbits(uint4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t countbits(int64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t2 countbits(int64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t3 countbits(int64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t4 countbits(int64_t4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t countbits(uint64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t2 countbits(uint64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t3 countbits(uint64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t4 countbits(uint64_t4);
+constexpr uint countbits(int x) {
+  return __builtin_elementwise_popcount(x);
+}  
+constexpr uint2 countbits(int2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int4 x) {
+  return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint4 x) {
+  return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(int64_t x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(int64_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int64_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int64_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint64_t x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint64_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint64_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint64_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 
 //===----------------------------------------------------------------------===//
 // degrees builtins
diff --git a/clang/test/CodeGenHLSL/builtins/countbits.hlsl b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
index 8dfe977bfae626..aa9ef40d7a0dc8 100644
--- a/clang/test/CodeGenHLSL/builtins/countbits.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
@@ -4,26 +4,30 @@
 
 #ifdef __HLSL_ENABLE_16_BIT
 // CHECK-LABEL: test_countbits_ushort
-// CHECK: call i16 @llvm.ctpop.i16
-uint16_t test_countbits_ushort(uint16_t p0)
+// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
+// CHECK-NEXT: zext i16 [[A]] to i32
+uint test_countbits_ushort(uint16_t p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_ushort2
-// CHECK: call <2 x i16> @llvm.ctpop.v2i16
-uint16_t2 test_countbits_ushort2(uint16_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
+// CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
+uint2 test_countbits_ushort2(uint16_t2 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_ushort3
-// CHECK: call <3 x i16> @llvm.ctpop.v3i16
-uint16_t3 test_countbits_ushort3(uint16_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i16> @llvm.ctpop.v3i16
+// CHECK-NEXT: zext <3 x i16> [[A]] to <3 x i32>
+uint3 test_countbits_ushort3(uint16_t3 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_ushort4
-// CHECK: call <4 x i16> @llvm.ctpop.v4i16
-uint16_t4 test_countbits_ushort4(uint16_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i16> @llvm.ctpop.v4i16
+// CHECK-NEXT: zext <4 x i16> [[A]] to <4 x i32>
+uint4 test_countbits_ushort4(uint16_t4 p0)
 {
 	return countbits(p0);
 }
@@ -31,7 +35,7 @@ uint16_t4 test_countbits_ushort4(uint16_t4 p0)
 
 // CHECK-LABEL: test_countbits_uint
 // CHECK: call i32 @llvm.ctpop.i32
-int test_countbits_uint(uint p0)
+uint test_countbits_uint(uint p0)
 {
 	return countbits(p0);
 }
@@ -55,26 +59,30 @@ uint4 test_countbits_uint4(uint4 p0)
 }
 
 // CHECK-LABEL: test_countbits_long
-// CHECK: call i64 @llvm.ctpop.i64
-uint64_t test_countbits_long(uint64_t p0)
+// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
+// CHECK-NEXT: trunc i64 [[A]] to i32
+uint test_countbits_long(uint64_t p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_long2
-// CHECK: call <2 x i64> @llvm.ctpop.v2i64
-uint64_t2 test_countbits_long2(uint64_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
+// CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
+uint2 test_countbits_long2(uint64_t2 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_long3
-// CHECK: call <3 x i64> @llvm.ctpop.v3i64
-uint64_t3 test_countbits_long3(uint64_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i64> @llvm.ctpop.v3i64
+// CHECK-NEXT: trunc <3 x i64> [[A]] to <3 x i32>
+uint3 test_countbits_long3(uint64_t3 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_long4
-// CHECK: call <4 x i64> @llvm.ctpop.v4i64
-uint64_t4 test_countbits_long4(uint64_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i64> @llvm.ctpop.v4i64
+// CHECK-NEXT: trunc <4 x i64> [[A]] to <4 x i32>
+uint4 test_countbits_long4(uint64_t4 p0)
 {
 	return countbits(p0);
 }
diff --git a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
index 8d5f0abb2860f8..5704165e1a4505 100644
--- a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
@@ -1,6 +1,4 @@
-// RUN: %clang_cc1 -finclude-default-header
-// -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only
-// -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
 
 
 double test_int_builtin(double p0) {
@@ -9,13 +7,11 @@ double test_int_builtin(double p0) {
 }
 
 double2 test_int_builtin_2(double2 p0) {
-  return __builtin_elementwise_popcount(p0);
-  // expected-error@-1 {{1st argument must be a vector of integers
-  // (was 'double2' (aka 'vector<double, 2>'))}}
+  return countbits(p0);
+  // expected-error@-1 {{call to 'countbits' is ambiguous}}
 }
 
 double test_int_builtin_3(float p0) {
-  return __builtin_elementwise_popcount(p0);
-  // expected-error@-1 {{1st argument must be a vector of integers
-  // (was 'float')}}
+  return countbits(p0);
+  // expected-error@-1 {{call to 'countbits' is ambiguous}}
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..73636739de0659 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -553,11 +553,10 @@ def Rbits :  DXILOp<30, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
-def CBits :  DXILOp<31, unary> {
+def CBits :  DXILOp<31, unaryBits> {
   let Doc = "Returns the number of 1 bits in the specified value.";
-  let LLVMIntrinsic = int_ctpop;
   let arguments = [OverloadTy];
-  let result = OverloadTy;
+  let result = Int32Ty;
   let overloads =
       [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 99df4850872078..a0b5df25760206 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -461,6 +461,67 @@ class OpLowerer {
     });
   }
 
+  [[nodiscard]] bool lowerCtpopToCBits(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+    Type *Int32Ty = IRB.getInt32Ty();
+    
+    return replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+      SmallVector<Value *> Args;
+      Args.append(CI->arg_begin(), CI->arg_end());
+
+      Type *RetTy = Int32Ty;
+      Type *FRT = F.getReturnType();
+      if (FRT->isVectorTy()) {
+        VectorType *VT = cast<VectorType>(FRT);
+	RetTy = VectorType::get(RetTy, VT);
+      }
+      
+      Expected<CallInst *> OpCall =
+	OpBuilder.tryCreateOp(dxil::OpCode::CBits, Args, CI->getName(), RetTy);
+      if (Error E = OpCall.takeError())
+	return E;
+
+      // If the result type is 32 bits we can do a direct replacement.
+      if (FRT->isIntOrIntVectorTy(32)) {
+        CI->replaceAllUsesWith(*OpCall);
+	CI->eraseFromParent();
+	return Error::success();
+      }
+
+      unsigned CastOp;
+      if (FRT->isIntOrIntVectorTy(16))
+	CastOp = Instruction::ZExt;
+      else // must be 64 bits
+	CastOp = Instruction::Trunc;
+
+      // It is correct to replace the ctpop with the dxil op and
+      // remove an existing cast iff the cast is the only usage of
+      // the ctpop
+      // can use hasOneUse instead of hasOneUser, because the user
+      // we care about should have one operand
+      if (CI->hasOneUse()) {
+	User *U = CI->user_back();
+	Instruction *I;
+	if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
+	    I->getOpcode() == CastOp && I->getType() == RetTy) {
+          I->replaceAllUsesWith(*OpCall);
+	  I->eraseFromParent();
+	  CI->eraseFromParent();
+	  return Error::success();
+	  }
+      }
+
+      // It is always correct to replace a ctpop with the dxil op and
+      // a cast
+      Value *Cast = IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(),
+					  "ctpop.cast");
+      CI->replaceAllUsesWith(Cast);
+      CI->eraseFromParent();
+      return Error::success();
+    });
+  }
+
   bool lowerIntrinsics() {
     bool Updated = false;
     bool HasErrors = false;
@@ -489,6 +550,9 @@ class OpLowerer {
       case Intrinsic::dx_typedBufferStore:
         HasErrors |= lowerTypedBufferStore(F);
         break;
+      case Intrinsic::ctpop:
+	HasErrors |= lowerCtpopToCBits(F);
+	break;
       }
       Updated = true;
     }
diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll
index c6bc2b6790948e..91f6f560903f01 100644
--- a/llvm/test/CodeGen/DirectX/countbits.ll
+++ b/llvm/test/CodeGen/DirectX/countbits.ll
@@ -4,35 +4,58 @@
 
 define noundef i16 @test_countbits_short(i16 noundef %a) {
 entry:
-; CHECK: call i16 @dx.op.unary.i16(i32 31, i16 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i16
+; CHECK-NEXT ret i16 [[B]]
   %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
   ret i16 %elt.ctpop
 }
 
+define noundef i32 @test_countbits_short2(i16 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+  %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
+  %elt.zext = zext i16 %elt.ctpop to i32
+  ret i32 %elt.zext
+}
+
 define noundef i32 @test_countbits_int(i32 noundef %a) {
 entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 31, i32 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
   %elt.ctpop = call i32 @llvm.ctpop.i32(i32 %a)
   ret i32 %elt.ctpop
 }
 
 define noundef i64 @test_countbits_long(i64 noundef %a) {
 entry:
-; CHECK: call i64 @dx.op.unary.i64(i32 31, i64 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = zext i32 [[A]] to i64
+; CHECK-NEXT ret i64 [[B]]
   %elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
   ret i64 %elt.ctpop
 }
 
+define noundef i32 @test_countbits_long2(i64 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+  %elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
+  %elt.trunc = trunc i64 %elt.ctpop to i32
+  ret i32 %elt.trunc
+}
+
 define noundef <4 x i32> @countbits_vec4_i32(<4 x i32> noundef %a)  {
 entry:
   ; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
-  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee0]])
+  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee0]])
   ; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
-  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee1]])
+  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee1]])
   ; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
-  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee2]])
+  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee2]])
   ; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
-  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee3]])
+  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee3]])
   ; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
   ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
   ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
@@ -44,4 +67,4 @@ entry:
 declare i16 @llvm.ctpop.i16(i16)
 declare i32 @llvm.ctpop.i32(i32)
 declare i64 @llvm.ctpop.i64(i64)
-declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
+declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
\ No newline at end of file

Copy link

github-actions bot commented Oct 21, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@@ -705,66 +705,74 @@ float4 cosh(float4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int16_t countbits(int16_t);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed the earlier PR where this was added, but I don't know how signed integers work here. We only support unsigned in HLSL. Is that a change for clang?
Missing tests as well, if so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I asked Justin about this he pointed to this link:
https://github.com/microsoft/DirectXShaderCompiler/blob/main/utils/hct/gen_intrin_main.txt#L114
and suggested it meant both signed and unsigned were supported.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, of course our documentation is wrong. countbits of signed is kind of illogical, but definitely should add tests regardless.


unsigned CastOp;
if (FRT->isIntOrIntVectorTy(16))
CastOp = Instruction::ZExt;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to the signed vs unsigned: this is where things I think get confusing. int16_t would almost always be sign extended, not zero extended. But doing a sign extension here doesn't make sense as you'll count 16 extra bits for any negative int16_t.
I think this is why std::bitset basically ignores the type. It's a forcing function to the author to say "yes, I'm purposely counting a signed value now go away"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I'm being confusing - I am not saying you have the wrong code here, I'm just bringing it up that this is where my thought drifted towards the topic because of this zero extension (which I think is correct)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extension should only be on the return value which is always unsigned. Is this code wrong with that in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. This is overriding the cast op below. Is a zero/sign extension needed? I notice the return type is "int16_t".
Although I guess that's another question - should the return types all be 32 bits? i.e., "int_t"

Copy link
Contributor

@bfavela bfavela Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disregard. I see you actually updated that correctly! (the whole point of this PR haha)

Copy link
Member

@farzonl farzonl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of minor style nitpicks, but otherwise this LGTM!

@spall spall merged commit 75e7ba8 into llvm:main Oct 29, 2024
8 checks passed
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
Restricts hlsl countbits to always return a uint32.
Implements a lowering from llvm.ctpop which has an overloaded return
type to dxil cbits op which always returns uint32.
Closes llvm#112779
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:X86 clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[HLSL] countbits should not have an overloaded return type.
6 participants