diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 7fe80b0cbdfbf..c14b7f41a3750 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -18470,22 +18470,14 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments, return Arg; } -Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) { - if (QT->hasFloatingRepresentation()) { - switch (elementCount) { - case 2: - return Intrinsic::dx_dot2; - case 3: - return Intrinsic::dx_dot3; - case 4: - return Intrinsic::dx_dot4; - } - } - if (QT->hasSignedIntegerRepresentation()) - return Intrinsic::dx_sdot; - - assert(QT->hasUnsignedIntegerRepresentation()); - return Intrinsic::dx_udot; +// Return dot product intrinsic that corresponds to the QT scalar type +Intrinsic::ID getDotProductIntrinsic(QualType QT) { + if (QT->isFloatingType()) + return Intrinsic::fdot; + if (QT->isSignedIntegerType()) + return Intrinsic::sdot; + assert(QT->isUnsignedIntegerType()); + return Intrinsic::udot; } Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, @@ -18528,37 +18520,38 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, Value *Op1 = EmitScalarExpr(E->getArg(1)); llvm::Type *T0 = Op0->getType(); llvm::Type *T1 = Op1->getType(); + + // If the arguments are scalars, just emit a multiply if (!T0->isVectorTy() && !T1->isVectorTy()) { if (T0->isFloatingPointTy()) - return Builder.CreateFMul(Op0, Op1, "dx.dot"); + return Builder.CreateFMul(Op0, Op1, "hlsl.dot"); if (T0->isIntegerTy()) - return Builder.CreateMul(Op0, Op1, "dx.dot"); + return Builder.CreateMul(Op0, Op1, "hlsl.dot"); - // Bools should have been promoted llvm_unreachable( "Scalar dot product is only supported on ints and floats."); } + // For vectors, validate types and emit the appropriate intrinsic + // A VectorSplat should have happened assert(T0->isVectorTy() && T1->isVectorTy() && "Dot product of vector and scalar is not supported."); - // A vector sext or sitofp should have happened - assert(T0->getScalarType() == T1->getScalarType() && - "Dot product of vectors need the same element types."); - auto *VecTy0 = E->getArg(0)->getType()->getAs(); [[maybe_unused]] auto *VecTy1 = E->getArg(1)->getType()->getAs(); - // A HLSLVectorTruncation should have happend + + assert(VecTy0->getElementType() == VecTy1->getElementType() && + "Dot product of vectors need the same element types."); + assert(VecTy0->getNumElements() == VecTy1->getNumElements() && "Dot product requires vectors to be of the same size."); return Builder.CreateIntrinsic( /*ReturnType=*/T0->getScalarType(), - getDotProductIntrinsic(E->getArg(0)->getType(), - VecTy0->getNumElements()), - ArrayRef{Op0, Op1}, nullptr, "dx.dot"); + getDotProductIntrinsic(VecTy0->getElementType()), + ArrayRef{Op0, Op1}, nullptr, "hlsl.dot"); } break; case Builtin::BI__builtin_hlsl_lerp: { Value *X = EmitScalarExpr(E->getArg(0)); diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl index b0b95074c972d..482f089d4770f 100644 --- a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl +++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl @@ -2,8 +2,8 @@ // CHECK-LABEL: builtin_bool_to_float_type_promotion // CHECK: %conv1 = uitofp i1 %loadedv to double -// CHECK: %dx.dot = fmul double %conv, %conv1 -// CHECK: %conv2 = fptrunc double %dx.dot to float +// CHECK: %hlsl.dot = fmul double %conv, %conv1 +// CHECK: %conv2 = fptrunc double %hlsl.dot to float // CHECK: ret float %conv2 float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) { return __builtin_hlsl_dot ( p0, p1 ); @@ -12,8 +12,8 @@ float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) { // CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion // CHECK: %conv = uitofp i1 %loadedv to double // CHECK: %conv1 = fpext float %1 to double -// CHECK: %dx.dot = fmul double %conv, %conv1 -// CHECK: %conv2 = fptrunc double %dx.dot to float +// CHECK: %hlsl.dot = fmul double %conv, %conv1 +// CHECK: %conv2 = fptrunc double %hlsl.dot to float // CHECK: ret float %conv2 float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) { return __builtin_hlsl_dot ( p0, p1 ); @@ -22,8 +22,8 @@ float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) { // CHECK-LABEL: builtin_dot_int_to_float_promotion // CHECK: %conv = fpext float %0 to double // CHECK: %conv1 = sitofp i32 %1 to double -// CHECK: dx.dot = fmul double %conv, %conv1 -// CHECK: %conv2 = fptrunc double %dx.dot to float +// CHECK: dot = fmul double %conv, %conv1 +// CHECK: %conv2 = fptrunc double %hlsl.dot to float // CHECK: ret float %conv2 float builtin_dot_int_to_float_promotion ( float p0, int p1 ) { return __builtin_hlsl_dot ( p0, p1 ); diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl index ae6e45c3f9482..05233be916526 100644 --- a/clang/test/CodeGenHLSL/builtins/dot.hlsl +++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl @@ -7,155 +7,155 @@ // RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF #ifdef __HLSL_ENABLE_16_BIT -// NATIVE_HALF: %dx.dot = mul i16 %0, %1 -// NATIVE_HALF: ret i16 %dx.dot +// NATIVE_HALF: %hlsl.dot = mul i16 %0, %1 +// NATIVE_HALF: ret i16 %hlsl.dot int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1) -// NATIVE_HALF: ret i16 %dx.dot +// NATIVE_HALF: %hlsl.dot = call i16 @llvm.sdot.v2i16(<2 x i16> %0, <2 x i16> %1) +// NATIVE_HALF: ret i16 %hlsl.dot int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1) -// NATIVE_HALF: ret i16 %dx.dot +// NATIVE_HALF: %hlsl.dot = call i16 @llvm.sdot.v3i16(<3 x i16> %0, <3 x i16> %1) +// NATIVE_HALF: ret i16 %hlsl.dot int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1) -// NATIVE_HALF: ret i16 %dx.dot +// NATIVE_HALF: %hlsl.dot = call i16 @llvm.sdot.v4i16(<4 x i16> %0, <4 x i16> %1) +// NATIVE_HALF: ret i16 %hlsl.dot int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = mul i16 %0, %1 -// NATIVE_HALF: ret i16 %dx.dot +// NATIVE_HALF: %hlsl.dot = mul i16 %0, %1 +// NATIVE_HALF: ret i16 %hlsl.dot uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1) -// NATIVE_HALF: ret i16 %dx.dot +// NATIVE_HALF: %hlsl.dot = call i16 @llvm.udot.v2i16(<2 x i16> %0, <2 x i16> %1) +// NATIVE_HALF: ret i16 %hlsl.dot uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1) -// NATIVE_HALF: ret i16 %dx.dot +// NATIVE_HALF: %hlsl.dot = call i16 @llvm.udot.v3i16(<3 x i16> %0, <3 x i16> %1) +// NATIVE_HALF: ret i16 %hlsl.dot uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1) -// NATIVE_HALF: ret i16 %dx.dot +// NATIVE_HALF: %hlsl.dot = call i16 @llvm.udot.v4i16(<4 x i16> %0, <4 x i16> %1) +// NATIVE_HALF: ret i16 %hlsl.dot uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); } #endif -// CHECK: %dx.dot = mul i32 %0, %1 -// CHECK: ret i32 %dx.dot +// CHECK: %hlsl.dot = mul i32 %0, %1 +// CHECK: ret i32 %hlsl.dot int test_dot_int(int p0, int p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1) -// CHECK: ret i32 %dx.dot +// CHECK: %hlsl.dot = call i32 @llvm.sdot.v2i32(<2 x i32> %0, <2 x i32> %1) +// CHECK: ret i32 %hlsl.dot int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1) -// CHECK: ret i32 %dx.dot +// CHECK: %hlsl.dot = call i32 @llvm.sdot.v3i32(<3 x i32> %0, <3 x i32> %1) +// CHECK: ret i32 %hlsl.dot int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1) -// CHECK: ret i32 %dx.dot +// CHECK: %hlsl.dot = call i32 @llvm.sdot.v4i32(<4 x i32> %0, <4 x i32> %1) +// CHECK: ret i32 %hlsl.dot int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = mul i32 %0, %1 -// CHECK: ret i32 %dx.dot +// CHECK: %hlsl.dot = mul i32 %0, %1 +// CHECK: ret i32 %hlsl.dot uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1) -// CHECK: ret i32 %dx.dot +// CHECK: %hlsl.dot = call i32 @llvm.udot.v2i32(<2 x i32> %0, <2 x i32> %1) +// CHECK: ret i32 %hlsl.dot uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1) -// CHECK: ret i32 %dx.dot +// CHECK: %hlsl.dot = call i32 @llvm.udot.v3i32(<3 x i32> %0, <3 x i32> %1) +// CHECK: ret i32 %hlsl.dot uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1) -// CHECK: ret i32 %dx.dot +// CHECK: %hlsl.dot = call i32 @llvm.udot.v4i32(<4 x i32> %0, <4 x i32> %1) +// CHECK: ret i32 %hlsl.dot uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = mul i64 %0, %1 -// CHECK: ret i64 %dx.dot +// CHECK: %hlsl.dot = mul i64 %0, %1 +// CHECK: ret i64 %hlsl.dot int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1) -// CHECK: ret i64 %dx.dot +// CHECK: %hlsl.dot = call i64 @llvm.sdot.v2i64(<2 x i64> %0, <2 x i64> %1) +// CHECK: ret i64 %hlsl.dot int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1) -// CHECK: ret i64 %dx.dot +// CHECK: %hlsl.dot = call i64 @llvm.sdot.v3i64(<3 x i64> %0, <3 x i64> %1) +// CHECK: ret i64 %hlsl.dot int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1) -// CHECK: ret i64 %dx.dot +// CHECK: %hlsl.dot = call i64 @llvm.sdot.v4i64(<4 x i64> %0, <4 x i64> %1) +// CHECK: ret i64 %hlsl.dot int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = mul i64 %0, %1 -// CHECK: ret i64 %dx.dot +// CHECK: %hlsl.dot = mul i64 %0, %1 +// CHECK: ret i64 %hlsl.dot uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1) -// CHECK: ret i64 %dx.dot +// CHECK: %hlsl.dot = call i64 @llvm.udot.v2i64(<2 x i64> %0, <2 x i64> %1) +// CHECK: ret i64 %hlsl.dot uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1) -// CHECK: ret i64 %dx.dot +// CHECK: %hlsl.dot = call i64 @llvm.udot.v3i64(<3 x i64> %0, <3 x i64> %1) +// CHECK: ret i64 %hlsl.dot uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1) -// CHECK: ret i64 %dx.dot +// CHECK: %hlsl.dot = call i64 @llvm.udot.v4i64(<4 x i64> %0, <4 x i64> %1) +// CHECK: ret i64 %hlsl.dot uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = fmul half %0, %1 -// NATIVE_HALF: ret half %dx.dot -// NO_HALF: %dx.dot = fmul float %0, %1 -// NO_HALF: ret float %dx.dot +// NATIVE_HALF: %hlsl.dot = fmul half %0, %1 +// NATIVE_HALF: ret half %hlsl.dot +// NO_HALF: %hlsl.dot = fmul float %0, %1 +// NO_HALF: ret float %hlsl.dot half test_dot_half(half p0, half p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1) -// NATIVE_HALF: ret half %dx.dot -// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1) -// NO_HALF: ret float %dx.dot +// NATIVE_HALF: %hlsl.dot = call half @llvm.fdot.v2f16(<2 x half> %0, <2 x half> %1) +// NATIVE_HALF: ret half %hlsl.dot +// NO_HALF: %hlsl.dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1) +// NO_HALF: ret float %hlsl.dot half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1) -// NATIVE_HALF: ret half %dx.dot -// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1) -// NO_HALF: ret float %dx.dot +// NATIVE_HALF: %hlsl.dot = call half @llvm.fdot.v3f16(<3 x half> %0, <3 x half> %1) +// NATIVE_HALF: ret half %hlsl.dot +// NO_HALF: %hlsl.dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1) +// NO_HALF: ret float %hlsl.dot half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1) -// NATIVE_HALF: ret half %dx.dot -// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1) -// NO_HALF: ret float %dx.dot +// NATIVE_HALF: %hlsl.dot = call half @llvm.fdot.v4f16(<4 x half> %0, <4 x half> %1) +// NATIVE_HALF: ret half %hlsl.dot +// NO_HALF: %hlsl.dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1) +// NO_HALF: ret float %hlsl.dot half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = fmul float %0, %1 -// CHECK: ret float %dx.dot +// CHECK: %hlsl.dot = fmul float %0, %1 +// CHECK: ret float %hlsl.dot float test_dot_float(float p0, float p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1) -// CHECK: ret float %dx.dot +// CHECK: %hlsl.dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1) +// CHECK: ret float %hlsl.dot float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1) -// CHECK: ret float %dx.dot +// CHECK: %hlsl.dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1) +// CHECK: ret float %hlsl.dot float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1) -// CHECK: ret float %dx.dot +// CHECK: %hlsl.dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1) +// CHECK: ret float %hlsl.dot float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1) -// CHECK: ret float %dx.dot +// CHECK: %hlsl.dot = call float @llvm.fdot.v2f32(<2 x float> %splat.splat, <2 x float> %1) +// CHECK: ret float %hlsl.dot float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1) -// CHECK: ret float %dx.dot +// CHECK: %hlsl.dot = call float @llvm.fdot.v3f32(<3 x float> %splat.splat, <3 x float> %1) +// CHECK: ret float %hlsl.dot float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1) -// CHECK: ret float %dx.dot +// CHECK: %hlsl.dot = call float @llvm.fdot.v4f32(<4 x float> %splat.splat, <4 x float> %1) +// CHECK: ret float %hlsl.dot float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = fmul double %0, %1 -// CHECK: ret double %dx.dot +// CHECK: %hlsl.dot = fmul double %0, %1 +// CHECK: ret double %hlsl.dot double test_dot_double(double p0, double p1) { return dot(p0, p1); } diff --git a/llvm/docs/GlobalISel/GenericOpcode.rst b/llvm/docs/GlobalISel/GenericOpcode.rst index d32aeff5a69bb..5a6994452c0ef 100644 --- a/llvm/docs/GlobalISel/GenericOpcode.rst +++ b/llvm/docs/GlobalISel/GenericOpcode.rst @@ -633,6 +633,14 @@ G_FCOS, G_FSIN, G_FTAN, G_FACOS, G_FASIN, G_FATAN, G_FCOSH, G_FSINH, G_FTANH These correspond to the standard C trigonometry functions of the same name. + +G_FDOTPROD, G_SDOTPROD, G_UDOTPROD +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +These represent the floating point, signed integer, and unsigned integer dot products respectively. +A dot product takes two equal-sized vectors and multiplies each element by the element in the corresponding +location of the other vector and then sums all the products, returning a scalar value. + G_INTRINSIC_TRUNC ^^^^^^^^^^^^^^^^^ diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst index 0ee4d7b444cfc..b494ec734099a 100644 --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -15674,6 +15674,138 @@ trapping or setting ``errno``. When specified with the fast-math-flag 'afn', the result may be approximated using a less accurate calculation. +'``llvm.fdot.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +This is an overloaded intrinsic. You can use ``llvm.fdot`` on a 2-4 element +vector of 16-bit or 32-bit floating-point types. Not all targets support +all types however. + +:: + + declare half @llvm.fdot.v2f16(<2 x half> %Vec0, <2 x half> %Vec1) + declare half @llvm.fdot.v3f16(<3 x half> %Vec0, <3 x half> %Vec1) + declare half @llvm.fdot.v4f16(<4 x half> %Vec0, <4 x half> %Vec1) + declare float @llvm.fdot.v2f32(<2 x float> %Vec0, <2 x float> %Vec1) + declare float @llvm.fdot.v3f32(<3 x float> %Vec0, <3 x float> %Vec1) + declare float @llvm.fdot.v4f32(<4 x float> %Vec0, <4 x float> %Vec1) + +Overview: +""""""""" + +The '``llvm.fdot.*``' intrinsics return the dot product of the two vector operands. +A dot product takes two equal-sized vectors and multiplies each element by the element in the corresponding +location of the other vector and then sums all the products, returning a scalar value. + + +Arguments: +"""""""""" + +The arguments are vectors to be elementwise multiplied and then summed. +The vectors may be of 2-4 elements and contain 16-bit or 32-bit float elements. +The arguments must be vectors of the same size and scalar element type. +The return type must match the scalar type of the elements. + +Semantics: +"""""""""" + +Return the summation of the products of the elements of the first and second arguments using +floating point arithmetic operations. + + +'``llvm.sdot.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +This is an overloaded intrinsic. You can use ``llvm.sdot`` on a 2-4 element +vector of 16-bit to 64-bit signed integer types. Not all targets support +all types however. + +:: + + declare i16 @llvm.sdot.v2i16(<2 x i16> %Vec0, <2 x i16> %Vec1) + declare i16 @llvm.sdot.v3i16(<3 x i16> %Vec0, <3 x i16> %Vec1) + declare i16 @llvm.sdot.v4i16(<4 x i16> %Vec0, <4 x i16> %Vec1) + declare i32 @llvm.sdot.v2i32(<2 x i32> %Vec0, <2 x i32> %Vec1) + declare i32 @llvm.sdot.v3i32(<3 x i32> %Vec0, <3 x i32> %Vec1) + declare i32 @llvm.sdot.v4i32(<4 x i32> %Vec0, <4 x i32> %Vec1) + declare i64 @llvm.sdot.v2i64(<2 x i64> %Vec0, <2 x i64> %Vec1) + declare i64 @llvm.sdot.v3i64(<3 x i64> %Vec0, <3 x i64> %Vec1) + declare i64 @llvm.sdot.v4i64(<4 x i64> %Vec0, <4 x i64> %Vec1) + +Overview: +""""""""" + +The '``llvm.sdot.*``' intrinsics return the dot product of the two vector operands. +A dot product takes two equal-sized vectors and multiplies each element by the element in the corresponding +location of the other vector and then sums all the products, returning a scalar value. + + +Arguments: +"""""""""" + +The arguments are vectors to be elementwise multiplied and then summed. +The vectors may be of 2-4 elements and contain 16-bit to 64-bit signed integer elements. +The arguments must be vectors of the same size and scalar element type. +The return type must match the scalar type of the elements. + +Semantics: +"""""""""" + +Return the summation of the products of the elements of the first and second arguments using +signed integer arithmetic operations. + + +'``llvm.udot.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +This is an overloaded intrinsic. You can use ``llvm.udot`` on a 2-4 element +vector of 16-bit to 64-bit unsigned integer types. Not all targets support +all types however. + +:: + + declare i16 @llvm.udot.v2i16(<2 x i16> %Vec0, <2 x i16> %Vec1) + declare i16 @llvm.udot.v3i16(<3 x i16> %Vec0, <3 x i16> %Vec1) + declare i16 @llvm.udot.v4i16(<4 x i16> %Vec0, <4 x i16> %Vec1) + declare i32 @llvm.udot.v2i32(<2 x i32> %Vec0, <2 x i32> %Vec1) + declare i32 @llvm.udot.v3i32(<3 x i32> %Vec0, <3 x i32> %Vec1) + declare i32 @llvm.udot.v4i32(<4 x i32> %Vec0, <4 x i32> %Vec1) + declare i64 @llvm.udot.v2i64(<2 x i64> %Vec0, <2 x i64> %Vec1) + declare i64 @llvm.udot.v3i64(<3 x i64> %Vec0, <3 x i64> %Vec1) + declare i64 @llvm.udot.v4i64(<4 x i64> %Vec0, <4 x i64> %Vec1) + +Overview: +""""""""" + +The '``llvm.udot.*``' intrinsics return the dot product of the two vector operands. +A dot product takes two equal-sized vectors and multiplies each element by the element in the corresponding +location of the other vector and then sums all the products, returning a scalar value. + + +Arguments: +"""""""""" + +The arguments are vectors to be elementwise multiplied and then summed. +The vectors may be of 2-4 elements and contain 16-bit to 64-bit unsigned integer elements. +The arguments must be vectors of the same size and scalar element type. +The return type must match the scalar type of the elements. + +Semantics: +"""""""""" + +Return the summation of the products of the elements of the first and second arguments using +unsigned integer arithmetic operations. + + '``llvm.pow.*``' Intrinsic ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td index b4e758136b39f..d412ac4167cf8 100644 --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1045,6 +1045,15 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable, IntrWillReturn] in { def int_nearbyint : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>; def int_round : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>; def int_roundeven : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>; + def int_udot : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], + [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, Commutative] >; + def int_sdot : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], + [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, Commutative] >; + def int_fdot : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], + [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, Commutative] >; // Truncate a floating point number with a specific rounding mode def int_fptrunc_round : DefaultAttrsIntrinsic<[ llvm_anyfloat_ty ], diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 312c3862f240d..8ce79eb7cbaaf 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -25,26 +25,18 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>; def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; -def int_dx_dot2 : - Intrinsic<[LLVMVectorElementType<0>], +def int_dx_dot2 : + Intrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, IntrWillReturn, Commutative] >; -def int_dx_dot3 : - Intrinsic<[LLVMVectorElementType<0>], +def int_dx_dot3 : + Intrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, IntrWillReturn, Commutative] >; -def int_dx_dot4 : - Intrinsic<[LLVMVectorElementType<0>], +def int_dx_dot4 : + Intrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, IntrWillReturn, Commutative] >; -def int_dx_sdot : - Intrinsic<[LLVMVectorElementType<0>], - [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], - [IntrNoMem, IntrWillReturn, Commutative] >; -def int_dx_udot : - Intrinsic<[LLVMVectorElementType<0>], - [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], - [IntrNoMem, IntrWillReturn, Commutative] >; def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>; diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def index 9fb6de49fb205..0808fd9d77be8 100644 --- a/llvm/include/llvm/Support/TargetOpcodes.def +++ b/llvm/include/llvm/Support/TargetOpcodes.def @@ -814,6 +814,15 @@ HANDLE_TARGET_OPCODE(G_FSINH) /// Floating point hyperbolic tangent. HANDLE_TARGET_OPCODE(G_FTANH) +/// Floating point vector dot product +HANDLE_TARGET_OPCODE(G_FDOTPROD) + +/// Unsigned integer vector dot product +HANDLE_TARGET_OPCODE(G_UDOTPROD) + +/// Signed integer vector dot product +HANDLE_TARGET_OPCODE(G_SDOTPROD) + /// Floating point square root. HANDLE_TARGET_OPCODE(G_FSQRT) diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td index 36a0a087ba457..ced103debf609 100644 --- a/llvm/include/llvm/Target/GenericOpcodes.td +++ b/llvm/include/llvm/Target/GenericOpcodes.td @@ -1088,6 +1088,27 @@ def G_FNEARBYINT : GenericInstruction { let hasSideEffects = false; } +/// Floating point vector dot product +def G_FDOTPROD : GenericInstruction { + let OutOperandList = (outs type0:$dst); + let InOperandList = (ins type1:$src1, type1:$src2); + let hasSideEffects = false; +} + +/// Signed integer vector dot product +def G_SDOTPROD : GenericInstruction { + let OutOperandList = (outs type0:$dst); + let InOperandList = (ins type1:$src1, type1:$src2); + let hasSideEffects = false; +} + +/// Unsigned integer vector dot product +def G_UDOTPROD : GenericInstruction { + let OutOperandList = (outs type0:$dst); + let InOperandList = (ins type1:$src1, type1:$src2); + let hasSideEffects = false; +} + //------------------------------------------------------------------------------ // Access to floating-point environment. //------------------------------------------------------------------------------ diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index 0169a0e466d87..d471b172a7dcb 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -1903,6 +1903,12 @@ unsigned IRTranslator::getSimpleIntrinsicOpcode(Intrinsic::ID ID) { return TargetOpcode::G_CTPOP; case Intrinsic::exp: return TargetOpcode::G_FEXP; + case Intrinsic::fdot: + return TargetOpcode::G_FDOTPROD; + case Intrinsic::sdot: + return TargetOpcode::G_SDOTPROD; + case Intrinsic::udot: + return TargetOpcode::G_UDOTPROD; case Intrinsic::exp2: return TargetOpcode::G_FEXP2; case Intrinsic::exp10: diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp index d22fbe322ec36..586ea587ef856 100644 --- a/llvm/lib/CodeGen/MachineVerifier.cpp +++ b/llvm/lib/CodeGen/MachineVerifier.cpp @@ -2092,6 +2092,36 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) { } break; } + + case TargetOpcode::G_SDOTPROD: + case TargetOpcode::G_UDOTPROD: + case TargetOpcode::G_FDOTPROD: { + LLT DstTy = MRI->getType(MI->getOperand(0).getReg()); + if (!DstTy.isScalar()) { + report("Destination must be a scalar", MI); + break; + } + LLT Src0Ty = MRI->getType(MI->getOperand(1).getReg()); + LLT Src1Ty = MRI->getType(MI->getOperand(2).getReg()); + LLT Src0EltTy = Src0Ty.getScalarType(); + LLT Src1EltTy = Src1Ty.getScalarType(); + + if (!Src0Ty.isVector() || !Src1Ty.isVector()) { + report("Sources must be vectors", MI); + break; + } + if (Src0EltTy == Src1EltTy) { + report("Source vectors must have the same scalar element types", MI); + } + if (Src0EltTy != DstTy.getScalarType()) { + report("Destination type must match source element types", MI); + break; + } + + if (!verifyVectorElementMatch(Src0Ty, Src1Ty, MI)) + break; + break; + } case TargetOpcode::G_PREFETCH: { const MachineOperand &AddrOp = MI->getOperand(0); if (!AddrOp.isReg() || !MRI->getType(AddrOp.getReg()).isPointer()) { diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 67015cff78a79..ac79b84a1e910 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -637,7 +637,7 @@ def UMad : DXILOp<49, tertiary> { def Dot2 : DXILOp<54, dot2> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " - "a[n]*b[n] where n is between 0 and 1"; + "a[n]*b[n] where n is 0 to 1 inclusive"; let LLVMIntrinsic = int_dx_dot2; let arguments = !listsplat(overloadTy, 4); let result = overloadTy; @@ -648,7 +648,7 @@ def Dot2 : DXILOp<54, dot2> { def Dot3 : DXILOp<55, dot3> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " - "a[n]*b[n] where n is between 0 and 2"; + "a[n]*b[n] where n is 0 to 2 inclusive"; let LLVMIntrinsic = int_dx_dot3; let arguments = !listsplat(overloadTy, 6); let result = overloadTy; @@ -659,7 +659,7 @@ def Dot3 : DXILOp<55, dot3> { def Dot4 : DXILOp<56, dot4> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " - "a[n]*b[n] where n is between 0 and 3"; + "a[n]*b[n] where n is 0 to 3 inclusive"; let LLVMIntrinsic = int_dx_dot4; let arguments = !listsplat(overloadTy, 8); let result = overloadTy; diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index ac85859af8a53..b67a0147bca71 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -43,14 +43,15 @@ static bool isIntrinsicExpansion(Function &F) { case Intrinsic::dx_uclamp: case Intrinsic::dx_lerp: case Intrinsic::dx_length: - case Intrinsic::dx_sdot: - case Intrinsic::dx_udot: + case Intrinsic::sdot: + case Intrinsic::udot: + case Intrinsic::fdot: return true; } return false; } -static bool expandAbs(CallInst *Orig) { +static Value *expandAbs(CallInst *Orig) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); Builder.SetInsertPoint(Orig); @@ -65,43 +66,76 @@ static bool expandAbs(CallInst *Orig) { auto *V = Builder.CreateSub(Zero, X); auto *MaxCall = Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max"); - Orig->replaceAllUsesWith(MaxCall); - Orig->eraseFromParent(); - return true; + return MaxCall; } -static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) { - assert(DotIntrinsic == Intrinsic::dx_sdot || - DotIntrinsic == Intrinsic::dx_udot); - Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot - ? Intrinsic::dx_imad - : Intrinsic::dx_umad; +// Create DXIL dot intrinsics for floating point dot operations +static Value *expandFloatDotIntrinsic(CallInst *Orig) { Value *A = Orig->getOperand(0); Value *B = Orig->getOperand(1); - [[maybe_unused]] Type *ATy = A->getType(); + + Type *ATy = A->getType(); [[maybe_unused]] Type *BTy = B->getType(); assert(ATy->isVectorTy() && BTy->isVectorTy()); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); + + auto *AVec = dyn_cast(ATy); + + assert(ATy->getScalarType()->isFloatingPointTy()); + + Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4; + switch (AVec->getNumElements()) { + case 2: + DotIntrinsic = Intrinsic::dx_dot2; + break; + case 3: + DotIntrinsic = Intrinsic::dx_dot3; + break; + case 4: + DotIntrinsic = Intrinsic::dx_dot4; + break; + default: + llvm_unreachable("dot product with vector outside 2-4 range"); + } + return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, + ArrayRef{A, B}, nullptr, "dot"); +} + +// Expand integer dot product to multiply and add ops +static Value *expandIntegerDotIntrinsic(CallInst *Orig, + Intrinsic::ID DotIntrinsic) { + assert(DotIntrinsic == Intrinsic::sdot || DotIntrinsic == Intrinsic::udot); + Value *A = Orig->getOperand(0); + Value *B = Orig->getOperand(1); + + Type *ATy = A->getType(); + [[maybe_unused]] Type *BTy = B->getType(); + assert(ATy->isVectorTy() && BTy->isVectorTy()); + + IRBuilder<> Builder(Orig); + + auto *AVec = dyn_cast(ATy); - auto *AVec = dyn_cast(A->getType()); + assert(ATy->getScalarType()->isIntegerTy()); + + Value *Result; + Intrinsic::ID MadIntrinsic = + DotIntrinsic == Intrinsic::sdot ? Intrinsic::dx_imad : Intrinsic::dx_umad; Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0); Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0); - Value *Result = Builder.CreateMul(Elt0, Elt1); - for (unsigned I = 1; I < AVec->getNumElements(); I++) { - Elt0 = Builder.CreateExtractElement(A, I); - Elt1 = Builder.CreateExtractElement(B, I); + Result = Builder.CreateMul(Elt0, Elt1); + for (unsigned i = 1; i < AVec->getNumElements(); i++) { + Elt0 = Builder.CreateExtractElement(A, i); + Elt1 = Builder.CreateExtractElement(B, i); Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic, ArrayRef{Elt0, Elt1, Result}, nullptr, "dx.mad"); } - Orig->replaceAllUsesWith(Result); - Orig->eraseFromParent(); - return true; + return Result; } -static bool expandExpIntrinsic(CallInst *Orig) { +static Value *expandExpIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); Builder.SetInsertPoint(Orig); @@ -118,23 +152,21 @@ static bool expandExpIntrinsic(CallInst *Orig) { Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2"); Exp2Call->setTailCall(Orig->isTailCall()); Exp2Call->setAttributes(Orig->getAttributes()); - Orig->replaceAllUsesWith(Exp2Call); - Orig->eraseFromParent(); - return true; + return Exp2Call; } -static bool expandAnyIntrinsic(CallInst *Orig) { +static Value *expandAnyIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); Builder.SetInsertPoint(Orig); Type *Ty = X->getType(); Type *EltTy = Ty->getScalarType(); + Value *Result = nullptr; if (!Ty->isVectorTy()) { - Value *Cond = EltTy->isFloatingPointTy() - ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0)) - : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0)); - Orig->replaceAllUsesWith(Cond); + Result = EltTy->isFloatingPointTy() + ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0)) + : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0)); } else { auto *XVec = dyn_cast(Ty); Value *Cond = @@ -147,18 +179,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) { X, ConstantVector::getSplat( ElementCount::getFixed(XVec->getNumElements()), ConstantInt::get(EltTy, 0))); - Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0); + Result = Builder.CreateExtractElement(Cond, (uint64_t)0); for (unsigned I = 1; I < XVec->getNumElements(); I++) { Value *Elt = Builder.CreateExtractElement(Cond, I); Result = Builder.CreateOr(Result, Elt); } - Orig->replaceAllUsesWith(Result); } - Orig->eraseFromParent(); - return true; + return Result; } -static bool expandLengthIntrinsic(CallInst *Orig) { +static Value *expandLengthIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); Builder.SetInsertPoint(Orig); @@ -181,15 +211,11 @@ static bool expandLengthIntrinsic(CallInst *Orig) { Value *Mul = Builder.CreateFMul(Elt, Elt); Sum = Builder.CreateFAdd(Sum, Mul); } - Value *Result = Builder.CreateIntrinsic( - EltTy, Intrinsic::sqrt, ArrayRef{Sum}, nullptr, "elt.sqrt"); - - Orig->replaceAllUsesWith(Result); - Orig->eraseFromParent(); - return true; + return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef{Sum}, + nullptr, "elt.sqrt"); } -static bool expandLerpIntrinsic(CallInst *Orig) { +static Value *expandLerpIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); Value *Y = Orig->getOperand(1); Value *S = Orig->getOperand(2); @@ -197,14 +223,11 @@ static bool expandLerpIntrinsic(CallInst *Orig) { Builder.SetInsertPoint(Orig); auto *V = Builder.CreateFSub(Y, X); V = Builder.CreateFMul(S, V); - auto *Result = Builder.CreateFAdd(X, V, "dx.lerp"); - Orig->replaceAllUsesWith(Result); - Orig->eraseFromParent(); - return true; + return Builder.CreateFAdd(X, V, "dx.lerp"); } -static bool expandLogIntrinsic(CallInst *Orig, - float LogConstVal = numbers::ln2f) { +static Value *expandLogIntrinsic(CallInst *Orig, + float LogConstVal = numbers::ln2f) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); Builder.SetInsertPoint(Orig); @@ -220,16 +243,13 @@ static bool expandLogIntrinsic(CallInst *Orig, Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2"); Log2Call->setTailCall(Orig->isTailCall()); Log2Call->setAttributes(Orig->getAttributes()); - auto *Result = Builder.CreateFMul(Ln2Const, Log2Call); - Orig->replaceAllUsesWith(Result); - Orig->eraseFromParent(); - return true; + return Builder.CreateFMul(Ln2Const, Log2Call); } -static bool expandLog10Intrinsic(CallInst *Orig) { +static Value *expandLog10Intrinsic(CallInst *Orig) { return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f); } -static bool expandPowIntrinsic(CallInst *Orig) { +static Value *expandPowIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); Value *Y = Orig->getOperand(1); @@ -244,9 +264,7 @@ static bool expandPowIntrinsic(CallInst *Orig) { Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2"); Exp2Call->setTailCall(Orig->isTailCall()); Exp2Call->setAttributes(Orig->getAttributes()); - Orig->replaceAllUsesWith(Exp2Call); - Orig->eraseFromParent(); - return true; + return Exp2Call; } static Intrinsic::ID getMaxForClamp(Type *ElemTy, @@ -275,7 +293,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy, return Intrinsic::minnum; } -static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) { +static Value *expandClampIntrinsic(CallInst *Orig, + Intrinsic::ID ClampIntrinsic) { Value *X = Orig->getOperand(0); Value *Min = Orig->getOperand(1); Value *Max = Orig->getOperand(2); @@ -284,41 +303,55 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) { Builder.SetInsertPoint(Orig); auto *MaxCall = Builder.CreateIntrinsic( Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max"); - auto *MinCall = - Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic), - {MaxCall, Max}, nullptr, "dx.min"); - - Orig->replaceAllUsesWith(MinCall); - Orig->eraseFromParent(); - return true; + return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic), + {MaxCall, Max}, nullptr, "dx.min"); } static bool expandIntrinsic(Function &F, CallInst *Orig) { + Value *Result = nullptr; switch (F.getIntrinsicID()) { case Intrinsic::abs: - return expandAbs(Orig); + Result = expandAbs(Orig); + break; case Intrinsic::exp: - return expandExpIntrinsic(Orig); + Result = expandExpIntrinsic(Orig); + break; case Intrinsic::log: - return expandLogIntrinsic(Orig); + Result = expandLogIntrinsic(Orig); + break; case Intrinsic::log10: - return expandLog10Intrinsic(Orig); + Result = expandLog10Intrinsic(Orig); + break; case Intrinsic::pow: - return expandPowIntrinsic(Orig); + Result = expandPowIntrinsic(Orig); + break; case Intrinsic::dx_any: - return expandAnyIntrinsic(Orig); + Result = expandAnyIntrinsic(Orig); + break; case Intrinsic::dx_uclamp: case Intrinsic::dx_clamp: - return expandClampIntrinsic(Orig, F.getIntrinsicID()); + Result = expandClampIntrinsic(Orig, F.getIntrinsicID()); + break; case Intrinsic::dx_lerp: - return expandLerpIntrinsic(Orig); + Result = expandLerpIntrinsic(Orig); + break; case Intrinsic::dx_length: - return expandLengthIntrinsic(Orig); - case Intrinsic::dx_sdot: - case Intrinsic::dx_udot: - return expandIntegerDot(Orig, F.getIntrinsicID()); + Result = expandLengthIntrinsic(Orig); + break; + case Intrinsic::fdot: + Result = expandFloatDotIntrinsic(Orig); + break; + case Intrinsic::sdot: + case Intrinsic::udot: + Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID()); + break; } - return false; + + if (Result) { + Orig->replaceAllUsesWith(Result); + Orig->eraseFromParent(); + } + return !!Result; } static bool expansionIntrinsics(Module &M) { diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index ed786bd33aa05..98b7847a61147 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -178,6 +178,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectRsqrt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I, int OpIdx) const; void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I, @@ -380,6 +383,20 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, MIB.addImm(V); return MIB.constrainAllUses(TII, TRI, RBI); } + + case TargetOpcode::G_FDOTPROD: { + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpDot)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(I.getOperand(2).getReg()) + .constrainAllUses(TII, TRI, RBI); + } + case TargetOpcode::G_SDOTPROD: + case TargetOpcode::G_UDOTPROD: + return selectIntegerDot(ResVReg, ResType, I); + case TargetOpcode::G_MEMMOVE: case TargetOpcode::G_MEMCPY: case TargetOpcode::G_MEMSET: @@ -1366,6 +1383,67 @@ bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } +// Since there is no integer dot implementation, expand by piecewise multiplying +// and adding the results, making use of FMA operations where possible. +bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + assert(I.getNumOperands() == 3); + assert(I.getOperand(1).isReg()); + assert(I.getOperand(2).isReg()); + MachineBasicBlock &BB = *I.getParent(); + + // Multiply the vectors, then sum the results + Register Vec0 = I.getOperand(1).getReg(); + Register Vec1 = I.getOperand(2).getReg(); + Register TmpVec = MRI->createVirtualRegister(&SPIRV::IDRegClass); + SPIRVType *VecType = GR.getSPIRVTypeForVReg(Vec0); + + bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulV)) + .addDef(TmpVec) + .addUse(GR.getSPIRVTypeID(VecType)) + .addUse(Vec0) + .addUse(Vec1) + .constrainAllUses(TII, TRI, RBI); + + assert(GR.getScalarOrVectorComponentCount(VecType) > 1 && + "dot product requires a vector of at least 2 components"); + + Register Res = MRI->createVirtualRegister(&SPIRV::IDRegClass); + Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) + .addDef(Res) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(TmpVec) + .addImm(0) + .constrainAllUses(TII, TRI, RBI); + + for (unsigned i = 1; i < GR.getScalarOrVectorComponentCount(VecType); i++) { + Register Elt = MRI->createVirtualRegister(&SPIRV::IDRegClass); + + Result |= + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) + .addDef(Elt) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(TmpVec) + .addImm(i) + .constrainAllUses(TII, TRI, RBI); + + Register Sum = i < GR.getScalarOrVectorComponentCount(VecType) - 1 + ? MRI->createVirtualRegister(&SPIRV::IDRegClass) + : ResVReg; + + Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS)) + .addDef(Sum) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Res) + .addUse(Elt) + .constrainAllUses(TII, TRI, RBI); + Res = Sum; + } + + return Result; +} + bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index e775f8c57b048..1434cb31e4fba 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -303,6 +303,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( allFloatScalarsAndVectors, allIntScalarsAndVectors); + getActionDefinitionsBuilder(G_FDOTPROD) + .legalForCartesianProduct(allFloatScalarsAndVectors, + allFloatScalarsAndVectors); + + getActionDefinitionsBuilder({G_SDOTPROD, G_UDOTPROD}) + .legalForCartesianProduct(allIntScalarsAndVectors, + allIntScalarsAndVectors); + if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { getActionDefinitionsBuilder( {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF}) diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir index 87a415b45cca9..12510b6e58b6e 100644 --- a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir +++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir @@ -716,6 +716,15 @@ # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}} # DEBUG-NEXT: .. the first uncovered type index: 1, OK # DEBUG-NEXT: .. the first uncovered imm index: 0, OK +# DEBUG-NEXT: G_FDOTPROD (opcode {{[0-9]+}}): 2 type indices, 0 imm indices +# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined +# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined +# DEBUG-NEXT: G_UDOTPROD (opcode {{[0-9]+}}): 2 type indices, 0 imm indices +# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined +# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined +# DEBUG-NEXT: G_SDOTPROD (opcode {{[0-9]+}}): 2 type indices, 0 imm indices +# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined +# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined # DEBUG-NEXT: G_FSQRT (opcode {{[0-9]+}}): 1 type index, 0 imm indices # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}} # DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll index 56817a172ff9e..3eb39fd5d4bb7 100644 --- a/llvm/test/CodeGen/DirectX/fdot.ll +++ b/llvm/test/CodeGen/DirectX/fdot.ll @@ -1,94 +1,101 @@ +; RUN: opt -S -dxil-intrinsic-expansion -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK ; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -; Make sure dxil operation function calls for dot are generated for int/uint vectors. +; Make sure dxil operation function calls for dot are generated for float type vectors. ; CHECK-LABEL: dot_half2 define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) { entry: -; CHECK: extractelement <2 x half> %a, i32 0 -; CHECK: extractelement <2 x half> %a, i32 1 -; CHECK: extractelement <2 x half> %b, i32 0 -; CHECK: extractelement <2 x half> %b, i32 1 -; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) - %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b) +; DOPCHECK: extractelement <2 x half> %a, i32 0 +; DOPCHECK: extractelement <2 x half> %a, i32 1 +; DOPCHECK: extractelement <2 x half> %b, i32 0 +; DOPCHECK: extractelement <2 x half> %b, i32 1 +; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) +; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b) + %dx.dot = call half @llvm.fdot.v2f16(<2 x half> %a, <2 x half> %b) ret half %dx.dot } ; CHECK-LABEL: dot_half3 define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) { entry: -; CHECK: extractelement <3 x half> %a, i32 0 -; CHECK: extractelement <3 x half> %a, i32 1 -; CHECK: extractelement <3 x half> %a, i32 2 -; CHECK: extractelement <3 x half> %b, i32 0 -; CHECK: extractelement <3 x half> %b, i32 1 -; CHECK: extractelement <3 x half> %b, i32 2 -; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) - %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b) +; DOPCHECK: extractelement <3 x half> %a, i32 0 +; DOPCHECK: extractelement <3 x half> %a, i32 1 +; DOPCHECK: extractelement <3 x half> %a, i32 2 +; DOPCHECK: extractelement <3 x half> %b, i32 0 +; DOPCHECK: extractelement <3 x half> %b, i32 1 +; DOPCHECK: extractelement <3 x half> %b, i32 2 +; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) +; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b) + %dx.dot = call half @llvm.fdot.v3f16(<3 x half> %a, <3 x half> %b) ret half %dx.dot } ; CHECK-LABEL: dot_half4 define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) { entry: -; CHECK: extractelement <4 x half> %a, i32 0 -; CHECK: extractelement <4 x half> %a, i32 1 -; CHECK: extractelement <4 x half> %a, i32 2 -; CHECK: extractelement <4 x half> %a, i32 3 -; CHECK: extractelement <4 x half> %b, i32 0 -; CHECK: extractelement <4 x half> %b, i32 1 -; CHECK: extractelement <4 x half> %b, i32 2 -; CHECK: extractelement <4 x half> %b, i32 3 -; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) - %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b) +; DOPCHECK: extractelement <4 x half> %a, i32 0 +; DOPCHECK: extractelement <4 x half> %a, i32 1 +; DOPCHECK: extractelement <4 x half> %a, i32 2 +; DOPCHECK: extractelement <4 x half> %a, i32 3 +; DOPCHECK: extractelement <4 x half> %b, i32 0 +; DOPCHECK: extractelement <4 x half> %b, i32 1 +; DOPCHECK: extractelement <4 x half> %b, i32 2 +; DOPCHECK: extractelement <4 x half> %b, i32 3 +; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) +; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b) + %dx.dot = call half @llvm.fdot.v4f16(<4 x half> %a, <4 x half> %b) ret half %dx.dot } ; CHECK-LABEL: dot_float2 define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) { entry: -; CHECK: extractelement <2 x float> %a, i32 0 -; CHECK: extractelement <2 x float> %a, i32 1 -; CHECK: extractelement <2 x float> %b, i32 0 -; CHECK: extractelement <2 x float> %b, i32 1 -; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) - %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b) +; DOPCHECK: extractelement <2 x float> %a, i32 0 +; DOPCHECK: extractelement <2 x float> %a, i32 1 +; DOPCHECK: extractelement <2 x float> %b, i32 0 +; DOPCHECK: extractelement <2 x float> %b, i32 1 +; DOPCHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) +; EXPCHECK: call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b) + %dx.dot = call float @llvm.fdot.v2f32(<2 x float> %a, <2 x float> %b) ret float %dx.dot } ; CHECK-LABEL: dot_float3 define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) { entry: -; CHECK: extractelement <3 x float> %a, i32 0 -; CHECK: extractelement <3 x float> %a, i32 1 -; CHECK: extractelement <3 x float> %a, i32 2 -; CHECK: extractelement <3 x float> %b, i32 0 -; CHECK: extractelement <3 x float> %b, i32 1 -; CHECK: extractelement <3 x float> %b, i32 2 -; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) - %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b) +; DOPCHECK: extractelement <3 x float> %a, i32 0 +; DOPCHECK: extractelement <3 x float> %a, i32 1 +; DOPCHECK: extractelement <3 x float> %a, i32 2 +; DOPCHECK: extractelement <3 x float> %b, i32 0 +; DOPCHECK: extractelement <3 x float> %b, i32 1 +; DOPCHECK: extractelement <3 x float> %b, i32 2 +; DOPCHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) +; EXPCHECK: call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b) + %dx.dot = call float @llvm.fdot.v3f32(<3 x float> %a, <3 x float> %b) ret float %dx.dot } ; CHECK-LABEL: dot_float4 define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) { entry: -; CHECK: extractelement <4 x float> %a, i32 0 -; CHECK: extractelement <4 x float> %a, i32 1 -; CHECK: extractelement <4 x float> %a, i32 2 -; CHECK: extractelement <4 x float> %a, i32 3 -; CHECK: extractelement <4 x float> %b, i32 0 -; CHECK: extractelement <4 x float> %b, i32 1 -; CHECK: extractelement <4 x float> %b, i32 2 -; CHECK: extractelement <4 x float> %b, i32 3 -; CHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) - %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b) +; DOPCHECK: extractelement <4 x float> %a, i32 0 +; DOPCHECK: extractelement <4 x float> %a, i32 1 +; DOPCHECK: extractelement <4 x float> %a, i32 2 +; DOPCHECK: extractelement <4 x float> %a, i32 3 +; DOPCHECK: extractelement <4 x float> %b, i32 0 +; DOPCHECK: extractelement <4 x float> %b, i32 1 +; DOPCHECK: extractelement <4 x float> %b, i32 2 +; DOPCHECK: extractelement <4 x float> %b, i32 3 +; DOPCHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) +; EXPCHECK: call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b) + %dx.dot = call float @llvm.fdot.v4f32(<4 x float> %a, <4 x float> %b) ret float %dx.dot } -declare half @llvm.dx.dot.v2f16(<2 x half> , <2 x half> ) -declare half @llvm.dx.dot.v3f16(<3 x half> , <3 x half> ) -declare half @llvm.dx.dot.v4f16(<4 x half> , <4 x half> ) -declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>) -declare float @llvm.dx.dot.v3f32(<3 x float>, <3 x float>) -declare float @llvm.dx.dot.v4f32(<4 x float>, <4 x float>) +declare half @llvm.fdot.v2f16(<2 x half> , <2 x half> ) +declare half @llvm.fdot.v3f16(<3 x half> , <3 x half> ) +declare half @llvm.fdot.v4f16(<4 x half> , <4 x half> ) +declare float @llvm.fdot.v2f32(<2 x float>, <2 x float>) +declare float @llvm.fdot.v3f32(<3 x float>, <3 x float>) +declare float @llvm.fdot.v4f32(<4 x float>, <4 x float>) diff --git a/llvm/test/CodeGen/DirectX/idot.ll b/llvm/test/CodeGen/DirectX/idot.ll index eac1b91106dde..94822f92b4135 100644 --- a/llvm/test/CodeGen/DirectX/idot.ll +++ b/llvm/test/CodeGen/DirectX/idot.ll @@ -13,12 +13,12 @@ entry: ; CHECK: extractelement <2 x i16> %b, i64 1 ; EXPCHECK: call i16 @llvm.dx.imad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) ; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 48, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) - %dx.dot = call i16 @llvm.dx.sdot.v3i16(<2 x i16> %a, <2 x i16> %b) - ret i16 %dx.dot + %dot = call i16 @llvm.sdot.v3i16(<2 x i16> %a, <2 x i16> %b) + ret i16 %dot } -; CHECK-LABEL: sdot_int4 -define noundef i32 @sdot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) { +; CHECK-LABEL: dot_int4 +define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) { entry: ; CHECK: extractelement <4 x i32> %a, i64 0 ; CHECK: extractelement <4 x i32> %b, i64 0 @@ -35,8 +35,8 @@ entry: ; CHECK: extractelement <4 x i32> %b, i64 3 ; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) ; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %a, <4 x i32> %b) - ret i32 %dx.dot + %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %a, <4 x i32> %b) + ret i32 %dot } ; CHECK-LABEL: dot_uint16_t3 @@ -53,8 +53,8 @@ entry: ; CHECK: extractelement <3 x i16> %b, i64 2 ; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) ; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) - %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %a, <3 x i16> %b) - ret i16 %dx.dot + %dot = call i16 @llvm.udot.v3i16(<3 x i16> %a, <3 x i16> %b) + ret i16 %dot } ; CHECK-LABEL: dot_uint4 @@ -75,8 +75,8 @@ entry: ; CHECK: extractelement <4 x i32> %b, i64 3 ; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) ; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %a, <4 x i32> %b) - ret i32 %dx.dot + %dot = call i32 @llvm.udot.v4i32(<4 x i32> %a, <4 x i32> %b) + ret i32 %dot } ; CHECK-LABEL: dot_uint64_t4 @@ -89,12 +89,12 @@ entry: ; CHECK: extractelement <2 x i64> %b, i64 1 ; EXPCHECK: call i64 @llvm.dx.umad.i64(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}}) ; DOPCHECK: call i64 @dx.op.tertiary.i64(i32 49, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}}) - %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %a, <2 x i64> %b) - ret i64 %dx.dot + %dot = call i64 @llvm.udot.v2i64(<2 x i64> %a, <2 x i64> %b) + ret i64 %dot } -declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i16>) -declare i32 @llvm.dx.sdot.v4i32(<4 x i32>, <4 x i32>) -declare i16 @llvm.dx.udot.v3i32(<3 x i16>, <3 x i16>) -declare i32 @llvm.dx.udot.v4i32(<4 x i32>, <4 x i32>) -declare i64 @llvm.dx.udot.v2i64(<2 x i64>, <2 x i64>) +declare i16 @llvm.sdot.v2i16(<2 x i16>, <2 x i16>) +declare i32 @llvm.sdot.v4i32(<4 x i32>, <4 x i32>) +declare i16 @llvm.udot.v3i32(<3 x i16>, <3 x i16>) +declare i32 @llvm.udot.v4i32(<4 x i32>, <4 x i32>) +declare i64 @llvm.udot.v2i64(<2 x i64>, <2 x i64>) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll new file mode 100644 index 0000000000000..964decf237a5c --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll @@ -0,0 +1,75 @@ +; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; Make sure dxil operation function calls for dot are generated for float type vectors. + +; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16 +; CHECK-DAG: %[[#vec2_float_16:]] = OpTypeVector %[[#float_16]] 2 +; CHECK-DAG: %[[#vec3_float_16:]] = OpTypeVector %[[#float_16]] 3 +; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4 +; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#vec2_float_32:]] = OpTypeVector %[[#float_32]] 2 +; CHECK-DAG: %[[#vec3_float_32:]] = OpTypeVector %[[#float_32]] 3 +; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4 + + +define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_float_16]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_float_16]] +; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]] + %dx.dot = call half @llvm.fdot.v2f16(<2 x half> %a, <2 x half> %b) + ret half %dx.dot +} + +define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_16]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_16]] +; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]] + %dx.dot = call half @llvm.fdot.v3f16(<3 x half> %a, <3 x half> %b) + ret half %dx.dot +} + +define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]] +; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]] + %dx.dot = call half @llvm.fdot.v4f16(<4 x half> %a, <4 x half> %b) + ret half %dx.dot +} + +define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_float_32]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_float_32]] +; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]] + %dx.dot = call float @llvm.fdot.v2f32(<2 x float> %a, <2 x float> %b) + ret float %dx.dot +} + +define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_32]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_32]] +; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]] + %dx.dot = call float @llvm.fdot.v3f32(<3 x float> %a, <3 x float> %b) + ret float %dx.dot +} + +define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]] +; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]] + %dx.dot = call float @llvm.fdot.v4f32(<4 x float> %a, <4 x float> %b) + ret float %dx.dot +} + +declare half @llvm.fdot.v2f16(<2 x half> , <2 x half> ) +declare half @llvm.fdot.v3f16(<3 x half> , <3 x half> ) +declare half @llvm.fdot.v4f16(<4 x half> , <4 x half> ) +declare float @llvm.fdot.v2f32(<2 x float>, <2 x float>) +declare float @llvm.fdot.v3f32(<3 x float>, <3 x float>) +declare float @llvm.fdot.v4f32(<4 x float>, <4 x float>) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll new file mode 100644 index 0000000000000..05f28920e205b --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll @@ -0,0 +1,88 @@ +; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; Make sure dxil operation function calls for dot are generated for int/uint vectors. + +; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16 +; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2 +; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3 +; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 +; CHECK-DAG: %[[#vec4_int_32:]] = OpTypeVector %[[#int_32]] 4 +; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64 +; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2 + +define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]] +; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]] +; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0 +; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1 +; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]] + %dot = call i16 @llvm.sdot.v3i16(<2 x i16> %a, <2 x i16> %b) + ret i16 %dot +} + +define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]] +; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]] +; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0 +; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1 +; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]] +; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2 +; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]] +; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3 +; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]] + %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %a, <4 x i32> %b) + ret i32 %dot +} + +define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]] +; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]] +; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0 +; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1 +; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]] +; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2 +; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]] + %dot = call i16 @llvm.udot.v3i16(<3 x i16> %a, <3 x i16> %b) + ret i16 %dot +} + +define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]] +; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]] +; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0 +; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1 +; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]] +; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2 +; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]] +; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3 +; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]] + %dot = call i32 @llvm.udot.v4i32(<4 x i32> %a, <4 x i32> %b) + ret i32 %dot +} + +define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]] +; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]] +; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0 +; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1 +; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]] + %dot = call i64 @llvm.udot.v2i64(<2 x i64> %a, <2 x i64> %b) + ret i64 %dot +} + +declare i16 @llvm.sdot.v2i16(<2 x i16>, <2 x i16>) +declare i32 @llvm.sdot.v4i32(<4 x i32>, <4 x i32>) +declare i16 @llvm.udot.v3i32(<3 x i16>, <3 x i16>) +declare i32 @llvm.udot.v4i32(<4 x i32>, <4 x i32>) +declare i64 @llvm.udot.v2i64(<2 x i64>, <2 x i64>)