Skip to content

[HLSL][DXIL] Implement asdouble intrinsic #114847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 22, 2024
Merged

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Nov 4, 2024

- define intrinsic as builtin in Builtins.td
- link intrinsic in hlsl_intrinsics.h
- add semantic analysis to SemaHLSL.cpp
- lower to `llvm` or a `dx` intrinsic when applicable in CGBuiltin.cpp
- define DXIL intrinsic in IntrinsicsDirectX.td
- add DXIL op and mapping in DXIL.td
- enable scalarization of intrinsic

- add basic sema checking to asdouble-errors.hlsl

Resolves #99081

@inbelic inbelic changed the title Inbelic/as double [HLSL][DXIL] Implement asdouble intrinsic Nov 4, 2024
@inbelic
Copy link
Contributor Author

inbelic commented Nov 4, 2024

Please ignore the first commit when reviewing. It is separately reviewed here and this commit depends on it: #114849

@inbelic inbelic marked this pull request as ready for review November 4, 2024 23:08
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. backend:DirectX HLSL HLSL Language Support llvm:ir llvm:analysis llvm:transforms labels Nov 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2024

@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Finn Plummer (inbelic)

Changes
- define intrinsic as builtin in Builtins.td
- link intrinsic in hlsl_intrinsics.h
- add semantic analysis to SemaHLSL.cpp
- lower to `llvm` or a `dx` intrinsic when applicable in CGBuiltin.cpp
- define DXIL intrinsic in IntrinsicsDirectX.td
- add DXIL op and mapping in DXIL.td
- enable scalarization of intrinsic

- add basic sema checking to asdouble-errors.hlsl

Resolves #99081


Full diff: https://github.com/llvm/llvm-project/pull/114847.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+37)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+18)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+17)
  • (added) clang/test/CodeGenHLSL/builtins/asdouble.hlsl (+37)
  • (added) clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl (+16)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+14)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+6)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+6)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+6)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+10)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+11)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h (+3)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+6-3)
  • (added) llvm/test/CodeGen/DirectX/asdouble.ll (+22)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..37ef0bf7324ffb 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4744,6 +4744,12 @@ def HLSLAny : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "bool(...)";
 }
 
+def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_asdouble"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 28f28c70b5ae52..85ad203b50c7ab 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18634,6 +18634,41 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
   return RT.getUDotIntrinsic();
 }
 
+Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) {
+  assert((E->getArg(0)->getType()->hasUnsignedIntegerRepresentation() &&
+          E->getArg(1)->getType()->hasUnsignedIntegerRepresentation()) &&
+         "asdouble operands types mismatch");
+  Value *OpLowBits = CGF.EmitScalarExpr(E->getArg(0));
+  Value *OpHighBits = CGF.EmitScalarExpr(E->getArg(1));
+
+  llvm::Type *ResultType = CGF.DoubleTy;
+  int N = 1;
+  if (auto *VTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {
+    N = VTy->getNumElements();
+    ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N);
+  }
+
+  if (CGF.CGM.getTarget().getTriple().isDXIL())
+    return CGF.Builder.CreateIntrinsic(
+        /*ReturnType=*/ResultType, Intrinsic::dx_asdouble,
+        ArrayRef<Value *>{OpLowBits, OpHighBits}, nullptr, "hlsl.asdouble");
+
+  if (!E->getArg(0)->getType()->isVectorType()) {
+    OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits);
+    OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits);
+  }
+
+  llvm::SmallVector<int> Mask;
+  for (int i = 0; i < N; i++) {
+    Mask.push_back(i);
+    Mask.push_back(i + N);
+  }
+
+  Value *BitVec = CGF.Builder.CreateShuffleVector(OpLowBits, OpHighBits, Mask);
+
+  return CGF.Builder.CreateBitCast(BitVec, ResultType);
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -18655,6 +18690,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         CGM.getHLSLRuntime().getAnyIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,
         "hlsl.any");
   }
+  case Builtin::BI__builtin_hlsl_asdouble:
+    return handleAsDoubleBuiltin(*this, E);
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
     Value *OpX = EmitScalarExpr(E->getArg(0));
     Value *OpMin = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..7dd9c136d1d3f4 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -361,6 +361,24 @@ bool any(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_any)
 bool any(double4);
 
