Skip to content

[HLSL] implement elementwise firstbithigh hlsl builtin #111082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 6, 2024
Merged

Conversation

spall
Copy link
Contributor

@spall spall commented Oct 4, 2024

Implements elementwise firstbithigh hlsl builtin.
Implements firstbituhigh intrinsic for spirv and directx, which handles unsigned integers
Implements firstbitshigh intrinsic for spirv and directx, which handles signed integers.
Fixes #113486
Closes #99115

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. backend:DirectX HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Oct 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-backend-x86
@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-backend-spir-v
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang-codegen

Author: Sarah Spall (spall)

Changes

Implements elementwise firstbithigh hlsl builtin.
Implements firstbituhigh intrinsic for spirv and directx, which handles unsigned integers
Implements firstbitshigh intrinsic for spirv and directx, which handles signed integers.
Closes #99115


Patch is 25.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111082.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+17)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+2)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+72)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+18)
  • (added) clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl (+153)
  • (added) clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl (+28)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+2)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+2)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+24)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+14)
  • (added) llvm/test/CodeGen/DirectX/firstbithigh.ll (+91)
  • (added) llvm/test/CodeGen/DirectX/firstbitshigh_error.ll (+10)
  • (added) llvm/test/CodeGen/DirectX/firstbituhigh_error.ll (+10)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll (+37)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index a726a0ef4b4bd2..cc4630c281052d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4751,6 +4751,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLFirstBitHigh : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_elementwise_firstbithigh"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLFrac : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_frac"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 43700ea9dd3cfd..6033dcffb3a6da 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18631,6 +18631,14 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
   return RT.getUDotIntrinsic();
 }
 
+Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
+  if (QT->hasSignedIntegerRepresentation()) {
+    return RT.getFirstBitSHighIntrinsic();
+  }
+
+  return RT.getFirstBitUHighIntrinsic();
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -18720,6 +18728,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
         ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
   } break;
