Skip to content

Commit edbb80c

Browse files
committed
Add SPIRV generation for HLSL dot
Use the new LLVM dot intrinsics to build SPIRV instructions. This involves generating multiply and add operations for integers and the existing OpDot operation for floating point. This includes adding some generic opcodes for signed, unsigned and floats. These require updating an existing test for all such opcodes. New tests for generating SPIRV float and integer dot intrinsics are added as well. Fixes llvm#88056
1 parent 7ca6bc5 commit edbb80c

File tree

8 files changed

+289
-0
lines changed

8 files changed

+289
-0
lines changed

llvm/include/llvm/Support/TargetOpcodes.def

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,15 @@ HANDLE_TARGET_OPCODE(G_FSINH)
814814
/// Floating point hyperbolic tangent.
815815
HANDLE_TARGET_OPCODE(G_FTANH)
816816

817+
/// Floating point vector dot product
818+
HANDLE_TARGET_OPCODE(G_FDOTPROD)
819+
820+
/// Unsigned integer vector dot product
821+
HANDLE_TARGET_OPCODE(G_UDOTPROD)
822+
823+
/// Signed integer vector dot product
824+
HANDLE_TARGET_OPCODE(G_SDOTPROD)
825+
817826
/// Floating point square root.
818827
HANDLE_TARGET_OPCODE(G_FSQRT)
819828

llvm/include/llvm/Target/GenericOpcodes.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,27 @@ def G_FTANH : GenericInstruction {
10571057
let hasSideEffects = false;
10581058
}
10591059

1060+
/// Floating point vector dot product
1061+
def G_FDOTPROD : GenericInstruction {
1062+
let OutOperandList = (outs type0:$dst);
1063+
let InOperandList = (ins type0:$src1, type0:$src2);
1064+
let hasSideEffects = false;
1065+
}
1066+
1067+
/// Signed integer vector dot product
1068+
def G_SDOTPROD : GenericInstruction {
1069+
let OutOperandList = (outs type0:$dst);
1070+
let InOperandList = (ins type0:$src1, type0:$src2);
1071+
let hasSideEffects = false;
1072+
}
1073+
1074+
/// Unsigned integer vector dot product
1075+
def G_UDOTPROD : GenericInstruction {
1076+
let OutOperandList = (outs type0:$dst);
1077+
let InOperandList = (ins type0:$src1, type0:$src2);
1078+
let hasSideEffects = false;
1079+
}
1080+
10601081
// Floating point square root of a value.
10611082
// This returns NaN for negative nonzero values.
10621083
// NOTE: Unlike libm sqrt(), this never sets errno. In all other respects it's

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,6 +1903,12 @@ unsigned IRTranslator::getSimpleIntrinsicOpcode(Intrinsic::ID ID) {
19031903
return TargetOpcode::G_CTPOP;
19041904
case Intrinsic::exp:
19051905
return TargetOpcode::G_FEXP;
1906+
case Intrinsic::fdot:
1907+
return TargetOpcode::G_FDOTPROD;
1908+
case Intrinsic::sdot:
1909+
return TargetOpcode::G_SDOTPROD;
1910+
case Intrinsic::udot:
1911+
return TargetOpcode::G_UDOTPROD;
19061912
case Intrinsic::exp2:
19071913
return TargetOpcode::G_FEXP2;
19081914
case Intrinsic::exp10:

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
178178
bool selectRsqrt(Register ResVReg, const SPIRVType *ResType,
179179
MachineInstr &I) const;
180180

181+
bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
182+
MachineInstr &I) const;
183+
181184
void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
182185
int OpIdx) const;
183186
void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -380,6 +383,20 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
380383
MIB.addImm(V);
381384
return MIB.constrainAllUses(TII, TRI, RBI);
382385
}
386+
387+
case TargetOpcode::G_FDOTPROD: {
388+
MachineBasicBlock &BB = *I.getParent();
389+
return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpDot))
390+
.addDef(ResVReg)
391+
.addUse(GR.getSPIRVTypeID(ResType))
392+
.addUse(I.getOperand(1).getReg())
393+
.addUse(I.getOperand(2).getReg())
394+
.constrainAllUses(TII, TRI, RBI);
395+
}
396+
case TargetOpcode::G_SDOTPROD:
397+
case TargetOpcode::G_UDOTPROD:
398+
return selectIntegerDot(ResVReg, ResType, I);
399+
383400
case TargetOpcode::G_MEMMOVE:
384401
case TargetOpcode::G_MEMCPY:
385402
case TargetOpcode::G_MEMSET:
@@ -1366,6 +1383,67 @@ bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
13661383
.constrainAllUses(TII, TRI, RBI);
13671384
}
13681385