+//===----------------------------------------------------------------------===//
+// asdouble builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn double asdouble(uint LowBits, uint HighBits)
+/// \brief Reinterprets a cast value (two 32-bit values) into a double.
+/// \param LowBits The low 32-bit pattern of the input value.
+/// \param HighBits The high 32-bit pattern of the input value.
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double asdouble(uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double2 asdouble(uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double3 asdouble(uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double4 asdouble(uint4, uint4);
+
 //===----------------------------------------------------------------------===//
 // asfloat builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index c6627b0e993226..69de0294cb7c7c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1870,6 +1870,23 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_asdouble: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+    if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
+      return true;
+
+    // Set the return type to be a scalar or vector of same length of double
+    ASTContext &Ctx = SemaRef.getASTContext();
+    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
+
+    QualType ResultType =
+        VTy ? Ctx.getVectorType(Ctx.DoubleTy, VTy->getNumElements(),
+                                VTy->getVectorKind())
+            : Ctx.DoubleTy;
+    TheCall->setType(ResultType);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/asdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
new file mode 100644
index 00000000000000..f1c31107cdcad6
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
@@ -0,0 +1,37 @@
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPV
+
+// Test lowering of asdouble expansion to shuffle/bitcast and splat when required
+
+// CHECK-LABEL: test_uint
+double test_uint(uint low, uint high) {
+  // CHECK-SPV: %[[LOW_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK-SPV: %[[LOW_SHUFFLE:.*]] = shufflevector <1 x i32> %[[LOW_INSERT]], {{.*}} zeroinitializer
+  // CHECK-SPV: %[[HIGH_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK-SPV: %[[HIGH_SHUFFLE:.*]] = shufflevector <1 x i32> %[[HIGH_INSERT]], {{.*}} zeroinitializer
+
+  // CHECK-SPV:      %[[SHUFFLE0:.*]] = shufflevector <1 x i32> %[[LOW_SHUFFLE]], <1 x i32> %[[HIGH_SHUFFLE]],
+  // CHECK-SPV-SAME: {{.*}} <i32 0, i32 1>
+  // CHECK-SPV:      bitcast <2 x i32> %[[SHUFFLE0]] to double
+
+  // CHECK-DXIL: call double @llvm.dx.asdouble.i32
+  return asdouble(low, high);
+}
+
+// CHECK-DXIL: declare double @llvm.dx.asdouble.i32
+
+// CHECK-LABEL: test_vuint
+double3 test_vuint(uint3 low, uint3 high) {
+  // CHECK-SPV:      %[[SHUFFLE1:.*]] = shufflevector
+  // CHECK-SPV-SAME: {{.*}} <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
+  // CHECK-SPV:      bitcast <6 x i32> %[[SHUFFLE1]] to <3 x double>
+
+  // CHECK-DXIL: call <3 x double> @llvm.dx.asdouble.v3i32
+  return asdouble(low, high);
+}
+
+// CHECK-DXIL: declare <3 x double> @llvm.dx.asdouble.v3i32
diff --git a/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl
new file mode 100644
index 00000000000000..c6b57c76a1e2b3
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl
@@ -0,0 +1,16 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+double test_too_few_arg() {
+  return __builtin_hlsl_asdouble();
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+double test_too_few_arg_1(uint p0) {
+  return __builtin_hlsl_asdouble(p0);
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+double test_too_many_arg(uint p0) {
+  return __builtin_hlsl_asdouble(p0, p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0459941fe05cdc..796b4011d71c0c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -896,6 +896,10 @@ class TargetTransformInfo {
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const;
+
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
@@ -1969,6 +1973,9 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
+  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      unsigned ScalarOpdIdx,
+                                                      bool Default) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2530,6 +2537,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) override {
+    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                       Default);
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index dbdfb4d8cdfa32..42d1082cf4d9eb 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -392,6 +392,12 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index db3b5cddd7c1c3..b2841e778947dd 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -798,6 +798,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
   InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
                                            bool Extract,
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..904607a98aa86e 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -40,6 +40,7 @@ def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
 
 def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
+def int_dx_asdouble : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [llvm_anyint_ty, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a47462b61e03b2..bf9733f971fdac 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -612,6 +612,12 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
+bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) const {
+  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                         Default);
+}
+
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..6a1edb9f6debe5 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -779,6 +779,16 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
+
+def MakeDouble :  DXILOp<101, makeDouble> {
+  let Doc = "creates a double value";
+  let LLVMIntrinsic = int_dx_asdouble;
+  let arguments = [Int32Ty, Int32Ty];
+  let result = DoubleTy;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def AnnotateHandle : DXILOp<217, annotateHandle> {
   let Doc = "annotate handle with resource properties";
   let arguments = [HandleTy, ResPropsTy];
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 231afd8ae3eeaf..a115a664209445 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -25,12 +25,23 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
+bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
+  switch (ID) {
+  case Intrinsic::dx_asdouble:
+    return ScalarOpdIdx == 0;
+  default:
+    return Default;
+  }
+}
+
 bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
   case Intrinsic::dx_wave_readlane:
+  case Intrinsic::dx_asdouble:
   case Intrinsic::dx_splitdouble:
     return true;
   default:
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index 30b57ed97d6370..ff82b7404ca58a 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -37,6 +37,9 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx);
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default);
 };
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 772f4c6c35ddec..719dce704872ae 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -727,7 +727,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 
   SmallVector<llvm::Type *, 3> Tys;
   // Add return type if intrinsic is overloaded on it.
-  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+  if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+          ID, -1, isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)))
     Tys.push_back(VS->SplitTy);
 
   if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +768,15 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       }
 
       Scattered[I] = scatter(&CI, OpI, *OpVS);
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I))) {
         OverloadIdx[I] = Tys.size();
         Tys.push_back(OpVS->SplitTy);
       }
     } else {
       ScalarOperands[I] = OpI;
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I)))
         Tys.push_back(OpI->getType());
     }
   }
diff --git a/llvm/test/CodeGen/DirectX/asdouble.ll b/llvm/test/CodeGen/DirectX/asdouble.ll
new file mode 100644
index 00000000000000..6a581d69eb7e9d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/asdouble.ll
@@ -0,0 +1,22 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; Test that for scalar and vector inputs, asdouble maps down to the makeDouble
+; DirectX op
+
+define noundef double @asdouble_scalar(i32 noundef %low, i32 noundef %high) {
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low, i32 %high)
+  %ret = call double @llvm.dx.asdouble.i32(i32 %low, i32 %high)
+  ret double %ret
+}
+
+declare double @llvm.dx.asdouble.i32(i32, i32)
+
+define noundef <3 x double> @asdouble_vec(<3 x i32> noundef %low, <3 x i32> noundef %high) {
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i0, i32 %high.i0)
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i1, i32 %high.i1)
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i2, i32 %high.i2)
+  %ret = call <3 x double> @llvm.dx.asdouble.v3i32(<3 x i32> %low, <3 x i32> %high)
+  ret <3 x double> %ret
+}
+
+declare <3 x double> @llvm.dx.asdouble.v3i32(<3 x i32>, <3 x i32>)

