-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[HLSL][DXIL][SPIRV] Create llvm dot intrinsic and use for HLSL #102872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Per https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294 `dot` should be an LLVM intrinsic. This adds the llvm intrinsics and updates HLSL builtin codegen to emit them. Removed some stale comments that gave the obsolete impression that type conversions should be expected to match overloads. With dot moving into an LLVM intrinsic, the lowering to dx-specific operations doesn't take place until DXIL intrinsic expansion. This moves the introduction of arity-specific DX opcodes to DXIL intrinsic expansion. Part of llvm#88056
The new LLVM integer intrinsics replace the previous dx intrinsics. This updates the DXIL intrinsic expansion code and tests to use and expect the new integer intrinsics and the flattened DX floating vector size variants only after op lowering. Part of llvm#88056
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-clang Author: Greg Roth (pow2clk) ChangesPer https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294 Removed some stale comments that gave the obsolete impression that With dot moving into an LLVM intrinsic, the lowering to dx-specific The new LLVM integer intrinsics replace the previous dx intrinsics. Use the new LLVM dot intrinsics to build SPIRV instructions. New tests for generating SPIRV float and integer dot intrinsics are Fixes #88056 Patch is 52.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102872.diff 17 Files Affected:
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7fe80b0cbdfbfa..67148e32014ed2 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18470,22 +18470,14 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}
-Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
- if (QT->hasFloatingRepresentation()) {
- switch (elementCount) {
- case 2:
- return Intrinsic::dx_dot2;
- case 3:
- return Intrinsic::dx_dot3;
- case 4:
- return Intrinsic::dx_dot4;
- }
- }
- if (QT->hasSignedIntegerRepresentation())
- return Intrinsic::dx_sdot;
-
- assert(QT->hasUnsignedIntegerRepresentation());
- return Intrinsic::dx_udot;
+// Return dot product intrinsic that corresponds to the QT scalar type
+Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+ if (QT->isFloatingType())
+ return Intrinsic::fdot;
+ if (QT->isSignedIntegerType())
+ return Intrinsic::sdot;
+ assert(QT->isUnsignedIntegerType());
+ return Intrinsic::udot;
}
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18528,37 +18520,38 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Value *Op1 = EmitScalarExpr(E->getArg(1));
llvm::Type *T0 = Op0->getType();
llvm::Type *T1 = Op1->getType();
+
+ // If the arguments are scalars, just emit a multiply
if (!T0->isVectorTy() && !T1->isVectorTy()) {
if (T0->isFloatingPointTy())
- return Builder.CreateFMul(Op0, Op1, "dx.dot");
+ return Builder.CreateFMul(Op0, Op1, "dot");
if (T0->isIntegerTy())
- return Builder.CreateMul(Op0, Op1, "dx.dot");
+ return Builder.CreateMul(Op0, Op1, "dot");
- // Bools should have been promoted
llvm_unreachable(
"Scalar dot product is only supported on ints and floats.");
}
+ // For vectors, validate types and emit the appropriate intrinsic
+
// A VectorSplat should have happened
assert(T0->isVectorTy() && T1->isVectorTy() &&
"Dot product of vector and scalar is not supported.");
- // A vector sext or sitofp should have happened
- assert(T0->getScalarType() == T1->getScalarType() &&
- "Dot product of vectors need the same element types.");
-
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *VecTy1 =
E->getArg(1)->getType()->getAs<VectorType>();
- // A HLSLVectorTruncation should have happend
+
+ assert(VecTy0->getElementType() == VecTy1->getElementType() &&
+ "Dot product of vectors need the same element types.");
+
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
"Dot product requires vectors to be of the same size.");
return Builder.CreateIntrinsic(
/*ReturnType=*/T0->getScalarType(),
- getDotProductIntrinsic(E->getArg(0)->getType(),
- VecTy0->getNumElements()),
- ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+ getDotProductIntrinsic(VecTy0->getElementType()),
+ ArrayRef<Value *>{Op0, Op1}, nullptr, "dot");
} break;
case Builtin::BI__builtin_hlsl_lerp: {
Value *X = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
index b0b95074c972d5..6036f9430db4f0 100644
--- a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
@@ -2,8 +2,8 @@
// CHECK-LABEL: builtin_bool_to_float_type_promotion
// CHECK: %conv1 = uitofp i1 %loadedv to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
// CHECK: ret float %conv2
float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
@@ -12,8 +12,8 @@ float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
// CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion
// CHECK: %conv = uitofp i1 %loadedv to double
// CHECK: %conv1 = fpext float %1 to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
// CHECK: ret float %conv2
float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
@@ -22,8 +22,8 @@ float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
// CHECK-LABEL: builtin_dot_int_to_float_promotion
// CHECK: %conv = fpext float %0 to double
// CHECK: %conv1 = sitofp i32 %1 to double
-// CHECK: dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
// CHECK: ret float %conv2
float builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index ae6e45c3f9482a..b9486f433cced1 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -7,155 +7,155 @@
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
#ifdef __HLSL_ENABLE_16_BIT
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
#endif
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = mul i32 %0, %1
+// CHECK: ret i32 %dot
int test_dot_int(int p0, int p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dot
int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dot
int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dot
int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = mul i32 %0, %1
+// CHECK: ret i32 %dot
uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dot
uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dot
uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dot
uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = mul i64 %0, %1
+// CHECK: ret i64 %dot
int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dot
int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dot
int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dot
int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = mul i64 %0, %1
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = fmul half %0, %1
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = fmul float %0, %1
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = fmul half %0, %1
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = fmul float %0, %1
+// NO_HALF: ret float %dot
half test_dot_half(half p0, half p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %dot
half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %dot
half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %dot
half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = fmul float %0, %1
-// CHECK: ret float %dx.dot
+// CHECK: %dot = fmul float %0, %1
+// CHECK: ret float %dot
float test_dot_float(float p0, float p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %dot
float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: ret float %dot
float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: ret float %dot
float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: ret float %dot
float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: ret float %dot
float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: ret float %dot
float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = fmul double %0, %1
-// CHECK: ret double %dx.dot
+// CHECK: %dot = fmul double %0, %1
+// CHECK: ret double %dot
double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index b4e758136b39fb..815da809d28a73 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1045,6 +1045,15 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable, IntrWillReturn] in {
def int_nearbyint : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_round : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_roundeven : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+ def int_udot : Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+ def int_sdot : Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+ def int_fdot : Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
// Truncate a floating point number with a specific rounding mode
def int_fptrunc_round : DefaultAttrsIntrinsic<[ llvm_anyfloat_ty ],
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..8ce79eb7cbaafa 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -25,26 +25,18 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
-def int_dx_dot2 :
- Intrinsic<[LLVMVectorElementType<0>],
+def int_dx_dot2 :
+ Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot3 :
- Intrinsic<[LLVMVectorElementType<0>],
+def int_dx_dot3 :
+ Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot4 :
- Intrinsic<[LLVMVectorElementType<0>],
+def int_dx_dot4 :
+ Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_sdot :
- Intrinsic<[LLVMVectorElementType<0>],
- [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_udot :
- Intrinsic<[LLVMVectorElementType<0>],
- [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index 9fb6de49fb2055..0808fd9d77be82 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -814,6 +814,15 @@ HANDLE_TARGET_OPCODE(G_FSINH)
/// Floating point hyperbolic tangent.
HANDLE_TARGET_OPCODE(G_FTANH)
+/// Floating point vector dot product
+HANDLE_TARGET_OPCODE(G_FDOTPROD)
+
+/// Unsigned integer vector dot product
+HANDLE_TARGET_OPCODE(G_UDOTPROD)
+
+/// Signed integer vector dot product
+HANDLE_TARGET_OPCODE(G_SDOTPROD)
+
/// Floating point square root.
HANDLE_TARGET_OPCODE(G_FSQRT)
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index 36a0a087ba457c..648671f627d649 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1057,6 +1057,27 @@ def G_FTANH : GenericInstruction {
let hasSideEffects = false;
}
+/// Floating point vector dot product
+def G_FDOTPROD : GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let InOperandList = (ins type0:$src1, type0:$src2);
+ let hasSideEffects = false;
+}
+
+/// Signed integer vector dot product
+def G_SDOTPROD : GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let ...
[truncated]
|
@llvm/pr-subscribers-backend-directx Author: Greg Roth (pow2clk) ChangesPer https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294 Removed some stale comments that gave the obsolete impression that With dot moving into an LLVM intrinsic, the lowering to dx-specific The new LLVM integer intrinsics replace the previous dx intrinsics. Use the new LLVM dot intrinsics to build SPIRV instructions. New tests for generating SPIRV float and integer dot intrinsics are Fixes #88056 Patch is 52.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102872.diff 17 Files Affected:
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7fe80b0cbdfbfa..67148e32014ed2 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18470,22 +18470,14 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}
-Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
- if (QT->hasFloatingRepresentation()) {
- switch (elementCount) {
- case 2:
- return Intrinsic::dx_dot2;
- case 3:
- return Intrinsic::dx_dot3;
- case 4:
- return Intrinsic::dx_dot4;
- }
- }
- if (QT->hasSignedIntegerRepresentation())
- return Intrinsic::dx_sdot;
-
- assert(QT->hasUnsignedIntegerRepresentation());
- return Intrinsic::dx_udot;
+// Return dot product intrinsic that corresponds to the QT scalar type
+Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+ if (QT->isFloatingType())
+ return Intrinsic::fdot;
+ if (QT->isSignedIntegerType())
+ return Intrinsic::sdot;
+ assert(QT->isUnsignedIntegerType());
+ return Intrinsic::udot;
}
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18528,37 +18520,38 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Value *Op1 = EmitScalarExpr(E->getArg(1));
llvm::Type *T0 = Op0->getType();
llvm::Type *T1 = Op1->getType();
+
+ // If the arguments are scalars, just emit a multiply
if (!T0->isVectorTy() && !T1->isVectorTy()) {
if (T0->isFloatingPointTy())
- return Builder.CreateFMul(Op0, Op1, "dx.dot");
+ return Builder.CreateFMul(Op0, Op1, "dot");
if (T0->isIntegerTy())
- return Builder.CreateMul(Op0, Op1, "dx.dot");
+ return Builder.CreateMul(Op0, Op1, "dot");
- // Bools should have been promoted
llvm_unreachable(
"Scalar dot product is only supported on ints and floats.");
}
+ // For vectors, validate types and emit the appropriate intrinsic
+
// A VectorSplat should have happened
assert(T0->isVectorTy() && T1->isVectorTy() &&
"Dot product of vector and scalar is not supported.");
- // A vector sext or sitofp should have happened
- assert(T0->getScalarType() == T1->getScalarType() &&
- "Dot product of vectors need the same element types.");
-
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *VecTy1 =
E->getArg(1)->getType()->getAs<VectorType>();
- // A HLSLVectorTruncation should have happend
+
+ assert(VecTy0->getElementType() == VecTy1->getElementType() &&
+ "Dot product of vectors need the same element types.");
+
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
"Dot product requires vectors to be of the same size.");
return Builder.CreateIntrinsic(
/*ReturnType=*/T0->getScalarType(),
- getDotProductIntrinsic(E->getArg(0)->getType(),
- VecTy0->getNumElements()),
- ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+ getDotProductIntrinsic(VecTy0->getElementType()),
+ ArrayRef<Value *>{Op0, Op1}, nullptr, "dot");
} break;
case Builtin::BI__builtin_hlsl_lerp: {
Value *X = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
index b0b95074c972d5..6036f9430db4f0 100644
--- a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
@@ -2,8 +2,8 @@
// CHECK-LABEL: builtin_bool_to_float_type_promotion
// CHECK: %conv1 = uitofp i1 %loadedv to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
// CHECK: ret float %conv2
float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
@@ -12,8 +12,8 @@ float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
// CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion
// CHECK: %conv = uitofp i1 %loadedv to double
// CHECK: %conv1 = fpext float %1 to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
// CHECK: ret float %conv2
float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
@@ -22,8 +22,8 @@ float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
// CHECK-LABEL: builtin_dot_int_to_float_promotion
// CHECK: %conv = fpext float %0 to double
// CHECK: %conv1 = sitofp i32 %1 to double
-// CHECK: dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
// CHECK: ret float %conv2
float builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index ae6e45c3f9482a..b9486f433cced1 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -7,155 +7,155 @@
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
#ifdef __HLSL_ENABLE_16_BIT
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
#endif
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = mul i32 %0, %1
+// CHECK: ret i32 %dot
int test_dot_int(int p0, int p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dot
int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dot
int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dot
int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = mul i32 %0, %1
+// CHECK: ret i32 %dot
uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dot
uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dot
uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dot
uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = mul i64 %0, %1
+// CHECK: ret i64 %dot
int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dot
int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dot
int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dot
int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = mul i64 %0, %1
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = fmul half %0, %1
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = fmul float %0, %1
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = fmul half %0, %1
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = fmul float %0, %1
+// NO_HALF: ret float %dot
half test_dot_half(half p0, half p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %dot
half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %dot
half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %dot
half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = fmul float %0, %1
-// CHECK: ret float %dx.dot
+// CHECK: %dot = fmul float %0, %1
+// CHECK: ret float %dot
float test_dot_float(float p0, float p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %dot
float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: ret float %dot
float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: ret float %dot
float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: ret float %dot
float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: ret float %dot
float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: ret float %dot
float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = fmul double %0, %1
-// CHECK: ret double %dx.dot
+// CHECK: %dot = fmul double %0, %1
+// CHECK: ret double %dot
double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index b4e758136b39fb..815da809d28a73 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1045,6 +1045,15 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable, IntrWillReturn] in {
def int_nearbyint : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_round : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_roundeven : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+ def int_udot : Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+ def int_sdot : Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+ def int_fdot : Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
// Truncate a floating point number with a specific rounding mode
def int_fptrunc_round : DefaultAttrsIntrinsic<[ llvm_anyfloat_ty ],
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..8ce79eb7cbaafa 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -25,26 +25,18 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
-def int_dx_dot2 :
- Intrinsic<[LLVMVectorElementType<0>],
+def int_dx_dot2 :
+ Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot3 :
- Intrinsic<[LLVMVectorElementType<0>],
+def int_dx_dot3 :
+ Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot4 :
- Intrinsic<[LLVMVectorElementType<0>],
+def int_dx_dot4 :
+ Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_sdot :
- Intrinsic<[LLVMVectorElementType<0>],
- [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_udot :
- Intrinsic<[LLVMVectorElementType<0>],
- [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index 9fb6de49fb2055..0808fd9d77be82 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -814,6 +814,15 @@ HANDLE_TARGET_OPCODE(G_FSINH)
/// Floating point hyperbolic tangent.
HANDLE_TARGET_OPCODE(G_FTANH)
+/// Floating point vector dot product
+HANDLE_TARGET_OPCODE(G_FDOTPROD)
+
+/// Unsigned integer vector dot product
+HANDLE_TARGET_OPCODE(G_UDOTPROD)
+
+/// Signed integer vector dot product
+HANDLE_TARGET_OPCODE(G_SDOTPROD)
+
/// Floating point square root.
HANDLE_TARGET_OPCODE(G_FSQRT)
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index 36a0a087ba457c..648671f627d649 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1057,6 +1057,27 @@ def G_FTANH : GenericInstruction {
let hasSideEffects = false;
}
+/// Floating point vector dot product
+def G_FDOTPROD : GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let InOperandList = (ins type0:$src1, type0:$src2);
+ let hasSideEffects = false;
+}
+
+/// Signed integer vector dot product
+def G_SDOTPROD : GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let ...
[truncated]
|
@llvm/pr-subscribers-llvm-globalisel Author: Greg Roth (pow2clk) ChangesPer https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294 Removed some stale comments that gave the obsolete impression that With dot moving into an LLVM intrinsic, the lowering to dx-specific The new LLVM integer intrinsics replace the previous dx intrinsics. Use the new LLVM dot intrinsics to build SPIRV instructions. New tests for generating SPIRV float and integer dot intrinsics are Fixes #88056 Patch is 52.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102872.diff 17 Files Affected:
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7fe80b0cbdfbfa..67148e32014ed2 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18470,22 +18470,14 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}
-Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
- if (QT->hasFloatingRepresentation()) {
- switch (elementCount) {
- case 2:
- return Intrinsic::dx_dot2;
- case 3:
- return Intrinsic::dx_dot3;
- case 4:
- return Intrinsic::dx_dot4;
- }
- }
- if (QT->hasSignedIntegerRepresentation())
- return Intrinsic::dx_sdot;
-
- assert(QT->hasUnsignedIntegerRepresentation());
- return Intrinsic::dx_udot;
+// Return dot product intrinsic that corresponds to the QT scalar type
+Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+ if (QT->isFloatingType())
+ return Intrinsic::fdot;
+ if (QT->isSignedIntegerType())
+ return Intrinsic::sdot;
+ assert(QT->isUnsignedIntegerType());
+ return Intrinsic::udot;
}
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18528,37 +18520,38 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Value *Op1 = EmitScalarExpr(E->getArg(1));
llvm::Type *T0 = Op0->getType();
llvm::Type *T1 = Op1->getType();
+
+ // If the arguments are scalars, just emit a multiply
if (!T0->isVectorTy() && !T1->isVectorTy()) {
if (T0->isFloatingPointTy())
- return Builder.CreateFMul(Op0, Op1, "dx.dot");
+ return Builder.CreateFMul(Op0, Op1, "dot");
if (T0->isIntegerTy())
- return Builder.CreateMul(Op0, Op1, "dx.dot");
+ return Builder.CreateMul(Op0, Op1, "dot");
- // Bools should have been promoted
llvm_unreachable(
"Scalar dot product is only supported on ints and floats.");
}
+ // For vectors, validate types and emit the appropriate intrinsic
+
// A VectorSplat should have happened
assert(T0->isVectorTy() && T1->isVectorTy() &&
"Dot product of vector and scalar is not supported.");
- // A vector sext or sitofp should have happened
- assert(T0->getScalarType() == T1->getScalarType() &&
- "Dot product of vectors need the same element types.");
-
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *VecTy1 =
E->getArg(1)->getType()->getAs<VectorType>();
- // A HLSLVectorTruncation should have happend
+
+ assert(VecTy0->getElementType() == VecTy1->getElementType() &&
+ "Dot product of vectors need the same element types.");
+
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
"Dot product requires vectors to be of the same size.");
return Builder.CreateIntrinsic(
/*ReturnType=*/T0->getScalarType(),
- getDotProductIntrinsic(E->getArg(0)->getType(),
- VecTy0->getNumElements()),
- ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+ getDotProductIntrinsic(VecTy0->getElementType()),
+ ArrayRef<Value *>{Op0, Op1}, nullptr, "dot");
} break;
case Builtin::BI__builtin_hlsl_lerp: {
Value *X = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
index b0b95074c972d5..6036f9430db4f0 100644
--- a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
@@ -2,8 +2,8 @@
// CHECK-LABEL: builtin_bool_to_float_type_promotion
// CHECK: %conv1 = uitofp i1 %loadedv to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
// CHECK: ret float %conv2
float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
@@ -12,8 +12,8 @@ float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
// CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion
// CHECK: %conv = uitofp i1 %loadedv to double
// CHECK: %conv1 = fpext float %1 to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
// CHECK: ret float %conv2
float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
@@ -22,8 +22,8 @@ float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
// CHECK-LABEL: builtin_dot_int_to_float_promotion
// CHECK: %conv = fpext float %0 to double
// CHECK: %conv1 = sitofp i32 %1 to double
-// CHECK: dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
// CHECK: ret float %conv2
float builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index ae6e45c3f9482a..b9486f433cced1 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -7,155 +7,155 @@
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
#ifdef __HLSL_ENABLE_16_BIT
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
#endif
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = mul i32 %0, %1
+// CHECK: ret i32 %dot
int test_dot_int(int p0, int p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dot
int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dot
int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dot
int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = mul i32 %0, %1
+// CHECK: ret i32 %dot
uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dot
uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dot
uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dot
uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = mul i64 %0, %1
+// CHECK: ret i64 %dot
int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dot
int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dot
int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dot
int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = mul i64 %0, %1
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dot
uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = fmul half %0, %1
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = fmul float %0, %1
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = fmul half %0, %1
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = fmul float %0, %1
+// NO_HALF: ret float %dot
half test_dot_half(half p0, half p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %dot
half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %dot
half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %dot
half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = fmul float %0, %1
-// CHECK: ret float %dx.dot
+// CHECK: %dot = fmul float %0, %1
+// CHECK: ret float %dot
float test_dot_float(float p0, float p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %dot
float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: ret float %dot
float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: ret float %dot
float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: ret float %dot
float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: ret float %dot
float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: ret float %dot
float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = fmul double %0, %1
-// CHECK: ret double %dx.dot
+// CHECK: %dot = fmul double %0, %1
+// CHECK: ret double %dot
double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index b4e758136b39fb..815da809d28a73 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1045,6 +1045,15 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable, IntrWillReturn] in {
def int_nearbyint : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_round : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_roundeven : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+ def int_udot : Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+ def int_sdot : Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+ def int_fdot : Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
// Truncate a floating point number with a specific rounding mode
def int_fptrunc_round : DefaultAttrsIntrinsic<[ llvm_anyfloat_ty ],
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..8ce79eb7cbaafa 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -25,26 +25,18 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
-def int_dx_dot2 :
- Intrinsic<[LLVMVectorElementType<0>],
+def int_dx_dot2 :
+ Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot3 :
- Intrinsic<[LLVMVectorElementType<0>],
+def int_dx_dot3 :
+ Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot4 :
- Intrinsic<[LLVMVectorElementType<0>],
+def int_dx_dot4 :
+ Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_sdot :
- Intrinsic<[LLVMVectorElementType<0>],
- [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_udot :
- Intrinsic<[LLVMVectorElementType<0>],
- [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index 9fb6de49fb2055..0808fd9d77be82 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -814,6 +814,15 @@ HANDLE_TARGET_OPCODE(G_FSINH)
/// Floating point hyperbolic tangent.
HANDLE_TARGET_OPCODE(G_FTANH)
+/// Floating point vector dot product
+HANDLE_TARGET_OPCODE(G_FDOTPROD)
+
+/// Unsigned integer vector dot product
+HANDLE_TARGET_OPCODE(G_UDOTPROD)
+
+/// Signed integer vector dot product
+HANDLE_TARGET_OPCODE(G_SDOTPROD)
+
/// Floating point square root.
HANDLE_TARGET_OPCODE(G_FSQRT)
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index 36a0a087ba457c..648671f627d649 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1057,6 +1057,27 @@ def G_FTANH : GenericInstruction {
let hasSideEffects = false;
}
+/// Floating point vector dot product
+def G_FDOTPROD : GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let InOperandList = (ins type0:$src1, type0:$src2);
+ let hasSideEffects = false;
+}
+
+/// Signed integer vector dot product
+def G_SDOTPROD : GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The three commits are independently committable, but this is the grouping @farzonl and I agreed on. Reviewing them individually still might make this easier:
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>(); | ||
[[maybe_unused]] auto *VecTy1 = | ||
E->getArg(1)->getType()->getAs<VectorType>(); | ||
// A HLSLVectorTruncation should have happend | ||
|
||
assert(VecTy0->getElementType() == VecTy1->getElementType() && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched to clang types to match signedness of integers
} | ||
|
||
/// Unsigned integer vector dot product | ||
def G_UDOTPROD : GenericInstruction { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unweildy names are because G_UDOT and G_SDOT clashed with existing AArch64 intrinsics that take three arguments as one is an accumulated inout parameter.
def G_UDOT : AArch64GenericInstruction { |
@@ -659,7 +659,7 @@ def Dot3 : DXILOp<55, dot3> { | |||
|
|||
def Dot4 : DXILOp<56, dot4> { | |||
let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " | |||
"a[n]*b[n] where n is between 0 and 3"; | |||
"a[n]*b[n] where n is 0 to 3 inclusive"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just something incidental as I found these descriptions misleading since the only numbers "between" 0 and 3 are 1 and 2.
// NATIVE_HALF: %dot = call half @llvm.fdot.v2f16(<2 x half> %0, <2 x half> %1) | ||
// NATIVE_HALF: ret half %dot | ||
// NO_HALF: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1) | ||
// NO_HALF: ret float %dot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before this, the only change was in the temp names since they are no longer dx-exclusive ops. Here and hereafter, for floating-point values, we no longer lower to the vector-size-specific ops until DXIL intrinsic expansion, so these have a more generic form here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while it is no longer a dx exclusive op we are still emitting the intrinsic in an hlsl specific code section. so instead of using %dot = ...
we should use %hlsl.dot
.
.addUse(I.getOperand(1).getReg()) | ||
.addUse(I.getOperand(2).getReg()) | ||
.constrainAllUses(TII, TRI, RBI); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a similar implementation here:
static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure myself. This might be a question for @Keenuts, @sudonatalie, or @VyacheslavLevytskyy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think trying to merge those two SPIRV::OpDot
's would not simplify or improve anything, let's keep as is.
Use the new LLVM dot intrinsics to build SPIRV instructions. This involves generating multiply and add operations for integers and the existing OpDot operation for floating point. This includes adding some generic opcodes for signed, unsigned and floats. These require updating an existing test for all such opcodes. New tests for generating SPIRV float and integer dot intrinsics are added as well. Fixes llvm#88056
490c0c0
to
edbb80c
Compare
Please advertise your achievements: https://github.com/llvm/llvm-project/blob/main/llvm/docs/GlobalISel/GenericOpcode.rst |
clang/lib/CodeGen/CGBuiltin.cpp
Outdated
if (!T0->isVectorTy() && !T1->isVectorTy()) { | ||
if (T0->isFloatingPointTy()) | ||
return Builder.CreateFMul(Op0, Op1, "dx.dot"); | ||
return Builder.CreateFMul(Op0, Op1, "dot"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hlsl.dot
clang/lib/CodeGen/CGBuiltin.cpp
Outdated
|
||
if (T0->isIntegerTy()) | ||
return Builder.CreateMul(Op0, Op1, "dx.dot"); | ||
return Builder.CreateMul(Op0, Op1, "dot"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hlsl.dot
if (QT->isFloatingType()) | ||
return Intrinsic::fdot; | ||
if (QT->isSignedIntegerType()) | ||
return Intrinsic::sdot; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to do an llvm dot intrinsic for integers. I'm pretty sure the RFC just covered the float case. I would instead do CGM.getHLSLRuntime().getSDotIntrinsic()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm rereading kparzysz posts in the rfc and seeing aarch64 wants integer dot seems like my comment here isn't correct. I still feel weird about it because our usages of it don't lower to a specific opcode in either the SPIRV or DXIL backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Justin's proposal didn't say explicitly, but linked to the HLSL documentation , which explicitly includes integers. In the discussion, there was an explicit request for integer versions that no one disapproved of there.
if (QT->isSignedIntegerType()) | ||
return Intrinsic::sdot; | ||
assert(QT->isUnsignedIntegerType()); | ||
return Intrinsic::udot; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CGM.getHLSLRuntime().getUDotIntrinsic()
clang/lib/CodeGen/CGBuiltin.cpp
Outdated
VecTy0->getNumElements()), | ||
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot"); | ||
getDotProductIntrinsic(VecTy0->getElementType()), | ||
ArrayRef<Value *>{Op0, Op1}, nullptr, "dot"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hlsl.dot
llvm/include/llvm/IR/Intrinsics.td
Outdated
def int_sdot : Intrinsic<[LLVMVectorElementType<0>], | ||
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], | ||
[IntrNoMem, IntrWillReturn, Commutative] >; | ||
def int_fdot : Intrinsic<[LLVMVectorElementType<0>], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a cleanup I wanted to do to swith to DefaultAttrsIntrinsic
I don't remember what the full list of default attributes are at the moment. but i'm pretty sure only one we would need to pass in if we switch is Commutative
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default properties seem to be
IntrNoCallBack
IntrNoSync
IntrNoFree
IntrWillReturn
It seems we need to set IntrNoMem
too then.
Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot | ||
? Intrinsic::dx_imad | ||
: Intrinsic::dx_umad; | ||
static bool expandDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for simplicity I would keep a seperate expandIntegerDot
and expandFloatDot
. Doing it this way is a little weird considering you already have a conditional to seperate behavior via the switch cases. Then you merge back only to seperate out the behaviors again via contional if (EltTy->isIntegerTy()) {
@@ -1366,6 +1383,67 @@ bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg, | |||
.constrainAllUses(TII, TRI, RBI); | |||
} | |||
|
|||
// Since there is no integer dot implementation, expand by piecewise multiplying |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually I see a OpSDot
, OpUDot
, and OpSUDot
(for mixes signedness) in the spec. Maybe we just don't use these in DXC's spirv generation. I'm fine with sticking with what DXC does we should at least mark the comment accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those are fairly recent SPIRV extensions. I didn't think incorporating them was within scope.
AArch64 has a udot and sdot instruction (and a usdot instruction). They perform a "partial" reduction though, producing a v4i32 from two v16i8 inputs. We would like to use those from the vectorizer and have recently added a partial-reduction intrinsic, but doing it with a higher level intrinsic might be a little nicer. It would seem like a "udot" can be represented already as @SamTebbs33 @NickGuy-Arm FYI. |
We haven't done it yet, but our plan here is to create a default expansion in
There was a whole discussion on dot in https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294/13 check out
|
Missing LangRef changes. |
Rename dot-related ops to hlsl.dot Add documentation for new Genereic opcodes and llvm intrinsics Use DefaultAttrsIntrinsic to define new llvm intrinsics Split DXIL instruction expansion for integer and float dot products
You can test this locally with the following command:git-clang-format --diff 3c3df1bef84bd509bdd2b6033bc9bb3653826388 c08c0153cbde1f43bebdbb8b50b74e77cdfc40bb --extensions cpp -- clang/lib/CodeGen/CGBuiltin.cpp llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp View the diff from clang-format here.diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 9b7ea0950d..7578752780 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -100,15 +100,17 @@ static bool expandFloatDotIntrinsic(CallInst *Orig) {
default:
llvm_unreachable("dot product with vector outside 2-4 range");
}
- Value *Result = Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
- ArrayRef<Value *>{A, B}, nullptr, "dot");
+ Value *Result =
+ Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
+ ArrayRef<Value *>{A, B}, nullptr, "dot");
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
}
// Expand integer dot product to multiply and add ops
-static bool expandIntegerDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+static bool expandIntegerDotIntrinsic(CallInst *Orig,
+ Intrinsic::ID DotIntrinsic) {
assert(DotIntrinsic == Intrinsic::sdot || DotIntrinsic == Intrinsic::udot);
Value *A = Orig->getOperand(0);
Value *B = Orig->getOperand(1);
@@ -124,9 +126,8 @@ static bool expandIntegerDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic
assert(ATy->getScalarType()->isIntegerTy());
Value *Result;
- Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::sdot
- ? Intrinsic::dx_imad
- : Intrinsic::dx_umad;
+ Intrinsic::ID MadIntrinsic =
+ DotIntrinsic == Intrinsic::sdot ? Intrinsic::dx_imad : Intrinsic::dx_umad;
Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
Result = Builder.CreateMul(Elt0, Elt1);
@@ -134,8 +135,8 @@ static bool expandIntegerDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic
Elt0 = Builder.CreateExtractElement(A, i);
Elt1 = Builder.CreateExtractElement(B, i);
Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
- ArrayRef<Value *>{Elt0, Elt1, Result},
- nullptr, "dx.mad");
+ ArrayRef<Value *>{Elt0, Elt1, Result},
+ nullptr, "dx.mad");
}
Orig->replaceAllUsesWith(Result);
|
Why would using the partial reduction intrinsic stop you from using hardware-specific dot product lowerings for GPUs? The lowering is quite trivial, see here. I think it would be best to not introduce another way of doing the same thing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please create a separate RFC for these intrinsics. I don't think there is a consensus on these intrinsics, and https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294 covers way too much disparate ground (something like "tan" and something like "dot" are entirely separate categories, especially if you also want to include integer intrinsics).
@nikic @SamTebbs33 For us using an intrisnic better for the float case because you can specify rounding modes. I don't believe Also we are coming at this from different use cases. We want to support language features for now that will be HLSL, but maybe later that includes c++ 20s |
def G_FDOTPROD : GenericInstruction { | ||
let OutOperandList = (outs type0:$dst); | ||
let InOperandList = (ins type0:$src1, type0:$src2); | ||
let hasSideEffects = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dot products take two vectors and return a scalar.
f G_FDOTPROD : GenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type1:$src1, type1:$src2);
let hasSideEffects = false;
You have to change the sources to type1 and adapt your legalizer accordingly.
Please update the https://github.com/llvm/llvm-project/blob/main/llvm/lib/CodeGen/MachineVerifier.cpp to protect against misuse. |
…ller All expansions end with replacing the previous inrinsic with the new expansion and erasing the old one. By moving this operation to the caller, these expansion functions can be called in more contexts and a small amount of duplicated code is consolidated.
if (Result) { | ||
Orig->replaceAllUsesWith(Result); | ||
Orig->eraseFromParent(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a return true to the if block. just keep it simple and leave the return false;
instead of !!Result
@@ -303,6 +303,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { | |||
getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( | |||
allFloatScalarsAndVectors, allIntScalarsAndVectors); | |||
|
|||
getActionDefinitionsBuilder(G_FDOTPROD) | |||
.legalForCartesianProduct(allFloatScalarsAndVectors, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not convinced that this is correct. Are you ruling out vectors for the sum type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getActionDefinitionBuilder(G_FDOTPROD)
.legalFor({{s32, s32vector}, {s64, s64vector}, ....});
There is some discussion in the RFC, but I don't see a consensus on the "dot" intrinsic in particular. I personally haven't found the arguments in favor of it particularly compelling. This really needs an RFC specific to that intrinsic (class), which includes a clear definition of the semantics of the intrinsics, why it needs to be a intrinsic and how it maps to different hardware (in particular, whether the chosen definition is actually sufficiently portable). Just looking at the LangRef wordings in this PR, your formulation is not appropriate for target-independent intrinsics. You cannot have a target-independent intrinsic "on a 2-4 element vector of 16-bit or 32-bit floating-point types". The description is also very vague for a floating-point operation. "The arguments are vectors to be elementwise multiplied and then summed" does not tell me how the summation occurs. Does it happen in order? As a tree reduction? In unspecified order? Does it use FMA? Your definitions for sdot and udot also don't really make sense to me -- addition and multiplication are signedness-independent operations, so your sdot and udot are in fact equivalent. Separating these intrinsics only makes sense if they also involve a zero or sign extension, which your formulation does not -- but which would be necessary to support other targets, as pointed out above. I'll also add that if you are adding target-independent intrinsics, the baseline expectation is that you need to provide full SDAG legalization support for them. I think we've been burned by this enough times that we're not going to accept new target-independent intrinsics that fail to do this. |
@pow2clk @farzonl and I discussed this a bit offline and we've come to agree with you here - while the RFC certainly does point out that there's demand for a generic dot intrinsic I think we were a bit overzealous to try to take it on in this context. What we need in HLSL is a fairly specific subset of the dot operation and it's probably more appropriate for us to just handle that with DirectX and SPIR-V specific intrinsics that do that rather than try to take on defining a fully generic dot operation here. I'll update the RFC to drop the dot and rsqrt intrinsics, and if we later want to revisit those in a generic way they can be handled on their own merits and pitfalls. |
Closing in light of the above. A new PR will capture the DXIL and SPIRV-specific work |
Here's the new PR for anyone who wants to keep following along in its altered state: #104656 |
Per https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294
dot
should be an LLVM intrinsic. This adds the llvm intrinsicsand updates HLSL builtin codegen to emit them.
Removed some stale comments that gave the obsolete impression that
type conversions should be expected to match overloads.
With dot moving into an LLVM intrinsic, the lowering to dx-specific
operations doesn't take place until DXIL intrinsic expansion. This
moves the introduction of arity-specific DX opcodes to DXIL
intrinsic expansion.
The new LLVM integer intrinsics replace the previous dx intrinsics.
This updates the DXIL intrinsic expansion code and tests to use and
expect the new integer intrinsics and the flattened DX floating
vector size variants only after op lowering.
Use the new LLVM dot intrinsics to build SPIRV instructions.
This involves generating multiply and add operations for integers
and the existing OpDot operation for floating point. This includes
adding some generic opcodes for signed, unsigned and floats.
These require updating an existing test for all such opcodes.
New tests for generating SPIRV float and integer dot intrinsics are
added as well.
Fixes #88056