Skip to content

[clang][HLSL] Add WaveIsFirstLane() intrinsic #103299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 4, 2024

Conversation

Keenuts
Copy link
Contributor

@Keenuts Keenuts commented Aug 13, 2024

This commits add the WaveIsFirstLane() hlsl intrinsinc. This intrinsic uses the convergence intrinsincs for the SPIR-V backend. On the DXIL side, I'm not sure what the strategy is for convergence, so I implemented that like in DXC: a normal builtin function.

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. backend:DirectX HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Aug 13, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2024

@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-backend-spir-v
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-backend-x86

Author: Nathan Gauër (Keenuts)

Changes

This commits add the WaveIsFirstLane() hlsl intrinsinc. This intrinsic uses the convergence intrinsincs for the SPIR-V backend. On the DXIL side, I'm not sure what the strategy is. (DXC didn't used convergence intrinsincs for DXIL).


Full diff: https://github.com/llvm/llvm-project/pull/103299.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+4)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+4)
  • (added) clang/test/CodeGenHLSL/builtins/wave_is_first_lane.hlsl (+34)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+2)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+2)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+9)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVStripConvergentIntrinsics.cpp (+31-22)
  • (added) llvm/test/CodeGen/DirectX/wave_is_first_lane.ll (+13)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll (+27)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..b047669ff3c53f 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4677,6 +4677,12 @@ def HLSLWaveGetLaneIndex : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "unsigned int()";
 }
 
+def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_is_first_lane"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "bool()";
+}
+
 def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_clamp"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7fe80b0cbdfbfa..0b96fe9d29b595 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18660,6 +18660,10 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
         {}, false, true));
   }
+  case Builtin::BI__builtin_hlsl_wave_is_first_lane: {
+    Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
+    return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 527e73a0e21fc4..d856b03debc063 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -79,6 +79,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e35a5262f92809..d7b5d8c40a0889 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1725,5 +1725,9 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_get_lane_index)
 __attribute__((convergent)) uint WaveGetLaneIndex();
 
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
+__attribute__((convergent)) bool WaveIsFirstLane();
+
 } // namespace hlsl
 #endif //_HLSL_HLSL_INTRINSICS_H_