@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Finn Plummer (inbelic)

Changes
- define intrinsic as builtin in Builtins.td
- link intrinsic in hlsl_intrinsics.h
- add semantic analysis to SemaHLSL.cpp
- lower to `llvm` or a `dx` intrinsic when applicable in CGBuiltin.cpp
- define DXIL intrinsic in IntrinsicsDirectX.td
- add DXIL op and mapping in DXIL.td
- enable scalarization of intrinsic

- add basic sema checking to asdouble-errors.hlsl

Resolves #99081


Full diff: https://github.com/llvm/llvm-project/pull/114847.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+37)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+18)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+17)
  • (added) clang/test/CodeGenHLSL/builtins/asdouble.hlsl (+37)
  • (added) clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl (+16)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+14)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+6)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+6)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+6)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+10)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+11)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h (+3)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+6-3)
  • (added) llvm/test/CodeGen/DirectX/asdouble.ll (+22)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..37ef0bf7324ffb 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4744,6 +4744,12 @@ def HLSLAny : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "bool(...)";
 }
 
+def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_asdouble"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 28f28c70b5ae52..85ad203b50c7ab 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18634,6 +18634,41 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
   return RT.getUDotIntrinsic();
 }
 
+Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) {
+  assert((E->getArg(0)->getType()->hasUnsignedIntegerRepresentation() &&
+          E->getArg(1)->getType()->hasUnsignedIntegerRepresentation()) &&
+         "asdouble operands types mismatch");
+  Value *OpLowBits = CGF.EmitScalarExpr(E->getArg(0));
+  Value *OpHighBits = CGF.EmitScalarExpr(E->getArg(1));
+
+  llvm::Type *ResultType = CGF.DoubleTy;
+  int N = 1;
+  if (auto *VTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {
+    N = VTy->getNumElements();
+    ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N);
+  }
+
+  if (CGF.CGM.getTarget().getTriple().isDXIL())
+    return CGF.Builder.CreateIntrinsic(
+        /*ReturnType=*/ResultType, Intrinsic::dx_asdouble,
+        ArrayRef<Value *>{OpLowBits, OpHighBits}, nullptr, "hlsl.asdouble");
+
+  if (!E->getArg(0)->getType()->isVectorType()) {
+    OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits);
+    OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits);
+  }
+
+  llvm::SmallVector<int> Mask;
+  for (int i = 0; i < N; i++) {
+    Mask.push_back(i);
+    Mask.push_back(i + N);
+  }
+
+  Value *BitVec = CGF.Builder.CreateShuffleVector(OpLowBits, OpHighBits, Mask);
+
+  return CGF.Builder.CreateBitCast(BitVec, ResultType);
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -18655,6 +18690,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         CGM.getHLSLRuntime().getAnyIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,
         "hlsl.any");
   }
+  case Builtin::BI__builtin_hlsl_asdouble:
+    return handleAsDoubleBuiltin(*this, E);
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
     Value *OpX = EmitScalarExpr(E->getArg(0));
     Value *OpMin = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..7dd9c136d1d3f4 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -361,6 +361,24 @@ bool any(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_any)
 bool any(double4);
 
