Skip to content

Commit 7b18520

Browse files
committed
[RFC][WIP][AMDGPU] Use bf16 instead of i16 for bfloat
Currently it looks like we generally use `i16` to represent `bf16` in those tablegen files. I'm not sure of the reason behind it. My wild guess is the type `bf16` was not available when we enabled the support. This patch is trying to use `bf16` directly in those tablegen files, aiming at fixing #79369. Of course for #79369 a workaround can be to treat all `INT16` variants as `BFloat` in `getOpFltSemantics`, but it doesn't look good IMHO. Since I'm fairly new to AMDGPU backend, I'd appreciate it if you can point out where I don't understand correctly.
1 parent d42f395 commit 7b18520

File tree

7 files changed

+87
-19
lines changed

7 files changed

+87
-19
lines changed

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -2839,7 +2839,7 @@ def int_amdgcn_fdot2_f32_bf16 :
28392839
llvm_v2i16_ty, // %b
28402840
llvm_float_ty, // %c
28412841
llvm_i1_ty // %clamp
2842-
],
2842+
],
28432843
[IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<3>>]
28442844
>;
28452845

llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
474474

475475
bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); }
476476

477+
bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); }
478+
477479
bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); }
478480

479481
bool isSSrcV2F16() const {
@@ -540,22 +542,40 @@ class AMDGPUOperand : public MCParsedAsmOperand {
540542
return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64);
541543
}
542544

545+
bool isVCSrcTBF16() const {
546+
return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16);
547+
}
548+
543549
bool isVCSrcTF16() const {
544550
return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16);
545551
}
546552

553+
bool isVCSrcTBF16_Lo128() const {
554+
return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16);
555+
}
556+
547557
bool isVCSrcTF16_Lo128() const {
548558
return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16);
549559
}
550560

561+
bool isVCSrcFake16BF16_Lo128() const {
562+
return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16);
563+
}
564+
551565
bool isVCSrcFake16F16_Lo128() const {
552566
return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16);
553567
}
554568

569+
bool isVCSrc_bf16() const {
570+
return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16);
571+
}
572+
555573
bool isVCSrc_f16() const {
556574
return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16);
557575
}
558576

577+
bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); }
578+
559579
bool isVCSrc_v2f16() const { return isVCSrc_f16(); }
560580

561581
bool isVSrc_b32() const {
@@ -596,18 +616,34 @@ class AMDGPUOperand : public MCParsedAsmOperand {
596616

597617
bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }
598618

619+
bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }
620+
599621
bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }
600622

623+
bool isVSrcT_bf16_Lo128() const {
624+
return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16);
625+
}
626+
601627
bool isVSrcT_f16_Lo128() const {
602628
return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16);
603629
}
604630

631+
bool isVSrcFake16_bf16_Lo128() const {
632+
return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16);
633+
}
634+
605635
bool isVSrcFake16_f16_Lo128() const {
606636
return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16);
607637
}
608638

639+
bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); }
640+
609641
bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); }
610642

643+
bool isVSrc_v2bf16() const {
644+
return isVSrc_bf16() || isLiteralImm(MVT::v2bf16);
645+
}
646+
611647
bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); }
612648

613649
bool isVISrcB32() const {
@@ -634,6 +670,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
634670
return isVISrcF16() || isVISrcB32();
635671
}
636672

673+
bool isVISrc_64_bf16() const {
674+
return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16);
675+
}
676+
637677
bool isVISrc_64_f16() const {
638678
return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16);
639679
}
@@ -802,6 +842,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
802842
return isAISrc_128F16() || isAISrc_128_b32();
803843
}
804844

845+
bool isVISrc_128_bf16() const {
846+
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16);
847+
}
848+
805849
bool isVISrc_128_f16() const {
806850
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16);
807851
}

llvm/lib/Target/AMDGPU/SIDefines.h