+  case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
+    
+    Value *X = EmitScalarExpr(E->getArg(0));
+    
+    return Builder.CreateIntrinsic(
+	/*ReturnType=*/X->getType(),
+	getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
+	ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
+  }
   case Builtin::BI__builtin_hlsl_lerp: {
     Value *X = EmitScalarExpr(E->getArg(0));
     Value *Y = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 6722d2c7c50a2b..03ddef95800e1a 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -88,6 +88,8 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index c3ecfc7c90d433..afb7bd33374111 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -928,6 +928,78 @@ float3 exp2(float3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_exp2)
 float4 exp2(float4);
 
+//===----------------------------------------------------------------------===//
+// firstbithigh builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T firstbithigh(T Val)
+/// \brief Returns the location of the first set bit starting from the highest
+/// order bit and working downward, per component.
+/// \param Val the input value.
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t firstbithigh(int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t2 firstbithigh(int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t3 firstbithigh(int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t4 firstbithigh(int16_t4);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t firstbithigh(uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t2 firstbithigh(uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t3 firstbithigh(uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t4 firstbithigh(uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int firstbithigh(int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int2 firstbithigh(int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int3 firstbithigh(int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int4 firstbithigh(int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint firstbithigh(uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint2 firstbithigh(uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint3 firstbithigh(uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint4 firstbithigh(uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t firstbithigh(int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t2 firstbithigh(int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t3 firstbithigh(int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t4 firstbithigh(int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t firstbithigh(uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t2 firstbithigh(uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t3 firstbithigh(uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t4 firstbithigh(uint64_t4);
+  
 //===----------------------------------------------------------------------===//
 // floor builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index fbcba201a351a6..48ea1d4aab7a18 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1874,6 +1874,24 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
+    if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
+      return true;
+
+    const Expr *Arg = TheCall->getArg(0);
+    QualType ArgTy = Arg->getType();
+    QualType EltTy = ArgTy;
+
+    if (auto *VecTy = EltTy->getAs<VectorType>())
+      EltTy = VecTy->getElementType();
+
+    if (!EltTy->isIntegerType()) {
+      Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+          << 1 << /* integer ty */ 6 << ArgTy;
+      return true;
+    }
+    break;
+  }
   case Builtin::BI__builtin_hlsl_select: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
new file mode 100644
index 00000000000000..9821b308e63521
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
@@ -0,0 +1,153 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s -DTARGET=dx
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s -DTARGET=spv
+
+#ifdef __HLSL_ENABLE_16_BIT
+// CHECK-LABEL: test_firstbithigh_ushort
+// CHECK: call i16 @llvm.[[TARGET]].firstbituhigh.i16
+int test_firstbithigh_ushort(uint16_t p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ushort2
+// CHECK: call <2 x i16> @llvm.[[TARGET]].firstbituhigh.v2i16
+uint16_t2 test_firstbithigh_ushort2(uint16_t2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ushort3
+// CHECK: call <3 x i16> @llvm.[[TARGET]].firstbituhigh.v3i16
+uint16_t3 test_firstbithigh_ushort3(uint16_t3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ushort4
+// CHECK: call <4 x i16> @llvm.[[TARGET]].firstbituhigh.v4i16
+uint16_t4 test_firstbithigh_ushort4(uint16_t4 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short
+// CHECK: call i16 @llvm.[[TARGET]].firstbitshigh.i16
+int16_t test_firstbithigh_short(int16_t p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short2
+// CHECK: call <2 x i16> @llvm.[[TARGET]].firstbitshigh.v2i16
+int16_t2 test_firstbithigh_short2(int16_t2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short3
+// CHECK: call <3 x i16> @llvm.[[TARGET]].firstbitshigh.v3i16
+int16_t3 test_firstbithigh_short3(int16_t3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short4
+// CHECK: call <4 x i16> @llvm.[[TARGET]].firstbitshigh.v4i16
+int16_t4 test_firstbithigh_short4(int16_t4 p0) {
+  return firstbithigh(p0);
+}
+#endif // __HLSL_ENABLE_16_BIT
+
+// CHECK-LABEL: test_firstbithigh_uint
+// CHECK: call i32 @llvm.[[TARGET]].firstbituhigh.i32
+uint test_firstbithigh_uint(uint p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_uint2
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbituhigh.v2i32
+uint2 test_firstbithigh_uint2(uint2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_uint3
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbituhigh.v3i32
+uint3 test_firstbithigh_uint3(uint3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_uint4
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbituhigh.v4i32
+uint4 test_firstbithigh_uint4(uint4 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong
+// CHECK: call i64 @llvm.[[TARGET]].firstbituhigh.i64
+uint64_t test_firstbithigh_ulong(uint64_t p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong2
+// CHECK: call <2 x i64> @llvm.[[TARGET]].firstbituhigh.v2i64
+uint64_t2 test_firstbithigh_ulong2(uint64_t2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong3
+// CHECK: call <3 x i64> @llvm.[[TARGET]].firstbituhigh.v3i64
+uint64_t3 test_firstbithigh_ulong3(uint64_t3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong4
+// CHECK: call <4 x i64> @llvm.[[TARGET]].firstbituhigh.v4i64
+uint64_t4 test_firstbithigh_ulong4(uint64_t4 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int
+// CHECK: call i32 @llvm.[[TARGET]].firstbitshigh.i32
+int test_firstbithigh_int(int p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int2
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbitshigh.v2i32
+int2 test_firstbithigh_int2(int2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int3
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbitshigh.v3i32
+int3 test_firstbithigh_int3(int3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int4
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i32
+int4 test_firstbithigh_int4(int4 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long
+// CHECK: call i64 @llvm.[[TARGET]].firstbitshigh.i64
+int64_t test_firstbithigh_long(int64_t p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long2
+// CHECK: call <2 x i64> @llvm.[[TARGET]].firstbitshigh.v2i64
+int64_t2 test_firstbithigh_long2(int64_t2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long3
+// CHECK: call <3 x i64> @llvm.[[TARGET]].firstbitshigh.v3i64
+int64_t3 test_firstbithigh_long3(int64_t3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long4
+// CHECK: call <4 x i64> @llvm.[[TARGET]].firstbitshigh.v4i64
+int64_t4 test_firstbithigh_long4(int64_t4 p0) {
+  return firstbithigh(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl
new file mode 100644
index 00000000000000..1912ab3ae806b3
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+int test_too_few_arg() {
+  return firstbithigh();
+  // expected-error@-1 {{no matching function for call to 'firstbithigh'}}
+}
+
+int test_too_many_arg(int p0) {
+  return firstbithigh(p0, p0);
+  // expected-error@-1 {{no matching function for call to 'firstbithigh'}}
+}
+
+double test_int_builtin(double p0) {
+  return firstbithigh(p0);
+  // expected-error@-1 {{call to 'firstbithigh' is ambiguous}}
+}
+
+double2 test_int_builtin_2(double2 p0) {
+  return __builtin_hlsl_elementwise_firstbithigh(p0);
+  // expected-error@-1 {{1st argument must be a vector of integers
+  // (was 'double2' (aka 'vector<double, 2>'))}}
+}
+
+float test_int_builtin_3(float p0) {
+  return __builtin_hlsl_elementwise_firstbithigh(p0);
+  // expected-error@-1 {{1st argument must be a vector of integers
+  // (was 'float')}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 1adfcdc9a1ed2f..ed46d88645f043 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -86,4 +86,6 @@ def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
 def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_firstbituhigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_firstbitshigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 0567efd8a5d7af..3ca5e0e2046289 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -84,4 +84,6 @@ let TargetPrefix = "spv" in {
     [IntrNoMem, Commutative] >;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
   def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
+  def int_spv_firstbituhigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+  def int_spv_firstbitshigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9aa0af3e3a6b17..4b315d53894174 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -564,6 +564,30 @@ def CBits :  DXILOp<31, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
+def FBH :  DXILOp<33, unary> {
+  let Doc = "Returns the location of the first set bit starting from "
+            "the highest order bit and working downward.";
+  let LLVMIntrinsic = int_dx_firstbituhigh;
+  let arguments = [OverloadTy];
+  let result = OverloadTy;
+  let overloads =
+      [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
+def FBSH :  DXILOp<34, unary> {
+  let Doc = "Returns the location of the first set bit from "
+            "the highest order bit based on the sign.";
+  let LLVMIntrinsic = int_dx_firstbitshigh;
+  let arguments = [OverloadTy];
+  let result = OverloadTy;
+  let overloads =
+      [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def FMax :  DXILOp<35, binary> {
   let Doc = "Float maximum. FMax(a,b) = a > b ? a : b";
   let LLVMIntrinsic = int_maxnum;
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 1a59f04b214042..21a4daf2cde6d8 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -18,6 +18,8 @@ bool llvm::DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
   switch (ID) {
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
+  case Intrinsic::dx_firstbituhigh:
+  case Intrinsic::dx_firstbitshigh:
     return true;
   default:
     return false;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3917ad180b87fc..85602a8b4f6f7a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -213,6 +213,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectPhi(Register ResVReg, const SPIRVType *ResType,
                  MachineInstr &I) const;
 
+  bool selectExtInst(Register ResVReg, const SPIRVType *RestType,
+		     MachineInstr &I, GL::GLSLExtInst GLInst) const;
   bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
                      MachineInstr &I, CL::OpenCLExtInst CLInst) const;
   bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
@@ -751,6 +753,14 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   }
 }
 
+bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
+					     const SPIRVType *ResType,
+					     MachineInstr &I,
+					     GL::GLSLExtInst GLInst) const {
+  return selectExtInst(ResVReg, ResType, I,
+		       {{SPIRV::InstructionSet::GLSL_std_450, GLInst}});
+}
+
 bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
                                              const SPIRVType *ResType,
                                              MachineInstr &I,
@@ -2515,6 +2525,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectExtInst(ResVReg, ResType, I, CL::rsqrt, GL::InverseSqrt);
   case Intrinsic::spv_sign:
     return selectSign(ResVReg, ResType, I);
+  case Intrinsic::spv_firstbituhigh:
+    return selectExtInst(ResVReg, ResType, I, GL::FindUMsb);
+  case Intrinsic::spv_firstbitshigh:
+    return selectExtInst(ResVReg, ResType, I, GL::FindSMsb);
   case Intrinsic::spv_lifetime_start:
   case Intrinsic::spv_lifetime_end: {
     unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
diff --git a/llvm/test/CodeGen/DirectX/firstbithigh.ll b/llvm/test/CodeGen/DirectX/firstbithigh.ll
new file mode 100644
index 00000000000000..4a97ad6226149f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/firstbithigh.ll
@@ -0,0 +1,91 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; Make sure dxil operation function calls for firstbithigh are generated for all integer types.
+
+define noundef i16 @test_firstbithigh_ushort(i16 noundef %a) {
+entry:
+; CHECK: call i16 @dx.op.unary.i16(i32 33, i16 %{{.*}})
+  %elt.firstbithigh = call i16 @llvm.dx.firstbituhigh.i16(i16 %a)
+  ret i16 %elt.firstbithigh
+}
+
+define noundef i16 @test_firstbithigh_short(i16 noundef %a) {
+entry:
+; CHECK: call i16 @dx.op.unary.i16(i32 34, i16 %{{.*}})
+  %elt.firstbithigh = call i16 @llvm.dx.firstbitshigh.i16(i16 %a)
+  ret i16 %elt.firstbithigh
+}
+
+define noundef i32 @test_firstbithigh_uint(i32 noundef %a) {
+entry:
+; CHECK: call i32 @dx.op.unary.i32(i32 33, i32 %{{.*}})
+  %elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i32(i32 %a)
+  ret i32 %elt.firstbithigh
+}
+
+define noundef i32 @test_firstbithigh_int(i32 noundef %a) {
+entry:
+; CHECK: call i32 @dx.op.unary.i32(i32 34, i32 %{{.*}})
+  %elt.firs...
[truncated]

Copy link

github-actions bot commented Oct 4, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@bogner
Copy link
Contributor

bogner commented Oct 5, 2024

It is not correct to limit firstbithigh to 32 bit integers. There are a couple of reasons that might make one think that it is, but we do in fact want/need to support int16 and int64 here.

  1. The HLSL docs only mention integer. Unfortunately that's true for most of these functions, as many of these docs were not updated as 64 and 16 bit types gained support in the language.
  2. DXC's SPIR-V implementation doesn't handle 16- and 64-bit for these functions. As noted in Incorrect SPIR-V generated when firstbithigh is used on a uint64_t microsoft/DirectXShaderCompiler#4702 we decided not to implement those overloads in DXC, but should reconsider this for clang.

Despite all of this, DXC does indeed support 16- and 64-bit overloads, as seen here: https://hlsl.godbolt.org/z/qbc17xz35

Note that the return type of the operation is not overloaded - all of the overloads of this function return uint.

@spall
Copy link
Contributor Author

spall commented Oct 5, 2024

It is not correct to limit firstbithigh to 32 bit integers. There are a couple of reasons that might make one think that it is, but we do in fact want/need to support int16 and int64 here.

2. DXC's SPIR-V implementation doesn't handle 16- and 64-bit for these functions. As noted in [Incorrect SPIR-V generated when `firstbithigh` is used on a `uint64_t` microsoft/DirectXShaderCompiler#4702](https://github.com/microsoft/DirectXShaderCompiler/issues/4702) we decided not to implement those overloads in DXC, but should reconsider this for clang.

Despite all of this, DXC does indeed support 16- and 64-bit overloads, as seen here: https://hlsl.godbolt.org/z/qbc17xz35

Note that the return type of the operation is not overloaded - all of the overloads of this function return uint.

https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#firstbithi
FirstBitHi does explicitly mention only 32 bit integers.

return true;
}

TheCall->setType(ArgTy);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If @bogner is right this line is a bit more complicated. We will have to construct the call type as a scalar or vector of the same size as ArgTy but with the element type of int32.

@spall
Copy link
Contributor Author

spall commented Oct 7, 2024

It is not correct to limit firstbithigh to 32 bit integers. There are a couple of reasons that might make one think that it is, but we do in fact want/need to support int16 and int64 here.

1. The [HLSL docs](https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/firstbithigh) only mention integer. Unfortunately that's true for most of these functions, as many of these docs were not updated as 64 and 16 bit types gained support in the language.

2. DXC's SPIR-V implementation doesn't handle 16- and 64-bit for these functions. As noted in [Incorrect SPIR-V generated when `firstbithigh` is used on a `uint64_t` microsoft/DirectXShaderCompiler#4702](https://github.com/microsoft/DirectXShaderCompiler/issues/4702) we decided not to implement those overloads in DXC, but should reconsider this for clang.

Despite all of this, DXC does indeed support 16- and 64-bit overloads, as seen here: https://hlsl.godbolt.org/z/qbc17xz35

Note that the return type of the operation is not overloaded - all of the overloads of this function return uint.

Why is the return type not overloaded? Isn't it overloaded for countbits which is a similar function.

@bogner
Copy link
Contributor

bogner commented Oct 10, 2024

Despite all of this, DXC does indeed support 16- and 64-bit overloads, as seen here: https://hlsl.godbolt.org/z/qbc17xz35
Note that the return type of the operation is not overloaded - all of the overloads of this function return uint.

Why is the return type not overloaded? Isn't it overloaded for countbits which is a similar function.

The return type for countbits is also not overloaded: https://hlsl.godbolt.org/z/oGYhsjEGc - in fact none of the operations that use the unaryBits operation class are overloaded on the return type.

The code we're generating with clang for countbits right now actually just looks like it's wrong. It doesn't match what DXC does at all: https://hlsl.godbolt.org/z/K8n3j5o3K

@spall
Copy link
Contributor Author

spall commented Oct 21, 2024

Despite all of this, DXC does indeed support 16- and 64-bit overloads, as seen here: https://hlsl.godbolt.org/z/qbc17xz35
Note that the return type of the operation is not overloaded - all of the overloads of this function return uint.

Why is the return type not overloaded? Isn't it overloaded for countbits which is a similar function.

The return type for countbits is also not overloaded: https://hlsl.godbolt.org/z/oGYhsjEGc - in fact none of the operations that use the unaryBits operation class are overloaded on the return type.

The code we're generating with clang for countbits right now actually just looks like it's wrong. It doesn't match what DXC does at all: https://hlsl.godbolt.org/z/K8n3j5o3K

Addressed here:
#113189

@@ -424,7 +424,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
LLT LLTy = LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
CurMF->getRegInfo().setRegClass(SpvVecConst, getRegClass(SpvType));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Since this was a bug fix. Would it have been better to do this as a seperate commit to capture why this was necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to put it in its own PR if that is what people want. I just left it here since it is such a small change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine there are other places in the SPIRV backend that might require a fix like this and want to give a future person a hint at how to fix it.

Alternatively create a second issue which refrences the problem you found. Add it to the pr description as fixed by this pr. drop a link to the issue in this comment section. That way if anyone searches github issues for something similar to this they will find your issue and see how you fixed it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created issue #113486 and mentioned it in this PR.

@farzonl
Copy link
Member

farzonl commented Oct 22, 2024

This looks really good. I had a few minor things and questions about either pulling things out of this pr and putting them into their own changes or future clean up to consolidate things. One more iteration and I'll be happy to approve.

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me once @farzonl's feedback is taken into account.

Comment on lines 2740 to 2922
unsigned i;
for (i = 0; i < count * 2; i += 2) {
MIB.addImm(i);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to declare the variable in the for-loop scope so it's clear we aren't planning on using it again later. Also it's a bit confusing to have two variables named i and I in scope here - might be best to rename the MachineInstr to MI and/or to use some other placeholder like J here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually re-used in the for-loop below this one. I will rename one of them.

// check if the high bits are == -1;
Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
Register NegOneReg =
GR.getOrCreateConstScalarOrVector(-1, I, ResType, TII, ZeroAsNull);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we just have NegOneReg be determined by isScalarRes Seems like we keep having to check the type sprinkled throughout when it would make more sense to do the check once and have guarantees as what the type is for the entire block. Doing it the current way makes me think a future implementer could unintendedly add buggy behavior related to vector when it is scalar and vice versa. Same for line 2779 isScalarRes ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure If I do that I can probably delete the helper function I added in SPIRVGlobalRegistry. I thought this style would be less ugly than a big block.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see the potential readability benefits. That said so much of the spirv backend is splitting scalar vs vector behaviors into their own little playgrounds that it feels like we are going against the grain here. Thats why we have OpSelectSISCond and OpSelectVIVCond when the SPIRV spec makes no function difference in the opcode.

Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
bool Result =
selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(), Opcode);
return Result & selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
Copy link
Member

@farzonl farzonl Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this works, don't use a bitwise & with a bool do &&.

@spall
Copy link
Contributor Author

spall commented Nov 4, 2024

Waiting on #114482

@spall spall merged commit fb90733 into llvm:main Nov 6, 2024
7 of 9 checks passed
spall pushed a commit that referenced this pull request Nov 21, 2024
As a follow up request from
#111082 (comment)
the following non functional changes have been make
- `selectNAryOpWithSrcs` has been renamed to `selectOpWithSrcs`
- Calls to `selectUnOpWithSrc` have been replaced with
`selectOpWithSrcs`
- `selectUnOpWithSrc`  has been deleted
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:SPIR-V backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Archived in project
6 participants