-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[HLSL][SPIRV][DXIL] Implement dot4add_i8packed
intrinsic
#113623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
136d5cb
07d529d
2cd3e4a
af2aa50
1500192
9c98c3c
d2851ae
7617b82
d3b494c
efea661
3857d17
746d510
e69dbad
dd91f76
b77c090
5b5e084
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// RUN: %clang_cc1 -finclude-default-header -triple \ | ||
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ | ||
// RUN: FileCheck %s -DTARGET=dx | ||
// RUN: %clang_cc1 -finclude-default-header -triple \ | ||
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ | ||
// RUN: FileCheck %s -DTARGET=spv | ||
|
||
// Test basic lowering to runtime function call. | ||
|
||
// CHECK-LABEL: test | ||
int test(uint a, uint b, int c) { | ||
// CHECK: %[[RET:.*]] = call [[TY:i32]] @llvm.[[TARGET]].dot4add.i8packed([[TY]] %[[#]], [[TY]] %[[#]], [[TY]] %[[#]]) | ||
// CHECK: ret [[TY]] %[[RET]] | ||
return dot4add_i8packed(a, b, c); | ||
} | ||
|
||
// CHECK: declare [[TY]] @llvm.[[TARGET]].dot4add.i8packed([[TY]], [[TY]], [[TY]]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify | ||
|
||
int test_too_few_arg0() { | ||
return __builtin_hlsl_dot4add_i8packed(); | ||
// expected-error@-1 {{too few arguments to function call, expected 3, have 0}} | ||
} | ||
|
||
int test_too_few_arg1(int p0) { | ||
return __builtin_hlsl_dot4add_i8packed(p0); | ||
// expected-error@-1 {{too few arguments to function call, expected 3, have 1}} | ||
} | ||
|
||
int test_too_few_arg2(int p0) { | ||
return __builtin_hlsl_dot4add_i8packed(p0, p0); | ||
// expected-error@-1 {{too few arguments to function call, expected 3, have 2}} | ||
} | ||
|
||
int test_too_many_arg(int p0) { | ||
return __builtin_hlsl_dot4add_i8packed(p0, p0, p0, p0); | ||
// expected-error@-1 {{too many arguments to function call, expected 3, have 4}} | ||
} | ||
|
||
struct S { float f; }; | ||
|
||
int test_expr_struct_type_check(S p0, int p1) { | ||
return __builtin_hlsl_dot4add_i8packed(p0, p1, p1); | ||
// expected-error@-1 {{no viable conversion from 'S' to 'unsigned int'}} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -524,6 +524,9 @@ defm OpISubBorrow: BinOpTypedGen<"OpISubBorrow", 150, subc, 0, 1>; | |
def OpUMulExtended: BinOp<"OpUMulExtended", 151>; | ||
def OpSMulExtended: BinOp<"OpSMulExtended", 152>; | ||
|
||
def OpSDot: BinOp<"OpSDot", 4450>; | ||
def OpUDot: BinOp<"OpUDot", 4451>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do I understand it right that you need both OpSDot/OpUDot and OpSDotKHR/OpUDotKHR here. The former for versions since 1.6, and the latter for previous SPIR-V versions (also by default at the moment)? |
||
|
||
// 3.42.14 Bit Instructions | ||
|
||
defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,6 +164,13 @@ class SPIRVInstructionSelector : public InstructionSelector { | |
bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType, | ||
MachineInstr &I) const; | ||
|
||
template <bool Signed> | ||
bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType, | ||
MachineInstr &I) const; | ||
template <bool Signed> | ||
bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType, | ||
MachineInstr &I) const; | ||
|
||
void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I, | ||
int OpIdx) const; | ||
void renderFImm64(MachineInstrBuilder &MIB, const MachineInstr &I, | ||
|
@@ -1646,7 +1653,7 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg, | |
// Multiply the vectors, then sum the results | ||
Register Vec0 = I.getOperand(2).getReg(); | ||
Register Vec1 = I.getOperand(3).getReg(); | ||
Register TmpVec = MRI->createVirtualRegister(&SPIRV::IDRegClass); | ||
Register TmpVec = MRI->createVirtualRegister(GR.getRegClass(ResType)); | ||
SPIRVType *VecType = GR.getSPIRVTypeForVReg(Vec0); | ||
|
||
bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulV)) | ||
|
@@ -1660,18 +1667,18 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg, | |
GR.getScalarOrVectorComponentCount(VecType) > 1 && | ||
"dot product requires a vector of at least 2 components"); | ||
|
||
Register Res = MRI->createVirtualRegister(&SPIRV::IDRegClass); | ||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) | ||
Register Res = MRI->createVirtualRegister(GR.getRegClass(ResType)); | ||
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) | ||
.addDef(Res) | ||
.addUse(GR.getSPIRVTypeID(ResType)) | ||
.addUse(TmpVec) | ||
.addImm(0) | ||
.constrainAllUses(TII, TRI, RBI); | ||
|
||
for (unsigned i = 1; i < GR.getScalarOrVectorComponentCount(VecType); i++) { | ||
Register Elt = MRI->createVirtualRegister(&SPIRV::IDRegClass); | ||
Register Elt = MRI->createVirtualRegister(GR.getRegClass(ResType)); | ||
|
||
Result |= | ||
Result &= | ||
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) | ||
.addDef(Elt) | ||
.addUse(GR.getSPIRVTypeID(ResType)) | ||
|
@@ -1680,10 +1687,10 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg, | |
.constrainAllUses(TII, TRI, RBI); | ||
|
||
Register Sum = i < GR.getScalarOrVectorComponentCount(VecType) - 1 | ||
? MRI->createVirtualRegister(&SPIRV::IDRegClass) | ||
? MRI->createVirtualRegister(GR.getRegClass(ResType)) | ||
: ResVReg; | ||
|
||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS)) | ||
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS)) | ||
.addDef(Sum) | ||
.addUse(GR.getSPIRVTypeID(ResType)) | ||
.addUse(Res) | ||
|
@@ -1695,6 +1702,112 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg, | |
return Result; | ||
} | ||
|
||
template <bool Signed> | ||
bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg, | ||
const SPIRVType *ResType, | ||
MachineInstr &I) const { | ||
assert(I.getNumOperands() == 5); | ||
assert(I.getOperand(2).isReg()); | ||
assert(I.getOperand(3).isReg()); | ||
assert(I.getOperand(4).isReg()); | ||
MachineBasicBlock &BB = *I.getParent(); | ||
|
||
auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot; | ||
Register Dot = MRI->createVirtualRegister(GR.getRegClass(ResType)); | ||
bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp)) | ||
.addDef(Dot) | ||
.addUse(GR.getSPIRVTypeID(ResType)) | ||
.addUse(I.getOperand(2).getReg()) | ||
.addUse(I.getOperand(3).getReg()) | ||
.constrainAllUses(TII, TRI, RBI); | ||
|
||
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS)) | ||
.addDef(ResVReg) | ||
.addUse(GR.getSPIRVTypeID(ResType)) | ||
.addUse(Dot) | ||
.addUse(I.getOperand(4).getReg()) | ||
.constrainAllUses(TII, TRI, RBI); | ||
|
||
return Result; | ||
} | ||
|
||
// Since pre-1.6 SPIRV has no DotProductInput4x8BitPacked implementation, | ||
// extract the elements of the packed inputs, multiply them and add the result | ||
// to the accumulator. | ||
template <bool Signed> | ||
bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( | ||
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { | ||
assert(I.getNumOperands() == 5); | ||
assert(I.getOperand(2).isReg()); | ||
assert(I.getOperand(3).isReg()); | ||
assert(I.getOperand(4).isReg()); | ||
MachineBasicBlock &BB = *I.getParent(); | ||
|
||
bool Result = false; | ||
|
||
// Acc = C | ||
Register Acc = I.getOperand(4).getReg(); | ||
SPIRVType *EltType = GR.getOrCreateSPIRVIntegerType(8, I, TII); | ||
auto ExtractOp = | ||
Signed ? SPIRV::OpBitFieldSExtract : SPIRV::OpBitFieldUExtract; | ||
|
||
// Extract the i8 element, multiply and add it to the accumulator | ||
for (unsigned i = 0; i < 4; i++) { | ||
// A[i] | ||
Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass); | ||
inbelic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) | ||
.addDef(AElt) | ||
.addUse(GR.getSPIRVTypeID(ResType)) | ||
.addUse(I.getOperand(2).getReg()) | ||
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII)) | ||
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII)) | ||
.constrainAllUses(TII, TRI, RBI); | ||
|
||
// B[i] | ||
Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass); | ||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) | ||
.addDef(BElt) | ||
.addUse(GR.getSPIRVTypeID(ResType)) | ||
.addUse(I.getOperand(3).getReg()) | ||
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII)) | ||
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII)) | ||
.constrainAllUses(TII, TRI, RBI); | ||
|
||
// A[i] * B[i] | ||
Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass); | ||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS)) | ||
.addDef(Mul) | ||
.addUse(GR.getSPIRVTypeID(ResType)) | ||
.addUse(AElt) | ||
.addUse(BElt) | ||
.constrainAllUses(TII, TRI, RBI); | ||
|
||
// Discard 24 highest-bits so that stored i32 register is i8 equivalent | ||
Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass); | ||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) | ||
.addDef(MaskMul) | ||
.addUse(GR.getSPIRVTypeID(ResType)) | ||
.addUse(Mul) | ||
.addUse(GR.getOrCreateConstInt(0, I, EltType, TII)) | ||
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII)) | ||
.constrainAllUses(TII, TRI, RBI); | ||
|
||
// Acc = Acc + A[i] * B[i] | ||
Register Sum = | ||
i < 3 ? MRI->createVirtualRegister(&SPIRV::IDRegClass) : ResVReg; | ||
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS)) | ||
.addDef(Sum) | ||
.addUse(GR.getSPIRVTypeID(ResType)) | ||
.addUse(Acc) | ||
.addUse(MaskMul) | ||
.constrainAllUses(TII, TRI, RBI); | ||
|
||
Acc = Sum; | ||
} | ||
|
||
return Result; | ||
} | ||
|
||
/// Transform saturate(x) to clamp(x, 0.0f, 1.0f) as SPIRV | ||
/// does not have a saturate builtin. | ||
bool SPIRVInstructionSelector::selectSaturate(Register ResVReg, | ||
|
@@ -2528,6 +2641,11 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, | |
case Intrinsic::spv_udot: | ||
case Intrinsic::spv_sdot: | ||
return selectIntegerDot(ResVReg, ResType, I); | ||
case Intrinsic::spv_dot4add_i8packed: | ||
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) || | ||
inbelic marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not valid. You have 3 cases, not 2 as now. Instead of selecting
As of now you get invalid SPIR-V when version is below 1.6 and SPV_KHR_integer_dot_product is enabled. In this case you will generate OpSDot/OpUDot which are missing before version 1.6. |
||
STI.isAtLeastSPIRVVer(VersionTuple(1, 6))) | ||
return selectDot4AddPacked<true>(ResVReg, ResType, I); | ||
return selectDot4AddPackedExpansion<true>(ResVReg, ResType, I); | ||
case Intrinsic::spv_all: | ||
return selectAll(ResVReg, ResType, I); | ||
case Intrinsic::spv_any: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason you didn't do
uint
here? Its the same, so just curious.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops no. A discrepancy on my part.