+//===----------------------------------------------------------------------===//
+// asdouble builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn double asdouble(uint LowBits, uint HighBits)
+/// \brief Reinterprets a cast value (two 32-bit values) into a double.
+/// \param LowBits The low 32-bit pattern of the input value.
+/// \param HighBits The high 32-bit pattern of the input value.
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double asdouble(uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double2 asdouble(uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double3 asdouble(uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double4 asdouble(uint4, uint4);
+
 //===----------------------------------------------------------------------===//
 // asfloat builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index c6627b0e993226..69de0294cb7c7c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1870,6 +1870,23 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_asdouble: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+    if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
+      return true;
+
+    // Set the return type to be a scalar or vector of same length of double
+    ASTContext &Ctx = SemaRef.getASTContext();
+    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
+
+    QualType ResultType =
+        VTy ? Ctx.getVectorType(Ctx.DoubleTy, VTy->getNumElements(),
+                                VTy->getVectorKind())
+            : Ctx.DoubleTy;
+    TheCall->setType(ResultType);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/asdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
new file mode 100644
index 00000000000000..f1c31107cdcad6
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
@@ -0,0 +1,37 @@
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPV
+
+// Test lowering of asdouble expansion to shuffle/bitcast and splat when required
+
+// CHECK-LABEL: test_uint
+double test_uint(uint low, uint high) {
+  // CHECK-SPV: %[[LOW_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK-SPV: %[[LOW_SHUFFLE:.*]] = shufflevector <1 x i32> %[[LOW_INSERT]], {{.*}} zeroinitializer
+  // CHECK-SPV: %[[HIGH_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK-SPV: %[[HIGH_SHUFFLE:.*]] = shufflevector <1 x i32> %[[HIGH_INSERT]], {{.*}} zeroinitializer
+
+  // CHECK-SPV:      %[[SHUFFLE0:.*]] = shufflevector <1 x i32> %[[LOW_SHUFFLE]], <1 x i32> %[[HIGH_SHUFFLE]],
+  // CHECK-SPV-SAME: {{.*}} <i32 0, i32 1>
+  // CHECK-SPV:      bitcast <2 x i32> %[[SHUFFLE0]] to double
+
+  // CHECK-DXIL: call double @llvm.dx.asdouble.i32
+  return asdouble(low, high);
+}
+
+// CHECK-DXIL: declare double @llvm.dx.asdouble.i32
+
+// CHECK-LABEL: test_vuint
+double3 test_vuint(uint3 low, uint3 high) {
+  // CHECK-SPV:      %[[SHUFFLE1:.*]] = shufflevector
+  // CHECK-SPV-SAME: {{.*}} <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
+  // CHECK-SPV:      bitcast <6 x i32> %[[SHUFFLE1]] to <3 x double>
+
+  // CHECK-DXIL: call <3 x double> @llvm.dx.asdouble.v3i32
+  return asdouble(low, high);
+}
+
+// CHECK-DXIL: declare <3 x double> @llvm.dx.asdouble.v3i32
diff --git a/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl
new file mode 100644
index 00000000000000..c6b57c76a1e2b3
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl
@@ -0,0 +1,16 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+double test_too_few_arg() {
+  return __builtin_hlsl_asdouble();
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+double test_too_few_arg_1(uint p0) {
+  return __builtin_hlsl_asdouble(p0);
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+double test_too_many_arg(uint p0) {
+  return __builtin_hlsl_asdouble(p0, p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0459941fe05cdc..796b4011d71c0c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -896,6 +896,10 @@ class TargetTransformInfo {
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const;
+
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
@@ -1969,6 +1973,9 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
+  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      unsigned ScalarOpdIdx,
+                                                      bool Default) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2530,6 +2537,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) override {
+    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                       Default);
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index dbdfb4d8cdfa32..42d1082cf4d9eb 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -392,6 +392,12 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index db3b5cddd7c1c3..b2841e778947dd 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -798,6 +798,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
   InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
                                            bool Extract,
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..904607a98aa86e 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -40,6 +40,7 @@ def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
 
 def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
+def int_dx_asdouble : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [llvm_anyint_ty, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a47462b61e03b2..bf9733f971fdac 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -612,6 +612,12 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
+bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) const {
+  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                         Default);
+}
+
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..6a1edb9f6debe5 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -779,6 +779,16 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
+
+def MakeDouble :  DXILOp<101, makeDouble> {
+  let Doc = "creates a double value";
+  let LLVMIntrinsic = int_dx_asdouble;
+  let arguments = [Int32Ty, Int32Ty];
+  let result = DoubleTy;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def AnnotateHandle : DXILOp<217, annotateHandle> {
   let Doc = "annotate handle with resource properties";
   let arguments = [HandleTy, ResPropsTy];
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 231afd8ae3eeaf..a115a664209445 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -25,12 +25,23 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
+bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
+  switch (ID) {
+  case Intrinsic::dx_asdouble:
+    return ScalarOpdIdx == 0;
+  default:
+    return Default;
+  }
+}
+
 bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
   case Intrinsic::dx_wave_readlane:
+  case Intrinsic::dx_asdouble:
   case Intrinsic::dx_splitdouble:
     return true;
   default:
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index 30b57ed97d6370..ff82b7404ca58a 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -37,6 +37,9 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx);
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default);
 };
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 772f4c6c35ddec..719dce704872ae 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -727,7 +727,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 
   SmallVector<llvm::Type *, 3> Tys;
   // Add return type if intrinsic is overloaded on it.
-  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+  if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+          ID, -1, isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)))
     Tys.push_back(VS->SplitTy);
 
   if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +768,15 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       }
 
       Scattered[I] = scatter(&CI, OpI, *OpVS);
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I))) {
         OverloadIdx[I] = Tys.size();
         Tys.push_back(OpVS->SplitTy);
       }
     } else {
       ScalarOperands[I] = OpI;
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I)))
         Tys.push_back(OpI->getType());
     }
   }
diff --git a/llvm/test/CodeGen/DirectX/asdouble.ll b/llvm/test/CodeGen/DirectX/asdouble.ll
new file mode 100644
index 00000000000000..6a581d69eb7e9d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/asdouble.ll
@@ -0,0 +1,22 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; Test that for scalar and vector inputs, asdouble maps down to the makeDouble
+; DirectX op
+
+define noundef double @asdouble_scalar(i32 noundef %low, i32 noundef %high) {
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low, i32 %high)
+  %ret = call double @llvm.dx.asdouble.i32(i32 %low, i32 %high)
+  ret double %ret
+}
+
+declare double @llvm.dx.asdouble.i32(i32, i32)
+
+define noundef <3 x double> @asdouble_vec(<3 x i32> noundef %low, <3 x i32> noundef %high) {
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i0, i32 %high.i0)
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i1, i32 %high.i1)
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i2, i32 %high.i2)
+  %ret = call <3 x double> @llvm.dx.asdouble.v3i32(<3 x i32> %low, <3 x i32> %high)
+  ret <3 x double> %ret
+}
+
+declare <3 x double> @llvm.dx.asdouble.v3i32(<3 x i32>, <3 x i32>)

@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2024

@llvm/pr-subscribers-backend-directx

Author: Finn Plummer (inbelic)