+7
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,12 @@ enum OperandType : unsigned {
196196
OPERAND_REG_IMM_INT16,
197197
OPERAND_REG_IMM_FP32,
198198
OPERAND_REG_IMM_FP64,
199+
OPERAND_REG_IMM_BF16,
199200
OPERAND_REG_IMM_FP16,
201+
OPERAND_REG_IMM_BF16_DEFERRED,
200202
OPERAND_REG_IMM_FP16_DEFERRED,
201203
OPERAND_REG_IMM_FP32_DEFERRED,
204+
OPERAND_REG_IMM_V2BF16,
202205
OPERAND_REG_IMM_V2FP16,
203206
OPERAND_REG_IMM_V2INT16,
204207
OPERAND_REG_IMM_V2INT32,
@@ -208,10 +211,12 @@ enum OperandType : unsigned {
208211
OPERAND_REG_INLINE_C_INT16,
209212
OPERAND_REG_INLINE_C_INT32,
210213
OPERAND_REG_INLINE_C_INT64,
214+
OPERAND_REG_INLINE_C_BF16,
211215
OPERAND_REG_INLINE_C_FP16,
212216
OPERAND_REG_INLINE_C_FP32,
213217
OPERAND_REG_INLINE_C_FP64,
214218
OPERAND_REG_INLINE_C_V2INT16,
219+
OPERAND_REG_INLINE_C_V2BF16,
215220
OPERAND_REG_INLINE_C_V2FP16,
216221
OPERAND_REG_INLINE_C_V2INT32,
217222
OPERAND_REG_INLINE_C_V2FP32,
@@ -226,10 +231,12 @@ enum OperandType : unsigned {
226231
/// Operands with an AccVGPR register or inline constant
227232
OPERAND_REG_INLINE_AC_INT16,
228233
OPERAND_REG_INLINE_AC_INT32,
234+
OPERAND_REG_INLINE_AC_BF16,
229235
OPERAND_REG_INLINE_AC_FP16,
230236
OPERAND_REG_INLINE_AC_FP32,
231237
OPERAND_REG_INLINE_AC_FP64,
232238
OPERAND_REG_INLINE_AC_V2INT16,
239+
OPERAND_REG_INLINE_AC_V2BF16,
233240
OPERAND_REG_INLINE_AC_V2FP16,
234241
OPERAND_REG_INLINE_AC_V2INT32,
235242
OPERAND_REG_INLINE_AC_V2FP32,

llvm/lib/Target/AMDGPU/SIInstrInfo.td

+12-15
Original file line numberDiff line numberDiff line change
@@ -1490,20 +1490,17 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
14901490
RegisterOperand ret =
14911491
!if(VT.isFP,
14921492
!if(!eq(VT.Size, 64),
1493-
VSrc_f64,
1494-
!if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
1495-
!if(IsTrue16,
1496-
!if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
1497-
VSrc_f16
1498-
),
1499-
!if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
1500-
VSrc_v2f16,
1501-
!if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
1502-
AVSrc_64,
1503-
VSrc_f32
1493+
VSrc_f64,
1494+
!if(!eq(VT.Value, f16.Value),
1495+
!if(IsTrue16, !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128), VSrc_f16),
1496+
!if(!eq(VT.Value, bf16.Value),
1497+
!if(IsTrue16, !if(IsFake16, VSrcFake16_bf16_Lo128, VSrcT_bf16_Lo128), VSrc_bf16),
1498+
!if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
1499+
!if(!eq(VT.Value, v2f16.Value), VSrc_v2f16, VSrc_v2bf16),
1500+
!if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), AVSrc_64, VSrc_f32)
1501+
)
15041502
)
1505-
)
1506-
)
1503+
)
15071504
),
15081505
!if(!eq(VT.Size, 64),
15091506
VSrc_b64,
@@ -2513,8 +2510,8 @@ def VOP_V2I16_F32_F32 : VOPProfile <[v2i16, f32, f32, untyped]>;
25132510
def VOP_V2I16_I32_I32 : VOPProfile <[v2i16, i32, i32, untyped]>;
25142511

25152512
def VOP_F16_V2F16_V2F16_F16 : VOPProfile <[f16, v2f16, v2f16, f16]>;
2516-
def VOP_I16_V2I16_V2I16_I16 : VOPProfile <[i16, v2i16, v2i16, i16]>;
2517-
def VOP_F32_V2I16_V2I16_F32 : VOPProfile <[f32, v2i16, v2i16, f32]>;
2513+
def VOP_BF16_V2BF16_V2BF16_BF16: VOPProfile <[bf16, v2bf16, v2bf16, bf16]>;
2514+
def VOP_F32_V2BF16_V2BF16_F32 : VOPProfile <[f32, v2bf16, v2bf16, f32]>;
25182515

25192516
def VOP_F32_V2F16_V2F16_V2F16 : VOPProfile <[f32, v2f16, v2f16, v2f16]>;
25202517

llvm/lib/Target/AMDGPU/SIRegisterInfo.td

+21-1
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ multiclass AVRegClass<int numRegs, list<ValueType> regTypes,
10661066
// Define the regular class.
10671067
def "" : VRegClassBase<numRegs, regTypes, (add vregList, aregList)>;
10681068

1069-
// Define 2-aligned variant
1069+
// Define 2-aligned variant
10701070
def _Align2 : VRegClassBase<numRegs, regTypes,
10711071
(add (decimate vregList, 2),
10721072
(decimate aregList, 2))> {
@@ -1115,6 +1115,7 @@ class RegOrImmOperand <string RegisterClassName, string OperandTypeName,
11151115
//===----------------------------------------------------------------------===//
11161116

11171117
def SSrc_b16 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_INT16", "_Imm16">;
1118+
def SSrc_bf16: RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_BF16", "_Imm16">;
11181119
def SSrc_f16 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_FP16", "_Imm16">;
11191120
def SSrc_b32 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_INT32", "_Imm32">;
11201121
def SSrc_f32 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_FP32", "_Imm32">;
@@ -1142,13 +1143,18 @@ def SCSrc_b64 : RegOrImmOperand <"SReg_64", "OPERAND_REG_INLINE_C_INT64", "_Imm6
11421143

11431144
// The current and temporary future default used case for VOP3.
11441145
def VSrc_b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_INT16", "_Imm16">;
1146+
def VSrc_bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_BF16", "_Imm16">;
11451147
def VSrc_f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_FP16", "_Imm16">;
11461148

11471149
// True16 VOP3 operands.
11481150
def VSrcT_b16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_INT16", "_Imm16"> {
11491151
let EncoderMethod = "getMachineOpValueT16";
11501152
let DecoderMethod = "decodeOperand_VSrcT16";
11511153
}
1154+
def VSrcT_bf16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_BF16", "_Imm16"> {
1155+
let EncoderMethod = "getMachineOpValueT16";
1156+
let DecoderMethod = "decodeOperand_VSrcT16";
1157+
}
11521158
def VSrcT_f16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_FP16", "_Imm16"> {
11531159
let EncoderMethod = "getMachineOpValueT16";
11541160
let DecoderMethod = "decodeOperand_VSrcT16";
@@ -1159,6 +1165,10 @@ def VSrcT_b16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_INT16", "
11591165
let EncoderMethod = "getMachineOpValueT16Lo128";
11601166
let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
11611167
}
1168+
def VSrcT_bf16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_BF16", "_Imm16"> {
1169+
let EncoderMethod = "getMachineOpValueT16Lo128";
1170+
let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
1171+
}
11621172
def VSrcT_f16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_FP16", "_Imm16"> {
11631173
let EncoderMethod = "getMachineOpValueT16Lo128";
11641174
let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
@@ -1167,11 +1177,13 @@ def VSrcT_f16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_FP16", "_
11671177
// The current and temporary future default used case for fake VOP1/2/C.
11681178
// For VOP1,2,C True16 instructions. _Lo128 use first 128 32-bit VGPRs only.
11691179
def VSrcFake16_b16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_INT16", "_Imm16">;
1180+
def VSrcFake16_bf16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_BF16", "_Imm16">;
11701181
def VSrcFake16_f16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_FP16", "_Imm16">;
11711182

