Skip to content

Commit 4682250

Browse files
committed
[HLSL][SPIRV][DXIL] Implement dot4add_i8packed intrinsic
- create a clang built-in in Builtins.td - link dot4add_i8packed in hlsl_intrinsics.h - add lowering to spirv backend through expansion of operation as OPSDot is missing up to SPIRV 1.6 in SPIRVInstructionSelector.cpp - add dot4add_i8packed intrinsic to IntrinsicsDirectX.td and mapping to DXIL.td op Dot4AddI8Packed - add tests for HLSL intrinsic lowering to dx/spv intrinsic in dot4add_i8packed.hlsl - add tests for sema checks in dot4add_i8packed-errors.hlsl - add test of spir-v lowering in SPIRV/dot4add_i8packed.ll - add test to dxil lowering in DirectX/dot4add_i8packed.ll
1 parent eaa7b38 commit 4682250

File tree

12 files changed

+227
-1
lines changed

12 files changed

+227
-1
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4792,6 +4792,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
47924792
let Prototype = "void(...)";
47934793
}
47944794

4795+
def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
4796+
let Spellings = ["__builtin_hlsl_dot4add_i8packed"];
4797+
let Attributes = [NoThrow, Const];
4798+
let Prototype = "int(unsigned int, unsigned int, int)";
4799+
}
4800+
47954801
def HLSLFrac : LangBuiltin<"HLSL_LANG"> {
47964802
let Spellings = ["__builtin_hlsl_elementwise_frac"];
47974803
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18722,7 +18722,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1872218722
/*ReturnType=*/T0->getScalarType(),
1872318723
getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
1872418724
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
18725-
} break;
18725+
}
18726+
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
18727+
Value *A = EmitScalarExpr(E->getArg(0));
18728+
Value *B = EmitScalarExpr(E->getArg(1));
18729+
Value *C = EmitScalarExpr(E->getArg(2));
18730+
18731+
Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddI8PackedIntrinsic();
18732+
return Builder.CreateIntrinsic(
18733+
/*ReturnType=*/C->getType(), ID,
18734+
ArrayRef<Value *>{A, B, C}, nullptr, "hlsl.dot4add.i8packed");
18735+
}
1872618736
case Builtin::BI__builtin_hlsl_lerp: {
1872718737
Value *X = EmitScalarExpr(E->getArg(0));
1872818738
Value *Y = EmitScalarExpr(E->getArg(1));

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class CGHLSLRuntime {
8989
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
9090
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
9191
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
92+
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
9293
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
9394
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
9495

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,16 @@ uint64_t dot(uint64_t3, uint64_t3);
894894
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
895895
uint64_t dot(uint64_t4, uint64_t4);
896896

897+
//===----------------------------------------------------------------------===//
898+
// dot4add builtins
899+
//===----------------------------------------------------------------------===//
900+
901+
/// \fn int dot4add_i8packed(uint A, uint B, int C)
902+
903+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.4)
904+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot4add_i8packed)
905+
int dot4add_i8packed(unsigned int, unsigned int, int);
906+
897907
//===----------------------------------------------------------------------===//
898908
// exp builtins
899909
//===----------------------------------------------------------------------===//
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s -DTARGET=dx
4+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s -DTARGET=spv
7+
8+
// Test basic lowering to runtime function call.
9+
10+
// CHECK-LABEL: test
11+
int test(uint a, uint b, int c) {
12+
// CHECK: %[[RET:.*]] = call [[TY:i32]] @llvm.[[TARGET]].dot4add.i8packed([[TY]] %[[#]], [[TY]] %[[#]], [[TY]] %[[#]])
13+
// CHECK: ret [[TY]] %[[RET]]
14+
return dot4add_i8packed(a, b, c);
15+
}
16+
17+
// CHECK: declare [[TY]] @llvm.[[TARGET]].dot4add.i8packed([[TY]], [[TY]], [[TY]])
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
int test_too_few_arg0() {
4+
return __builtin_hlsl_dot4add_i8packed();
5+
// expected-error@-1 {{too few arguments to function call, expected 3, have 0}}
6+
}
7+
8+
int test_too_few_arg1(int p0) {
9+
return __builtin_hlsl_dot4add_i8packed(p0);
10+
// expected-error@-1 {{too few arguments to function call, expected 3, have 1}}
11+
}
12+
13+
int test_too_few_arg2(int p0) {
14+
return __builtin_hlsl_dot4add_i8packed(p0, p0);
15+
// expected-error@-1 {{too few arguments to function call, expected 3, have 2}}
16+
}
17+
18+
int test_too_many_arg(int p0) {
19+
return __builtin_hlsl_dot4add_i8packed(p0, p0, p0, p0);
20+
// expected-error@-1 {{too many arguments to function call, expected 3, have 4}}
21+
}
22+
23+
struct S { float f; };
24+
25+
int test_expr_struct_type_check(S p0, int p1) {
26+
return __builtin_hlsl_dot4add_i8packed(p0, p1, p1);
27+
// expected-error@-1 {{no viable conversion from 'S' to 'unsigned int'}}
28+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def int_dx_udot :
6969
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
7070
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
7171
[IntrNoMem, Commutative] >;
72+
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem, Commutative] >;
7273