Changes
- define intrinsic as builtin in Builtins.td
- link intrinsic in hlsl_intrinsics.h
- add semantic analysis to SemaHLSL.cpp
- lower to `llvm` or a `dx` intrinsic when applicable in CGBuiltin.cpp
- define DXIL intrinsic in IntrinsicsDirectX.td
- add DXIL op and mapping in DXIL.td
- enable scalarization of intrinsic

- add basic sema checking to asdouble-errors.hlsl

Resolves #99081


Full diff: https://github.com/llvm/llvm-project/pull/114847.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+37)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+18)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+17)
  • (added) clang/test/CodeGenHLSL/builtins/asdouble.hlsl (+37)
  • (added) clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl (+16)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+14)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+6)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+6)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+6)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+10)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+11)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h (+3)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+6-3)
  • (added) llvm/test/CodeGen/DirectX/asdouble.ll (+22)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..37ef0bf7324ffb 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4744,6 +4744,12 @@ def HLSLAny : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "bool(...)";
 }
 
+def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_asdouble"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 28f28c70b5ae52..85ad203b50c7ab 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18634,6 +18634,41 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
   return RT.getUDotIntrinsic();
 }
 
+Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) {
+  assert((E->getArg(0)->getType()->hasUnsignedIntegerRepresentation() &&
+          E->getArg(1)->getType()->hasUnsignedIntegerRepresentation()) &&
+         "asdouble operands types mismatch");
+  Value *OpLowBits = CGF.EmitScalarExpr(E->getArg(0));
+  Value *OpHighBits = CGF.EmitScalarExpr(E->getArg(1));
+
+  llvm::Type *ResultType = CGF.DoubleTy;
+  int N = 1;
+  if (auto *VTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {
+    N = VTy->getNumElements();
+    ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N);
+  }
+
+  if (CGF.CGM.getTarget().getTriple().isDXIL())
+    return CGF.Builder.CreateIntrinsic(
+        /*ReturnType=*/ResultType, Intrinsic::dx_asdouble,
+        ArrayRef<Value *>{OpLowBits, OpHighBits}, nullptr, "hlsl.asdouble");
+
+  if (!E->getArg(0)->getType()->isVectorType()) {
+    OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits);
+    OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits);
+  }
+
+  llvm::SmallVector<int> Mask;
+  for (int i = 0; i < N; i++) {
+    Mask.push_back(i);
+    Mask.push_back(i + N);
+  }
+
+  Value *BitVec = CGF.Builder.CreateShuffleVector(OpLowBits, OpHighBits, Mask);
+
+  return CGF.Builder.CreateBitCast(BitVec, ResultType);
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -18655,6 +18690,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         CGM.getHLSLRuntime().getAnyIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,
         "hlsl.any");
   }
+  case Builtin::BI__builtin_hlsl_asdouble:
+    return handleAsDoubleBuiltin(*this, E);
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
     Value *OpX = EmitScalarExpr(E->getArg(0));
     Value *OpMin = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..7dd9c136d1d3f4 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -361,6 +361,24 @@ bool any(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_any)
 bool any(double4);
 