11721183
def VSrc_b32 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_INT32", "_Imm32">;
11731184
def VSrc_f32 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_FP32", "_Imm32">;
11741185
def VSrc_v2b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2INT16", "_ImmV2I16">;
1186+
def VSrc_v2bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2BF16", "_ImmV2F16">;
11751187
def VSrc_v2f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2FP16", "_ImmV2F16">;
11761188
def VSrc_b64 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_INT64", "_Imm64">;
11771189
def VSrc_f64 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_FP64", "_Imm64"> {
@@ -1185,9 +1197,13 @@ def VSrc_v2f32 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_V2FP32", "_Imm32">;
11851197
// with FMAMK/FMAAK
11861198
//===----------------------------------------------------------------------===//
11871199

1200+
def VSrc_bf16_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_BF16_DEFERRED", "_Deferred_Imm16">;
11881201
def VSrc_f16_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_FP16_DEFERRED", "_Deferred_Imm16">;
11891202
def VSrc_f32_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_FP32_DEFERRED", "_Deferred_Imm32">;
11901203

1204+
def VSrcFake16_bf16_Lo128_Deferred : RegOrImmOperand<"VS_32_Lo128",
1205+
"OPERAND_REG_IMM_BF16_DEFERRED",
1206+
"_Deferred_Imm16">;
11911207
def VSrcFake16_f16_Lo128_Deferred : RegOrImmOperand<"VS_32_Lo128",
11921208
"OPERAND_REG_IMM_FP16_DEFERRED",
11931209
"_Deferred_Imm16">;
@@ -1252,19 +1268,23 @@ def ARegSrc_32 : AVOperand<AGPR_32, "decodeSrcA9", "OPW32">;
12521268
//===----------------------------------------------------------------------===//
12531269