1386+
// Since there is no integer dot implementation, expand by piecewise multiplying
1387+
// and adding the results, making use of FMA operations where possible.
1388+
bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
1389+
const SPIRVType *ResType,
1390+
MachineInstr &I) const {
1391+
assert(I.getNumOperands() == 3);
1392+
assert(I.getOperand(1).isReg());
1393+
assert(I.getOperand(2).isReg());
1394+
MachineBasicBlock &BB = *I.getParent();
1395+
1396+
// Multiply the vectors, then sum the results
1397+
Register Vec0 = I.getOperand(1).getReg();
1398+
Register Vec1 = I.getOperand(2).getReg();
1399+
Register TmpVec = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1400+
SPIRVType *VecType = GR.getSPIRVTypeForVReg(Vec0);
1401+
1402+
bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulV))
1403+
.addDef(TmpVec)
1404+
.addUse(GR.getSPIRVTypeID(VecType))
1405+
.addUse(Vec0)
1406+
.addUse(Vec1)
1407+
.constrainAllUses(TII, TRI, RBI);
1408+
1409+
assert(GR.getScalarOrVectorComponentCount(VecType) > 1 &&
1410+
"dot product requires a vector of at least 2 components");
1411+
1412+
Register Res = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1413+
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1414+
.addDef(Res)
1415+
.addUse(GR.getSPIRVTypeID(ResType))
1416+
.addUse(TmpVec)
1417+
.addImm(0)
1418+
.constrainAllUses(TII, TRI, RBI);
1419+
1420+
for (unsigned i = 1; i < GR.getScalarOrVectorComponentCount(VecType); i++) {
1421+
Register Elt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1422+
1423+
Result |=
1424+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1425+
.addDef(Elt)
1426+
.addUse(GR.getSPIRVTypeID(ResType))
1427+
.addUse(TmpVec)
1428+
.addImm(i)
1429+
.constrainAllUses(TII, TRI, RBI);
1430+
1431+
Register Sum = i < GR.getScalarOrVectorComponentCount(VecType) - 1
1432+
? MRI->createVirtualRegister(&SPIRV::IDRegClass)
1433+
: ResVReg;
1434+
1435+
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
1436+
.addDef(Sum)
1437+
.addUse(GR.getSPIRVTypeID(ResType))
1438+
.addUse(Res)
1439+
.addUse(Elt)
1440+
.constrainAllUses(TII, TRI, RBI);
1441+
Res = Sum;
1442+
}
1443+
1444+
return Result;
1445+
}
1446+
13691447
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
13701448
const SPIRVType *ResType,
13711449
MachineInstr &I) const {

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
285285
G_FCOSH,
286286
G_FSINH,
287287
G_FTANH,
288+
G_FDOTPROD,
289+
G_SDOTPROD,
290+
G_UDOTPROD,
288291
G_FSQRT,
289292
G_FFLOOR,
290293
G_FRINT,

llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,15 @@
716716
# DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
717717
# DEBUG-NEXT: .. the first uncovered type index: 1, OK
718718
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
719+
# DEBUG-NEXT: G_FDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
720+
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
721+
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
722+
# DEBUG-NEXT: G_UDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
723+
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
724+
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
725+
# DEBUG-NEXT: G_SDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
726+
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
727+
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
719728
# DEBUG-NEXT: G_FSQRT (opcode {{[0-9]+}}): 1 type index, 0 imm indices
720729
# DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
721730
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; Make sure dxil operation function calls for dot are generated for float type vectors.
5+
6+
; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
7+
; CHECK-DAG: %[[#vec2_float_16:]] = OpTypeVector %[[#float_16]] 2
8+
; CHECK-DAG: %[[#vec3_float_16:]] = OpTypeVector %[[#float_16]] 3
9+
; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
10+
; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
11+
; CHECK-DAG: %[[#vec2_float_32:]] = OpTypeVector %[[#float_32]] 2
12+
; CHECK-DAG: %[[#vec3_float_32:]] = OpTypeVector %[[#float_32]] 3
13+
; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
14+
15+
16+
define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
17+
entry:
18+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_float_16]]
19+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_float_16]]
20+
; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
21+
%dx.dot = call half @llvm.fdot.v2f16(<2 x half> %a, <2 x half> %b)
22+
ret half %dx.dot
23+
}
24+
25+
define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
26+
entry:
27+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_16]]
28+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_16]]
29+
; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
30+
%dx.dot = call half @llvm.fdot.v3f16(<3 x half> %a, <3 x half> %b)
31+
ret half %dx.dot
32+
}
33+
34+
define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
35+
entry:
36+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
37+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]]
38+
; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
39+
%dx.dot = call half @llvm.fdot.v4f16(<4 x half> %a, <4 x half> %b)
40+
ret half %dx.dot
41+
}
42+
43+
define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
44+
entry:
45+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_float_32]]
46+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_float_32]]
47+
; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
48+
%dx.dot = call float @llvm.fdot.v2f32(<2 x float> %a, <2 x float> %b)
49+
ret float %dx.dot
50+
}
51+
52+
define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
53+
entry:
54+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_32]]
55+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_32]]
56+
; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
57+
%dx.dot = call float @llvm.fdot.v3f32(<3 x float> %a, <3 x float> %b)
58+
ret float %dx.dot
59+
}
60+
61+
define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
62+
entry:
63+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
64+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
65+
; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
66+
%dx.dot = call float @llvm.fdot.v4f32(<4 x float> %a, <4 x float> %b)
67+
ret float %dx.dot
68+
}
69+
70+
declare half @llvm.fdot.v2f16(<2 x half> , <2 x half> )
71+
declare half @llvm.fdot.v3f16(<3 x half> , <3 x half> )
72+
declare half @llvm.fdot.v4f16(<4 x half> , <4 x half> )
73+
declare float @llvm.fdot.v2f32(<2 x float>, <2 x float>)
74+
declare float @llvm.fdot.v3f32(<3 x float>, <3 x float>)
75+
declare float @llvm.fdot.v4f32(<4 x float>, <4 x float>)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; Make sure dxil operation function calls for dot are generated for int/uint vectors.
5+
6+
; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16
7+
; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2
8+
; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3
9+
; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32
10+
; CHECK-DAG: %[[#vec4_int_32:]] = OpTypeVector %[[#int_32]] 4
11+
; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64
12+
; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2
13+
14+
define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
15+
entry:
16+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]]
17+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]]
18+
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
19+
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
20+
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
21+
; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
22+
%dot = call i16 @llvm.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
23+
ret i16 %dot
24+
}
25+
26+
define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
27+
entry:
28+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
29+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
30+
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
31+
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
32+
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
33+
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
34+
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
35+
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
36+
; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
37+
; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
38+
%dot = call i32 @llvm.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
39+
ret i32 %dot
40+
}
41+
42+
define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
43+
entry:
44+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]]
45+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]]
46+
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
47+
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
48+
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
49+
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
50+
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
51+
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
52+
%dot = call i16 @llvm.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
53+
ret i16 %dot
54+
}
55+
56+
define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
57+
entry:
58+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
59+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
60+
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
61+
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
62+
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
63+
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
64+
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
65+
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
66+
; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
67+
; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
68+
%dot = call i32 @llvm.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
69+
ret i32 %dot
70+
}
71+
72+
define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
73+
entry:
74+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
75+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
76+
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
77+
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
78+
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
79+
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
80+
%dot = call i64 @llvm.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
81+
ret i64 %dot
82+
}
83+
84+
declare i16 @llvm.sdot.v2i16(<2 x i16>, <2 x i16>)
85+
declare i32 @llvm.sdot.v4i32(<4 x i32>, <4 x i32>)
86+
declare i16 @llvm.udot.v3i32(<3 x i16>, <3 x i16>)
87+
declare i32 @llvm.udot.v4i32(<4 x i32>, <4 x i32>)
88+
declare i64 @llvm.udot.v2i64(<2 x i64>, <2 x i64>)

0 commit comments

Comments
 (0)