+//===----------------------------------------------------------------------===//
+// asdouble builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn double asdouble(uint LowBits, uint HighBits)
+/// \brief Reinterprets a cast value (two 32-bit values) into a double.
+/// \param LowBits The low 32-bit pattern of the input value.
+/// \param HighBits The high 32-bit pattern of the input value.
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double asdouble(uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double2 asdouble(uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double3 asdouble(uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double4 asdouble(uint4, uint4);
+
 //===----------------------------------------------------------------------===//
 // asfloat builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index c6627b0e993226..69de0294cb7c7c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1870,6 +1870,23 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_asdouble: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+    if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
+      return true;
+
+    // Set the return type to be a scalar or vector of same length of double
+    ASTContext &Ctx = SemaRef.getASTContext();
+    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
+
+    QualType ResultType =
+        VTy ? Ctx.getVectorType(Ctx.DoubleTy, VTy->getNumElements(),
+                                VTy->getVectorKind())
+            : Ctx.DoubleTy;
+    TheCall->setType(ResultType);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/asdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
new file mode 100644
index 00000000000000..f1c31107cdcad6
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
@@ -0,0 +1,37 @@
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPV
+
+// Test lowering of asdouble expansion to shuffle/bitcast and splat when required
+
+// CHECK-LABEL: test_uint
+double test_uint(uint low, uint high) {
+  // CHECK-SPV: %[[LOW_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK-SPV: %[[LOW_SHUFFLE:.*]] = shufflevector <1 x i32> %[[LOW_INSERT]], {{.*}} zeroinitializer
+  // CHECK-SPV: %[[HIGH_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK-SPV: %[[HIGH_SHUFFLE:.*]] = shufflevector <1 x i32> %[[HIGH_INSERT]], {{.*}} zeroinitializer
+
+  // CHECK-SPV:      %[[SHUFFLE0:.*]] = shufflevector <1 x i32> %[[LOW_SHUFFLE]], <1 x i32> %[[HIGH_SHUFFLE]],
+  // CHECK-SPV-SAME: {{.*}} <i32 0, i32 1>
+  // CHECK-SPV:      bitcast <2 x i32> %[[SHUFFLE0]] to double
+
+  // CHECK-DXIL: call double @llvm.dx.asdouble.i32
+  return asdouble(low, high);
+}
+
+// CHECK-DXIL: declare double @llvm.dx.asdouble.i32
+
+// CHECK-LABEL: test_vuint
+double3 test_vuint(uint3 low, uint3 high) {
+  // CHECK-SPV:      %[[SHUFFLE1:.*]] = shufflevector
+  // CHECK-SPV-SAME: {{.*}} <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
+  // CHECK-SPV:      bitcast <6 x i32> %[[SHUFFLE1]] to <3 x double>
+
+  // CHECK-DXIL: call <3 x double> @llvm.dx.asdouble.v3i32
+  return asdouble(low, high);
+}
+
+// CHECK-DXIL: declare <3 x double> @llvm.dx.asdouble.v3i32
diff --git a/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl
new file mode 100644
index 00000000000000..c6b57c76a1e2b3
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl
@@ -0,0 +1,16 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+double test_too_few_arg() {
+  return __builtin_hlsl_asdouble();
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+double test_too_few_arg_1(uint p0) {
+  return __builtin_hlsl_asdouble(p0);
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+double test_too_many_arg(uint p0) {
+  return __builtin_hlsl_asdouble(p0, p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0459941fe05cdc..796b4011d71c0c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -896,6 +896,10 @@ class TargetTransformInfo {
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const;
+
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
@@ -1969,6 +1973,9 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
+  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      unsigned ScalarOpdIdx,
+                                                      bool Default) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2530,6 +2537,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) override {
+    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                       Default);
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index dbdfb4d8cdfa32..42d1082cf4d9eb 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -392,6 +392,12 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index db3b5cddd7c1c3..b2841e778947dd 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -798,6 +798,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
   InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
                                            bool Extract,
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..904607a98aa86e 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -40,6 +40,7 @@ def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
 
 def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
+def int_dx_asdouble : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [llvm_anyint_ty, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a47462b61e03b2..bf9733f971fdac 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -612,6 +612,12 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
+bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) const {
+  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                         Default);
+}
+
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..6a1edb9f6debe5 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -779,6 +779,16 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
+
+def MakeDouble :  DXILOp<101, makeDouble> {
+  let Doc = "creates a double value";
+  let LLVMIntrinsic = int_dx_asdouble;
+  let arguments = [Int32Ty, Int32Ty];
+  let result = DoubleTy;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def AnnotateHandle : DXILOp<217, annotateHandle> {
   let Doc = "annotate handle with resource properties";
   let arguments = [HandleTy, ResPropsTy];
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 231afd8ae3eeaf..a115a664209445 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -25,12 +25,23 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
+bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
+  switch (ID) {
+  case Intrinsic::dx_asdouble:
+    return ScalarOpdIdx == 0;
+  default:
+    return Default;
+  }
+}
+
 bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
   case Intrinsic::dx_wave_readlane:
+  case Intrinsic::dx_asdouble:
   case Intrinsic::dx_splitdouble:
     return true;
   default:
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index 30b57ed97d6370..ff82b7404ca58a 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -37,6 +37,9 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx);
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default);
 };
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 772f4c6c35ddec..719dce704872ae 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -727,7 +727,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 
   SmallVector<llvm::Type *, 3> Tys;
   // Add return type if intrinsic is overloaded on it.
-  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+  if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+          ID, -1, isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)))
     Tys.push_back(VS->SplitTy);
 
   if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +768,15 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       }
 
       Scattered[I] = scatter(&CI, OpI, *OpVS);
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I))) {
         OverloadIdx[I] = Tys.size();
         Tys.push_back(OpVS->SplitTy);
       }
     } else {
       ScalarOperands[I] = OpI;
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I)))
         Tys.push_back(OpI->getType());
     }
   }
diff --git a/llvm/test/CodeGen/DirectX/asdouble.ll b/llvm/test/CodeGen/DirectX/asdouble.ll
new file mode 100644
index 00000000000000..6a581d69eb7e9d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/asdouble.ll
@@ -0,0 +1,22 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; Test that for scalar and vector inputs, asdouble maps down to the makeDouble
+; DirectX op
+
+define noundef double @asdouble_scalar(i32 noundef %low, i32 noundef %high) {
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low, i32 %high)
+  %ret = call double @llvm.dx.asdouble.i32(i32 %low, i32 %high)
+  ret double %ret
+}
+
+declare double @llvm.dx.asdouble.i32(i32, i32)
+
+define noundef <3 x double> @asdouble_vec(<3 x i32> noundef %low, <3 x i32> noundef %high) {
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i0, i32 %high.i0)
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i1, i32 %high.i1)
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i2, i32 %high.i2)
+  %ret = call <3 x double> @llvm.dx.asdouble.v3i32(<3 x i32> %low, <3 x i32> %high)
+  ret <3 x double> %ret
+}
+
+declare <3 x double> @llvm.dx.asdouble.v3i32(<3 x i32>, <3 x i32>)

@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2024

@llvm/pr-subscribers-backend-x86

Author: Finn Plummer (inbelic)

Changes
- define intrinsic as builtin in Builtins.td
- link intrinsic in hlsl_intrinsics.h
- add semantic analysis to SemaHLSL.cpp
- lower to `llvm` or a `dx` intrinsic when applicable in CGBuiltin.cpp
- define DXIL intrinsic in IntrinsicsDirectX.td
- add DXIL op and mapping in DXIL.td
- enable scalarization of intrinsic

- add basic sema checking to asdouble-errors.hlsl

Resolves #99081


Full diff: https://github.com/llvm/llvm-project/pull/114847.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+37)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+18)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+17)
  • (added) clang/test/CodeGenHLSL/builtins/asdouble.hlsl (+37)
  • (added) clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl (+16)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+14)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+6)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+6)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+6)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+10)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+11)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h (+3)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+6-3)
  • (added) llvm/test/CodeGen/DirectX/asdouble.ll (+22)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..37ef0bf7324ffb 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4744,6 +4744,12 @@ def HLSLAny : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "bool(...)";
 }
 