7374
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
7475
def int_dx_degrees : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ let TargetPrefix = "spv" in {
8383
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
8484
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
8585
[IntrNoMem, Commutative] >;
86+
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem, Commutative] >;
8687
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8788
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
8889
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,16 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> {
779779
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
780780
}
781781

782+
def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
783+
let Doc = "signed dot product of 4 x i8 vectors packed into i32, with "
784+
"accumulate to i32";
785+
let LLVMIntrinsic = int_dx_dot4add_i8packed;
786+
let arguments = [Int32Ty, Int32Ty, Int32Ty];
787+
let result = Int32Ty;
788+
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
789+
let stages = [Stages<DXIL1_0, [all_stages]>];
790+
}
791+
782792
def AnnotateHandle : DXILOp<217, annotateHandle> {
783793
let Doc = "annotate handle with resource properties";
784794
let arguments = [HandleTy, ResPropsTy];

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
164164
bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
165165
MachineInstr &I) const;
166166

167+
template <bool Signed>
168+
bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType,
169+
MachineInstr &I) const;
170+
167171
void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
168172
int OpIdx) const;
169173
void renderFImm64(MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -1694,6 +1698,84 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
16941698
return Result;
16951699
}
16961700

1701+
// Since pre-1.6 SPIRV has no DotProductInput4x8BitPacked implementation,
1702+
// extract the elements of the packed inputs, multiply them and add the result
1703+
// to the accumulator.
1704+
template <bool Signed>
1705+
bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
1706+
const SPIRVType *ResType,
1707+
MachineInstr &I) const {
1708+
assert(I.getNumOperands() == 5);
1709+
assert(I.getOperand(2).isReg());
1710+
assert(I.getOperand(3).isReg());
1711+
assert(I.getOperand(4).isReg());
1712+
MachineBasicBlock &BB = *I.getParent();
1713+
1714+
bool Result = false;
1715+
1716+
// Acc = C
1717+
Register Acc = I.getOperand(4).getReg();
1718+
SPIRVType *EltType = GR.getOrCreateSPIRVIntegerType(8, I, TII);
1719+
auto ExtractOp =
1720+
Signed ? SPIRV::OpBitFieldSExtract : SPIRV::OpBitFieldUExtract;
1721+
1722+
// Extract the i8 element, multiply and add it to the accumulator
1723+
for (unsigned i = 0; i < 4; i++) {
1724+
// A[i]
1725+
Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1726+
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
1727+
.addDef(AElt)
1728+
.addUse(GR.getSPIRVTypeID(ResType))
1729+
.addUse(I.getOperand(2).getReg())
1730+
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII))
1731+
.addImm(8)
1732+
.constrainAllUses(TII, TRI, RBI);
1733+
1734+
// B[i]
1735+
Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1736+
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
1737+
.addDef(BElt)
1738+
.addUse(GR.getSPIRVTypeID(ResType))
1739+
.addUse(I.getOperand(3).getReg())
1740+
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII))
1741+
.addImm(8)
1742+
.constrainAllUses(TII, TRI, RBI);
1743+
1744+
// A[i] * B[i]
1745+
Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1746+
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS))
1747+
.addDef(Mul)
1748+
.addUse(GR.getSPIRVTypeID(ResType))
1749+
.addUse(AElt)
1750+
.addUse(BElt)
1751+
.constrainAllUses(TII, TRI, RBI);
1752+
1753+
// Discard 24 highest-bits so that stored i32 register is i8 equivalent
1754+
Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1755+
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
1756+
.addDef(MaskMul)
1757+
.addUse(GR.getSPIRVTypeID(ResType))
1758+
.addUse(Mul)
1759+
.addUse(GR.getOrCreateConstInt(0, I, EltType, TII))
1760+
.addImm(8)
1761+
.constrainAllUses(TII, TRI, RBI);
1762+
1763+
// Acc = Acc + A[i] * B[i]
1764+
Register Sum =
1765+
i < 3 ? MRI->createVirtualRegister(&SPIRV::IDRegClass) : ResVReg;
1766+
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
1767+
.addDef(Sum)
1768+
.addUse(GR.getSPIRVTypeID(ResType))
1769+
.addUse(Acc)
1770+
.addUse(MaskMul)
1771+
.constrainAllUses(TII, TRI, RBI);
1772+
1773+
Acc = Sum;
1774+
}
1775+
1776+
return Result;
1777+
}
1778+
16971779
/// Transform saturate(x) to clamp(x, 0.0f, 1.0f) as SPIRV
16981780
/// does not have a saturate builtin.
16991781
bool SPIRVInstructionSelector::selectSaturate(Register ResVReg,
@@ -2527,6 +2609,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
25272609
case Intrinsic::spv_udot:
25282610
case Intrinsic::spv_sdot:
25292611
return selectIntegerDot(ResVReg, ResType, I);
2612+
case Intrinsic::spv_dot4add_i8packed:
2613+
return selectDot4AddPacked<true>(ResVReg, ResType, I);
25302614
case Intrinsic::spv_all:
25312615
return selectAll(ResVReg, ResType, I);
25322616
case Intrinsic::spv_any:
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
define void @main(i32 %a, i32 %b, i32 %c) {
4+
entry:
5+
; CHECK: call i32 @dx.op.dot4AddPacked(i32 163, i32 %a, i32 %b, i32 %c)
6+
%0 = call i32 @llvm.dx.dot4add.i8packed(i32 %a, i32 %b, i32 %c)
7+
ret void
8+
}
9+
10+
declare i32 @llvm.dx.dot4add.i8packed(i32, i32, i32)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
; RUN: llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 0
5+
; CHECK-DAG: %[[#int_8:]] = OpTypeInt 8 0
6+
; CHECK-DAG: %[[#zero:]] = OpConstantNull %[[#int_8]]
7+
; CHECK-DAG: %[[#eight:]] = OpConstant %[[#int_8]] 8
8+
; CHECK-DAG: %[[#sixteen:]] = OpConstant %[[#int_8]] 16
9+
; CHECK-DAG: %[[#twentyfour:]] = OpConstant %[[#int_8]] 24
10+
; CHECK-LABEL: Begin function test_dot
11+
define noundef i32 @test_dot(i32 noundef %a, i32 noundef %b, i32 noundef %c) {
12+
entry:
13+
; CHECK: %[[#A:]] = OpFunctionParameter %[[#int_32]]
14+
; CHECK: %[[#B:]] = OpFunctionParameter %[[#int_32]]
15+
; CHECK: %[[#C:]] = OpFunctionParameter %[[#int_32]]
16+
17+
; First element of the packed vector
18+
; CHECK: %[[#A0:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#zero]] 8
19+
; CHECK: %[[#B0:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#zero]] 8
20+
; CHECK: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#A0]] %[[#B0]]
21+
; CHECK: %[[#MASK0:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL0]] %[[#zero]] 8
22+
; CHECK: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#C]] %[[#MASK0]]
23+
24+
; Second element of the packed vector
25+
; CHECK: %[[#A1:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#eight]] 8
26+
; CHECK: %[[#B1:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#eight]] 8
27+
; CHECK: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#A1]] %[[#B1]]
28+
; CHECK: %[[#MASK1:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL1]] %[[#zero]] 8
29+
; CHECK: %[[#ACC1:]] = OpIAdd %[[#int_32]] %[[#ACC0]] %[[#MASK1]]
30+
31+
; Third element of the packed vector
32+
; CHECK: %[[#A2:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#sixteen]] 8
33+
; CHECK: %[[#B2:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#sixteen]] 8
34+
; CHECK: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#A2]] %[[#B2]]
35+
; CHECK: %[[#MASK2:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL2]] %[[#zero]] 8
36+
; CHECK: %[[#ACC2:]] = OpIAdd %[[#int_32]] %[[#ACC1]] %[[#MASK2]]
37+
38+
; Fourth element of the packed vector
39+
; CHECK: %[[#A3:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#twentyfour]] 8
40+
; CHECK: %[[#B3:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#twentyfour]] 8
41+
; CHECK: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#A3]] %[[#B3]]
42+
; CHECK: %[[#MASK3:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL3]] %[[#zero]] 8
43+
; CHECK: %[[#ACC3:]] = OpIAdd %[[#int_32]] %[[#ACC2]] %[[#MASK3]]
44+
45+
; CHECK: OpReturnValue %[[#ACC3]]
46+
%spv.dot = call i32 @llvm.spv.dot4add.i8packed(i32 %a, i32 %b, i32 %c)
47+
ret i32 %spv.dot
48+
}

0 commit comments

Comments
 (0)