Skip to content

[HLSL][SPIRV][DXIL] Implement dot4add_i8packed intrinsic #113623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 5, 2024
Merged
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4792,6 +4792,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot4add_i8packed"];
let Attributes = [NoThrow, Const];
let Prototype = "int(unsigned int, unsigned int, int)";
}

def HLSLFrac : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_frac"];
let Attributes = [NoThrow, Const];
Expand Down
12 changes: 11 additions & 1 deletion clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18855,7 +18855,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/T0->getScalarType(),
getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
} break;
}
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
Value *C = EmitScalarExpr(E->getArg(2));

Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddI8PackedIntrinsic();
return Builder.CreateIntrinsic(
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
"hlsl.dot4add.i8packed");
}
case Builtin::BI__builtin_hlsl_lerp: {
Value *X = EmitScalarExpr(E->getArg(0));
Value *Y = EmitScalarExpr(E->getArg(1));
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)

Expand Down
10 changes: 10 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,16 @@ uint64_t dot(uint64_t3, uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t4, uint64_t4);

//===----------------------------------------------------------------------===//
// dot4add builtins
//===----------------------------------------------------------------------===//

/// \fn int dot4add_i8packed(uint A, uint B, int C)

_HLSL_AVAILABILITY(shadermodel, 6.4)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot4add_i8packed)
int dot4add_i8packed(unsigned int, unsigned int, int);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason you didn't do uint here? Its the same, so just curious.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops no. A discrepancy on my part.


//===----------------------------------------------------------------------===//
// exp builtins
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 17 additions & 0 deletions clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: %clang_cc1 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s -DTARGET=dx
// RUN: %clang_cc1 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s -DTARGET=spv

// Test basic lowering to runtime function call.

// CHECK-LABEL: test
int test(uint a, uint b, int c) {
// CHECK: %[[RET:.*]] = call [[TY:i32]] @llvm.[[TARGET]].dot4add.i8packed([[TY]] %[[#]], [[TY]] %[[#]], [[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return dot4add_i8packed(a, b, c);
}

// CHECK: declare [[TY]] @llvm.[[TARGET]].dot4add.i8packed([[TY]], [[TY]], [[TY]])
28 changes: 28 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/dot4add_i8packed-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

int test_too_few_arg0() {
return __builtin_hlsl_dot4add_i8packed();
// expected-error@-1 {{too few arguments to function call, expected 3, have 0}}
}

int test_too_few_arg1(int p0) {
return __builtin_hlsl_dot4add_i8packed(p0);
// expected-error@-1 {{too few arguments to function call, expected 3, have 1}}
}

int test_too_few_arg2(int p0) {
return __builtin_hlsl_dot4add_i8packed(p0, p0);
// expected-error@-1 {{too few arguments to function call, expected 3, have 2}}
}

int test_too_many_arg(int p0) {
return __builtin_hlsl_dot4add_i8packed(p0, p0, p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 3, have 4}}
}

struct S { float f; };

int test_expr_struct_type_check(S p0, int p1) {
return __builtin_hlsl_dot4add_i8packed(p0, p1, p1);
// expected-error@-1 {{no viable conversion from 'S' to 'unsigned int'}}
}
2 changes: 2 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Provides additional information to a compiler, similar to the llvm.assume and llvm.expect intrinsics.
* - ``SPV_KHR_float_controls``
- Provides new execution modes to control floating-point computations by overriding an implementation’s default behavior for rounding modes, denormals, signed zero, and infinities.
* - ``SPV_KHR_integer_dot_product``
- Adds instructions for dot product operations on integer vectors with optional accumulation. Integer vectors includes 4-component vector of 8 bit integers and 4-component vectors of 8 bit integers packed into 32-bit integers.
* - ``SPV_KHR_linkonce_odr``
- Allows to use the LinkOnceODR linkage type that lets a function or global variable to be merged with other functions or global variables of the same name when linkage occurs.
* - ``SPV_KHR_no_integer_wrap_decoration``
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def int_dx_udot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;

def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_degrees : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ let TargetPrefix = "spv" in {
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,16 @@ def SplitDouble : DXILOp<102, splitDouble> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
let Doc = "signed dot product of 4 x i8 vectors packed into i32, with "
"accumulate to i32";
let LLVMIntrinsic = int_dx_dot4add_i8packed;
let arguments = [Int32Ty, Int32Ty, Int32Ty];
let result = Int32Ty;
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
}

def AnnotateHandle : DXILOp<216, annotateHandle> {
let Doc = "annotate handle with resource properties";
let arguments = [HandleTy, ResPropsTy];
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
SPIRV::Extension::Extension::SPV_KHR_expect_assume},
{"SPV_KHR_bit_instructions",
SPIRV::Extension::Extension::SPV_KHR_bit_instructions},
{"SPV_KHR_integer_dot_product",
SPIRV::Extension::Extension::SPV_KHR_integer_dot_product},
{"SPV_KHR_linkonce_odr",
SPIRV::Extension::Extension::SPV_KHR_linkonce_odr},
{"SPV_INTEL_inline_assembly",
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ defm OpISubBorrow: BinOpTypedGen<"OpISubBorrow", 150, subc, 0, 1>;
def OpUMulExtended: BinOp<"OpUMulExtended", 151>;
def OpSMulExtended: BinOp<"OpSMulExtended", 152>;

def OpSDot: BinOp<"OpSDot", 4450>;
def OpUDot: BinOp<"OpUDot", 4451>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand it right that you need both OpSDot/OpUDot and OpSDotKHR/OpUDotKHR here. The former for versions since 1.6, and the latter for previous SPIR-V versions (also by default at the moment)?


// 3.42.14 Bit Instructions

defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>;
Expand Down
132 changes: 125 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

template <bool Signed>
bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
template <bool Signed>
bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
int OpIdx) const;
void renderFImm64(MachineInstrBuilder &MIB, const MachineInstr &I,
Expand Down Expand Up @@ -1646,7 +1653,7 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
// Multiply the vectors, then sum the results
Register Vec0 = I.getOperand(2).getReg();
Register Vec1 = I.getOperand(3).getReg();
Register TmpVec = MRI->createVirtualRegister(&SPIRV::IDRegClass);
Register TmpVec = MRI->createVirtualRegister(GR.getRegClass(ResType));
SPIRVType *VecType = GR.getSPIRVTypeForVReg(Vec0);

bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulV))
Expand All @@ -1660,18 +1667,18 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
GR.getScalarOrVectorComponentCount(VecType) > 1 &&
"dot product requires a vector of at least 2 components");