+def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_asdouble"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 28f28c70b5ae52..85ad203b50c7ab 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18634,6 +18634,41 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
   return RT.getUDotIntrinsic();
 }
 
+Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) {
+  assert((E->getArg(0)->getType()->hasUnsignedIntegerRepresentation() &&
+          E->getArg(1)->getType()->hasUnsignedIntegerRepresentation()) &&
+         "asdouble operands types mismatch");
+  Value *OpLowBits = CGF.EmitScalarExpr(E->getArg(0));
+  Value *OpHighBits = CGF.EmitScalarExpr(E->getArg(1));
+
+  llvm::Type *ResultType = CGF.DoubleTy;
+  int N = 1;
+  if (auto *VTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {
+    N = VTy->getNumElements();
+    ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N);
+  }
+
+  if (CGF.CGM.getTarget().getTriple().isDXIL())
+    return CGF.Builder.CreateIntrinsic(
+        /*ReturnType=*/ResultType, Intrinsic::dx_asdouble,
+        ArrayRef<Value *>{OpLowBits, OpHighBits}, nullptr, "hlsl.asdouble");
+
+  if (!E->getArg(0)->getType()->isVectorType()) {
+    OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits);
+    OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits);
+  }
+
+  llvm::SmallVector<int> Mask;
+  for (int i = 0; i < N; i++) {
+    Mask.push_back(i);
+    Mask.push_back(i + N);
+  }
+
+  Value *BitVec = CGF.Builder.CreateShuffleVector(OpLowBits, OpHighBits, Mask);
+
+  return CGF.Builder.CreateBitCast(BitVec, ResultType);
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -18655,6 +18690,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         CGM.getHLSLRuntime().getAnyIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,
         "hlsl.any");
   }
+  case Builtin::BI__builtin_hlsl_asdouble:
+    return handleAsDoubleBuiltin(*this, E);
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
     Value *OpX = EmitScalarExpr(E->getArg(0));
     Value *OpMin = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..7dd9c136d1d3f4 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -361,6 +361,24 @@ bool any(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_any)
 bool any(double4);
 
+//===----------------------------------------------------------------------===//
+// asdouble builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn double asdouble(uint LowBits, uint HighBits)
+/// \brief Reinterprets a cast value (two 32-bit values) into a double.
+/// \param LowBits The low 32-bit pattern of the input value.
+/// \param HighBits The high 32-bit pattern of the input value.
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double asdouble(uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double2 asdouble(uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double3 asdouble(uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double4 asdouble(uint4, uint4);
+
 //===----------------------------------------------------------------------===//
 // asfloat builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index c6627b0e993226..69de0294cb7c7c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1870,6 +1870,23 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_asdouble: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+    if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
+      return true;
+
+    // Set the return type to be a scalar or vector of same length of double
+    ASTContext &Ctx = SemaRef.getASTContext();
+    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
+
+    QualType ResultType =
+        VTy ? Ctx.getVectorType(Ctx.DoubleTy, VTy->getNumElements(),
+                                VTy->getVectorKind())
+            : Ctx.DoubleTy;
+    TheCall->setType(ResultType);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/asdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
new file mode 100644
index 00000000000000..f1c31107cdcad6
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
@@ -0,0 +1,37 @@
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPV
+
+// Test lowering of asdouble expansion to shuffle/bitcast and splat when required
+
+// CHECK-LABEL: test_uint
+double test_uint(uint low, uint high) {
+  // CHECK-SPV: %[[LOW_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK-SPV: %[[LOW_SHUFFLE:.*]] = shufflevector <1 x i32> %[[LOW_INSERT]], {{.*}} zeroinitializer
+  // CHECK-SPV: %[[HIGH_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK-SPV: %[[HIGH_SHUFFLE:.*]] = shufflevector <1 x i32> %[[HIGH_INSERT]], {{.*}} zeroinitializer
+
+  // CHECK-SPV:      %[[SHUFFLE0:.*]] = shufflevector <1 x i32> %[[LOW_SHUFFLE]], <1 x i32> %[[HIGH_SHUFFLE]],
+  // CHECK-SPV-SAME: {{.*}} <i32 0, i32 1>
+  // CHECK-SPV:      bitcast <2 x i32> %[[SHUFFLE0]] to double
+
+  // CHECK-DXIL: call double @llvm.dx.asdouble.i32
+  return asdouble(low, high);
+}
+
+// CHECK-DXIL: declare double @llvm.dx.asdouble.i32
+
+// CHECK-LABEL: test_vuint
+double3 test_vuint(uint3 low, uint3 high) {
+  // CHECK-SPV:      %[[SHUFFLE1:.*]] = shufflevector
+  // CHECK-SPV-SAME: {{.*}} <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
+  // CHECK-SPV:      bitcast <6 x i32> %[[SHUFFLE1]] to <3 x double>
+
+  // CHECK-DXIL: call <3 x double> @llvm.dx.asdouble.v3i32
+  return asdouble(low, high);
+}
+
+// CHECK-DXIL: declare <3 x double> @llvm.dx.asdouble.v3i32
diff --git a/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl
new file mode 100644
index 00000000000000..c6b57c76a1e2b3
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl
@@ -0,0 +1,16 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+double test_too_few_arg() {
+  return __builtin_hlsl_asdouble();
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+double test_too_few_arg_1(uint p0) {
+  return __builtin_hlsl_asdouble(p0);
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+double test_too_many_arg(uint p0) {
+  return __builtin_hlsl_asdouble(p0, p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0459941fe05cdc..796b4011d71c0c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -896,6 +896,10 @@ class TargetTransformInfo {
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const;
+
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
@@ -1969,6 +1973,9 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
+  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      unsigned ScalarOpdIdx,
+                                                      bool Default) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2530,6 +2537,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) override {
+    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                       Default);
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index dbdfb4d8cdfa32..42d1082cf4d9eb 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -392,6 +392,12 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index db3b5cddd7c1c3..b2841e778947dd 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -798,6 +798,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
   InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
                                            bool Extract,
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..904607a98aa86e 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -40,6 +40,7 @@ def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
 
 def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
+def int_dx_asdouble : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [llvm_anyint_ty, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a47462b61e03b2..bf9733f971fdac 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -612,6 +612,12 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
+bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) const {
+  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                         Default);
+}
+
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..6a1edb9f6debe5 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -779,6 +779,16 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
+
+def MakeDouble :  DXILOp<101, makeDouble> {
+  let Doc = "creates a double value";
+  let LLVMIntrinsic = int_dx_asdouble;
+  let arguments = [Int32Ty, Int32Ty];
+  let result = DoubleTy;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def AnnotateHandle : DXILOp<217, annotateHandle> {
   let Doc = "annotate handle with resource properties";
   let arguments = [HandleTy, ResPropsTy];
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 231afd8ae3eeaf..a115a664209445 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -25,12 +25,23 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
+bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
+  switch (ID) {
+  case Intrinsic::dx_asdouble:
+    return ScalarOpdIdx == 0;
+  default:
+    return Default;
+  }
+}
+
 bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
   case Intrinsic::dx_wave_readlane:
+  case Intrinsic::dx_asdouble:
   case Intrinsic::dx_splitdouble:
     return true;
   default:
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index 30b57ed97d6370..ff82b7404ca58a 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -37,6 +37,9 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx);
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default);
 };
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 772f4c6c35ddec..719dce704872ae 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -727,7 +727,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 
   SmallVector<llvm::Type *, 3> Tys;
   // Add return type if intrinsic is overloaded on it.
-  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+  if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+          ID, -1, isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)))
     Tys.push_back(VS->SplitTy);
 
   if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +768,15 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       }
 
       Scattered[I] = scatter(&CI, OpI, *OpVS);
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I))) {
         OverloadIdx[I] = Tys.size();
         Tys.push_back(OpVS->SplitTy);
       }
     } else {
       ScalarOperands[I] = OpI;
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I)))
         Tys.push_back(OpI->getType());
     }
   }