diff --git a/clang/test/CodeGenHLSL/builtins/wave_is_first_lane.hlsl b/clang/test/CodeGenHLSL/builtins/wave_is_first_lane.hlsl
new file mode 100644
index 00000000000000..18860c321eb912
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/wave_is_first_lane.hlsl
@@ -0,0 +1,34 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+
+[numthreads(1, 1, 1)]
+void main() {
+// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+
+// CHECK-SPIRV: %[[#loop_tok:]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %[[#entry_tok]]) ]
+  while (true) {
+
+// CHECK-DXIL:  %[[#]] = call i1 @llvm.dx.wave.is.first.lane()
+// CHECK-SPIRV: %[[#]] = call i1 @llvm.spv.wave.is.first.lane()
+// CHECK-SPIRV-SAME: [ "convergencectrl"(token %[[#loop_tok]]) ]
+    if (WaveIsFirstLane()) {
+      break;
+    }
+  }
+
+// CHECK-DXIL:  %[[#]] = call i1 @llvm.dx.wave.is.first.lane()
+// CHECK-SPIRV: %[[#]] = call i1 @llvm.spv.wave.is.first.lane()
+// CHECK-SPIRV-SAME: [ "convergencectrl"(token %[[#entry_tok]]) ]
+  if (WaveIsFirstLane()) {
+    return;
+  }
+}
+
+// CHECK-DXIL:  i1 @llvm.dx.wave.is.first.lane() #[[#attr:]]
+// CHECK-SPIRV: i1 @llvm.spv.wave.is.first.lane() #[[#attr:]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..1eea4d25c0ac50 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -60,4 +60,6 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_rcp  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+
+def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3f77ef6bfcdbe2..ea8b58caa6b193 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -65,4 +65,6 @@ let TargetPrefix = "spv" in {
     [IntrNoMem, IntrWillReturn] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
+
+  def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 67015cff78a79a..a4e7aeae883fbc 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -703,3 +703,12 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
+
+def WaveIsFirstLane :  DXILOp<110, waveIsFirstLane> {
+  let Doc = "returns 1 for the first lane in the wave";
+  let LLVMIntrinsic = int_dx_wave_is_first_lane;
+  let arguments = [];
+  let result = i1Ty;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index c55235a04a607f..d014d90e31fb9d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2132,6 +2132,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
       Size = 0;
     BuildMI(BB, I, I.getDebugLoc(), TII.get(Op)).addUse(PtrReg).addImm(Size);
   } break;
+  case Intrinsic::spv_wave_is_first_lane: {
+    SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+    return BuildMI(BB, I, I.getDebugLoc(),
+                   TII.get(SPIRV::OpGroupNonUniformElect))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(ResType))
+        .addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
+  }
   default: {
     std::string DiagMsg;
     raw_string_ostream OS(DiagMsg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVStripConvergentIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVStripConvergentIntrinsics.cpp
index dca30535acfa1a..b632d784977678 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStripConvergentIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStripConvergentIntrinsics.cpp
@@ -41,31 +41,40 @@ class SPIRVStripConvergentIntrinsics : public FunctionPass {
   virtual bool runOnFunction(Function &F) override {
     DenseSet<Instruction *> ToRemove;
 
+    // Is the instruction is a convergent intrinsic, add it to kill-list and
+    // returns true. Returns false otherwise.
+    auto CleanupIntrinsic = [&](IntrinsicInst *II) {
+      if (II->getIntrinsicID() != Intrinsic::experimental_convergence_entry &&
+          II->getIntrinsicID() != Intrinsic::experimental_convergence_loop &&
+          II->getIntrinsicID() != Intrinsic::experimental_convergence_anchor)
+        return false;
+
+      II->replaceAllUsesWith(UndefValue::get(II->getType()));
+      ToRemove.insert(II);
+      return true;
+    };
+
+    // Replace the given CallInst by a similar CallInst with no convergencectrl
+    // attribute.
+    auto CleanupCall = [&](CallInst *CI) {
+      auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
+      if (!OB.has_value())
+        return;
+
+      auto *NewCall = CallBase::removeOperandBundle(
+          CI, LLVMContext::OB_convergencectrl, CI);
+      NewCall->copyMetadata(*CI);
+      CI->replaceAllUsesWith(NewCall);
+      ToRemove.insert(CI);
+    };
+
     for (BasicBlock &BB : F) {
       for (Instruction &I : BB) {
-        if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
-          if (II->getIntrinsicID() !=
-                  Intrinsic::experimental_convergence_entry &&
-              II->getIntrinsicID() !=
-                  Intrinsic::experimental_convergence_loop &&
-              II->getIntrinsicID() !=
-                  Intrinsic::experimental_convergence_anchor) {
+        if (auto *II = dyn_cast<IntrinsicInst>(&I))
+          if (CleanupIntrinsic(II))
             continue;
-          }
-
-          II->replaceAllUsesWith(UndefValue::get(II->getType()));
-          ToRemove.insert(II);
-        } else if (auto *CI = dyn_cast<CallInst>(&I)) {
-          auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
-          if (!OB.has_value())
-            continue;
-
-          auto *NewCall = CallBase::removeOperandBundle(
-              CI, LLVMContext::OB_convergencectrl, CI);
-          NewCall->copyMetadata(*CI);
-          CI->replaceAllUsesWith(NewCall);
-          ToRemove.insert(CI);
-        }
+        if (auto *CI = dyn_cast<CallInst>(&I))
+          CleanupCall(CI);
       }
     }
 
diff --git a/llvm/test/CodeGen/DirectX/wave_is_first_lane.ll b/llvm/test/CodeGen/DirectX/wave_is_first_lane.ll
new file mode 100644
index 00000000000000..b9a63bb0f14722
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/wave_is_first_lane.ll
@@ -0,0 +1,13 @@
+; RUN: opt -S  -dxil-op-lower  -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define void @main() #0 {
+entry:
+; CHECK: call i1 @dx.op.waveIsFirstLane.i1(i32 110)
+  %0 = call i1 @llvm.dx.wave.is.first.lane()
+  ret void
+}
+
+declare i1 @llvm.dx.wave.is.first.lane() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent nocallback nofree nosync nounwind willreturn }
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll
new file mode 100644
index 00000000000000..94597b37cc7eb1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll
@@ -0,0 +1,27 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spirv-unknown-vulkan-compute"
+
+; CHECK-DAG:   %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#uint_3:]] = OpConstant %[[#uint]] 3
+; CHECK-DAG:   %[[#bool:]] = OpTypeBool
+
+define spir_func void @main() #0 {
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+; CHECK:   %[[#]] = OpGroupNonUniformElect %[[#bool]] %[[#uint_3]]
+  %1 = call i1 @llvm.spv.wave.is.first.lane() [ "convergencectrl"(token %0) ]
+  ret void
+}
+
+declare i32 @__hlsl_wave_get_lane_index() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}

@farzonl
Copy link
Member

farzonl commented Aug 13, 2024

We have this work tracked here: #99158

there should be some dxil specific tasks.

@Keenuts
Copy link
Contributor Author

Keenuts commented Aug 14, 2024

We have this work tracked here: #99158

there should be some dxil specific tasks.

Seems like most boxes would be checked by this PR, except Sema checks:

  • what kind of Sema checks would be required for this one?
    Also, the intrinsic name in the issue is using camel case vs snake case for this PR. But seems like existing ones like thread_id are using snake case (same thing on the SPIR-V backend). So shouldn't we remain consistent?

@farzonl
Copy link
Member

farzonl commented Aug 19, 2024

closes #99158

@farzonl farzonl linked an issue Aug 19, 2024 that may be closed by this pull request
12 tasks
@farzonl
Copy link
Member

farzonl commented Aug 19, 2024

We have this work tracked here: #99158
there should be some dxil specific tasks.

Seems like most boxes would be checked by this PR, except Sema checks:

  • what kind of Sema checks would be required for this one?
    Also, the intrinsic name in the issue is using camel case vs snake case for this PR. But seems like existing ones like thread_id are using snake case (same thing on the SPIR-V backend). So shouldn't we remain consistent?

Good point, the only thing I could think to add would be TheCall->setType(BoolTy); But I'm not sure if its worth it.

@Keenuts
Copy link
Contributor Author

Keenuts commented Aug 22, 2024

  • what kind of Sema checks would be required for this one?

Good point, the only thing I could think to add would be TheCall->setType(BoolTy); But I'm not sure if its worth it.

Not sure I follow, like a sema test checking the AST call is a bool, or to set the type after EmitRuntimeCall?

This commits add the WaveIsLaneFirst() hlsl intrinsinc.
This intrinsic uses the convergence intrinsincs for the SPIR-V backend.
On the DXIL side, I'm not sure what the strategy is. (DXC didn't used
convergence intrinsincs for DXIL).

Signed-off-by: Nathan Gauër <[email protected]>
@Keenuts
Copy link
Contributor Author

Keenuts commented Sep 4, 2024

Merging to unblock the structurizer work.
Let me know if you had a specific SEMA check in mind to add in the end!

@Keenuts Keenuts merged commit afb6daf into llvm:main Sep 4, 2024
9 checks passed
@Keenuts Keenuts deleted the waveisfirstlane branch September 4, 2024 09:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:SPIR-V backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

Implement the WaveIsFirstLane HLSL Function
5 participants