Register Res = MRI->createVirtualRegister(&SPIRV::IDRegClass);
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
Register Res = MRI->createVirtualRegister(GR.getRegClass(ResType));
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(Res)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(TmpVec)
.addImm(0)
.constrainAllUses(TII, TRI, RBI);

for (unsigned i = 1; i < GR.getScalarOrVectorComponentCount(VecType); i++) {
Register Elt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
Register Elt = MRI->createVirtualRegister(GR.getRegClass(ResType));

Result |=
Result &=
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(Elt)
.addUse(GR.getSPIRVTypeID(ResType))
Expand All @@ -1680,10 +1687,10 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);

Register Sum = i < GR.getScalarOrVectorComponentCount(VecType) - 1
? MRI->createVirtualRegister(&SPIRV::IDRegClass)
? MRI->createVirtualRegister(GR.getRegClass(ResType))
: ResVReg;

Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
.addDef(Sum)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Res)
Expand All @@ -1695,6 +1702,112 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
return Result;
}

template <bool Signed>
bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
assert(I.getNumOperands() == 5);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());
assert(I.getOperand(4).isReg());
MachineBasicBlock &BB = *I.getParent();

auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
Register Dot = MRI->createVirtualRegister(GR.getRegClass(ResType));
bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
.addDef(Dot)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg())
.constrainAllUses(TII, TRI, RBI);

Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Dot)
.addUse(I.getOperand(4).getReg())
.constrainAllUses(TII, TRI, RBI);

return Result;
}

// Since pre-1.6 SPIRV has no DotProductInput4x8BitPacked implementation,
// extract the elements of the packed inputs, multiply them and add the result
// to the accumulator.
template <bool Signed>
bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
assert(I.getNumOperands() == 5);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());
assert(I.getOperand(4).isReg());
MachineBasicBlock &BB = *I.getParent();

bool Result = false;

// Acc = C
Register Acc = I.getOperand(4).getReg();
SPIRVType *EltType = GR.getOrCreateSPIRVIntegerType(8, I, TII);
auto ExtractOp =
Signed ? SPIRV::OpBitFieldSExtract : SPIRV::OpBitFieldUExtract;

// Extract the i8 element, multiply and add it to the accumulator
for (unsigned i = 0; i < 4; i++) {
// A[i]
Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
.addDef(AElt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(2).getReg())
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII))
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII))
.constrainAllUses(TII, TRI, RBI);

// B[i]
Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
.addDef(BElt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(3).getReg())
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII))
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII))
.constrainAllUses(TII, TRI, RBI);

// A[i] * B[i]
Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS))
.addDef(Mul)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(AElt)
.addUse(BElt)
.constrainAllUses(TII, TRI, RBI);

// Discard 24 highest-bits so that stored i32 register is i8 equivalent
Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
.addDef(MaskMul)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Mul)
.addUse(GR.getOrCreateConstInt(0, I, EltType, TII))
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII))
.constrainAllUses(TII, TRI, RBI);

// Acc = Acc + A[i] * B[i]
Register Sum =
i < 3 ? MRI->createVirtualRegister(&SPIRV::IDRegClass) : ResVReg;
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
.addDef(Sum)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Acc)
.addUse(MaskMul)
.constrainAllUses(TII, TRI, RBI);

Acc = Sum;
}

return Result;
}

/// Transform saturate(x) to clamp(x, 0.0f, 1.0f) as SPIRV
/// does not have a saturate builtin.
bool SPIRVInstructionSelector::selectSaturate(Register ResVReg,
Expand Down Expand Up @@ -2528,6 +2641,11 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_udot:
case Intrinsic::spv_sdot:
return selectIntegerDot(ResVReg, ResType, I);
case Intrinsic::spv_dot4add_i8packed:
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not valid. You have 3 cases, not 2 as now. Instead of selecting OpSDot or Expansion, you should check:

  • STI.isAtLeastSPIRVVer(VersionTuple(1, 6)) allows to use OpSDot/OpUDot
  • STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) allows to use OpSDotKHR/OpUDotKHR
  • otherwise you may only expand with selectDot4AddPackedExpansion

As of now you get invalid SPIR-V when version is below 1.6 and SPV_KHR_integer_dot_product is enabled. In this case you will generate OpSDot/OpUDot which are missing before version 1.6.

STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
return selectDot4AddPacked<true>(ResVReg, ResType, I);
return selectDot4AddPackedExpansion<true>(ResVReg, ResType, I);
case Intrinsic::spv_all:
return selectAll(ResVReg, ResType, I);
case Intrinsic::spv_any:
Expand Down
Loading