diff --git a/llvm/test/CodeGen/DirectX/asdouble.ll b/llvm/test/CodeGen/DirectX/asdouble.ll
new file mode 100644
index 00000000000000..6a581d69eb7e9d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/asdouble.ll
@@ -0,0 +1,22 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; Test that for scalar and vector inputs, asdouble maps down to the makeDouble
+; DirectX op
+
+define noundef double @asdouble_scalar(i32 noundef %low, i32 noundef %high) {
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low, i32 %high)
+  %ret = call double @llvm.dx.asdouble.i32(i32 %low, i32 %high)
+  ret double %ret
+}
+
+declare double @llvm.dx.asdouble.i32(i32, i32)
+
+define noundef <3 x double> @asdouble_vec(<3 x i32> noundef %low, <3 x i32> noundef %high) {
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i0, i32 %high.i0)
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i1, i32 %high.i1)
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i2, i32 %high.i2)
+  %ret = call <3 x double> @llvm.dx.asdouble.v3i32(<3 x i32> %low, <3 x i32> %high)
+  ret <3 x double> %ret
+}
+
+declare <3 x double> @llvm.dx.asdouble.v3i32(<3 x i32>, <3 x i32>)

double test_too_many_arg(uint p0) {
return __builtin_hlsl_asdouble(p0, p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like adding some tests to check for type mismatch would be useful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a type that you have in mind? Most types will be implicitly converted here, I added the custom struct type to illustrate that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any type really, what we really care about here is making sure the error is meaningful to the user. I usually write such tests with half, float or bool if the former are implicitly cast.

Copy link
Contributor

@spall spall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@farzonl
Copy link
Member

farzonl commented Nov 12, 2024

LGTM, but won't sign off until the TargetTransformInfo issue on the precursor pr is resolved.

inbelic added a commit that referenced this pull request Nov 21, 2024
…rloadTypeAtArg` api (#114849)

This changes allows target intrinsics to specify and overwrite overloaded types.

- Updates `ReplaceWithVecLib` to not provide TTI as there most probably won't be a use-case
- Updates `SLPVectorizer` to use available TTI
- Updates `VPTransformState` to pass down TTI
- Updates `VPlanRecipe` to use passed-down TTI

This change will let us add scalarization for `asdouble`:  #114847
- define intrinsic as builtin in Builtins.td
  - link intrinsic in hlsl_intrinsics.h
  - add semantic analysis to SemaHLSL.cpp

  - add basic sema checking to asdouble-errors.hlsl
- use already declared function
@inbelic inbelic merged commit a5f501e into llvm:main Nov 22, 2024
8 checks passed
@inbelic inbelic deleted the inbelic/as-double branch November 25, 2024 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:analysis llvm:ir llvm:transforms
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

Implement the asdouble HLSL Function
5 participants