12541270
def VCSrc_b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_INT16", "_Imm16">;
1271+
def VCSrc_bf16: RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
12551272
def VCSrc_f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
12561273
def VCSrc_b32 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
12571274
def VCSrc_f32 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_FP32", "_Imm32">;
12581275
def VCSrc_v2b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2INT16", "_ImmV2I16">;
1276+
def VCSrc_v2bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2BF16", "_ImmV2F16">;
12591277
def VCSrc_v2f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2FP16", "_ImmV2F16">;
12601278

12611279
//===----------------------------------------------------------------------===//
12621280
// VISrc_* Operands with a VGPR or an inline constant
12631281
//===----------------------------------------------------------------------===//
12641282

1283+
def VISrc_64_bf16 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
12651284
def VISrc_64_f16 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
12661285
def VISrc_64_b32 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
12671286
def VISrc_64_f64 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_FP64", "_Imm64">;
1287+
def VISrc_128_bf16 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
12681288
def VISrc_128_f16 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
12691289
def VISrc_128_b32 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
12701290
def VISrc_128_f32 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_FP32", "_Imm32">;

llvm/lib/Target/AMDGPU/VOP3Instructions.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,7 @@ let SubtargetPredicate = isGFX12Plus, ReadsModeReg = 0 in {
904904

905905
let SubtargetPredicate = HasDot9Insts, IsDOT=1 in {
906906
defm V_DOT2_F16_F16 : VOP3Inst<"v_dot2_f16_f16", VOP3_DOT_Profile<VOP_F16_V2F16_V2F16_F16>, int_amdgcn_fdot2_f16_f16>;
907-
defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_I16_V2I16_V2I16_I16>, int_amdgcn_fdot2_bf16_bf16>;
907+
defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_BF16_V2BF16_V2BF16_BF16>, int_amdgcn_fdot2_bf16_bf16>;
908908
}
909909

910910
class VOP_Pseudo_Scalar<RegisterClass Dst, RegisterOperand SrcOp,

llvm/lib/Target/AMDGPU/VOP3PInstructions.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ defm V_DOT8_I32_I4 : VOP3PInst<"v_dot8_i32_i4",
396396
} // End OtherPredicates = [HasDot1Insts]
397397

398398
def DOT2_BF16_Profile
399-
: VOP3P_Profile<VOP_F32_V2I16_V2I16_F32, VOP3_REGULAR, /*HasDPP*/ 1> {
399+
: VOP3P_Profile<VOP_F32_V2BF16_V2BF16_F32, VOP3_REGULAR, /*HasDPP*/ 1> {
400400
let HasSrc1Mods = 1;
401401
}
402402

0 commit comments

Comments
 (0)