Skip to content

[AMDGPU][MC][True16] VOP3dot instruction update for true16/fake16 #113474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions llvm/lib/Target/AMDGPU/VOP3Instructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,30 @@ class VOP3_DOT_Profile<VOPProfile P> : VOP3_Profile<P, VOP3_OPSEL> {
let HasOMod = 0;
}

class VOP3_DOT_Profile_t16<VOPProfile P, VOP3Features Features = VOP3_REGULAR> : VOP3_Profile_True16<P, Features> {
let HasClamp = 0;
let HasOMod = 0;
// Override modifiers for bf16(i16) (same as float modifiers).
let HasSrc0Mods = 1;
let HasSrc1Mods = 1;
let HasSrc2Mods = 1;
let Src0ModVOP3DPP = FPVRegInputMods;
let Src1ModVOP3DPP = FP32VCSrcInputMods;
let Src2ModVOP3DPP = FPT16VCSrcInputMods</*IsFake16*/0>;
}

class VOP3_DOT_Profile_fake16<VOPProfile P, VOP3Features Features = VOP3_REGULAR> : VOP3_Profile_Fake16<P, Features> {
let HasClamp = 0;
let HasOMod = 0;
// Override modifiers for bf16(i16) (same as float modifiers).
let HasSrc0Mods = 1;
let HasSrc1Mods = 1;
let HasSrc2Mods = 1;
let AsmVOP3Base = getAsmVOP3Base<NumSrcArgs, HasDst, HasClamp,
HasOpSel, HasOMod, IsVOP3P, HasModifiers, 1/*HasSrc0Mods*/, 1/*HasSrc1Mods*/,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: single space indentation.

1/*HasSrc2Mods*/, DstVT>.ret;
}

let SubtargetPredicate = isGFX11Plus in {
defm V_MAXMIN_F32 : VOP3Inst<"v_maxmin_f32", VOP3_Profile<VOP_F32_F32_F32_F32>>;
defm V_MINMAX_F32 : VOP3Inst<"v_minmax_f32", VOP3_Profile<VOP_F32_F32_F32_F32>>;
Expand Down Expand Up @@ -1409,9 +1433,15 @@ let SubtargetPredicate = isGFX12Plus, ReadsModeReg = 0 in {
defm V_MINIMUMMAXIMUM_F16 : VOP3Inst<"v_minimummaximum_f16", VOP3_Profile<VOP_F16_F16_F16_F16, VOP3_OPSEL>>;
} // End SubtargetPredicate = isGFX12Plus, ReadsModeReg = 0

let OtherPredicates = [HasDot9Insts], IsDOT=1 in {
defm V_DOT2_F16_F16 : VOP3Inst<"v_dot2_f16_f16", VOP3_DOT_Profile<VOP_F16_V2F16_V2F16_F16>, int_amdgcn_fdot2_f16_f16>;
defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_BF16_V2BF16_V2BF16_BF16>, int_amdgcn_fdot2_bf16_bf16>;
let SubtargetPredicate = HasDot9Insts, IsDOT=1 in {
defm V_DOT2_F16_F16 : VOP3Inst_t16_with_profiles<"v_dot2_f16_f16", VOP3_DOT_Profile<VOP_F16_V2F16_V2F16_F16>,
VOP3_DOT_Profile_t16<VOP_F16_V2F16_V2F16_F16>,
VOP3_DOT_Profile_fake16<VOP_F16_V2F16_V2F16_F16>,
int_amdgcn_fdot2_f16_f16>;
defm V_DOT2_BF16_BF16 : VOP3Inst_t16_with_profiles<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_BF16_V2BF16_V2BF16_BF16>,
VOP3_DOT_Profile_t16<VOP_BF16_V2BF16_V2BF16_BF16>,
VOP3_DOT_Profile_fake16<VOP_BF16_V2BF16_V2BF16_BF16>,
int_amdgcn_fdot2_bf16_bf16>;
}

class VOP_Pseudo_Scalar<RegisterClass Dst, RegisterOperand SrcOp,
Expand Down Expand Up @@ -1609,8 +1639,10 @@ multiclass VOP3_Realtriple_with_name_gfx11_gfx12<bits<10> op, string opName,
VOP3_Realtriple_with_name<GFX11Gen, op, opName, asmName>,
VOP3_Realtriple_with_name<GFX12Gen, op, opName, asmName>;

multiclass VOP3Dot_Realtriple_gfx11_gfx12<bits<10> op> :
VOP3Dot_Realtriple<GFX11Gen, op>, VOP3Dot_Realtriple<GFX12Gen, op>;
multiclass VOP3Dot_Realtriple_t16_and_fake16_gfx11_gfx12<bits<10> op, string asmName, string opName = NAME> {
defm _t16: VOP3Dot_Realtriple_gfx11_gfx12<op, asmName, 0, opName#"_t16">;
defm _fake16: VOP3Dot_Realtriple_gfx11_gfx12<op, asmName, 0, opName#"_fake16">;
}

multiclass VOP3_Realtriple_t16_gfx11_gfx12<bits<10> op, string asmName, string opName = NAME,
string pseudo_mnemonic = "", bit isSingle = 0> :
Expand Down Expand Up @@ -1702,8 +1734,8 @@ defm V_MAXMIN_U32 : VOP3_Realtriple_gfx11_gfx12<0x262>;
defm V_MINMAX_U32 : VOP3_Realtriple_gfx11_gfx12<0x263>;
defm V_MAXMIN_I32 : VOP3_Realtriple_gfx11_gfx12<0x264>;
defm V_MINMAX_I32 : VOP3_Realtriple_gfx11_gfx12<0x265>;
defm V_DOT2_F16_F16 : VOP3Dot_Realtriple_gfx11_gfx12<0x266>;
defm V_DOT2_BF16_BF16 : VOP3Dot_Realtriple_gfx11_gfx12<0x267>;
defm V_DOT2_F16_F16 : VOP3Dot_Realtriple_t16_and_fake16_gfx11_gfx12<0x266, "v_dot2_f16_f16">;
defm V_DOT2_BF16_BF16 : VOP3Dot_Realtriple_t16_and_fake16_gfx11_gfx12<0x267, "v_dot2_bf16_bf16">;
defm V_DIV_SCALE_F32 : VOP3be_Real_gfx11_gfx12<0x2fc, "V_DIV_SCALE_F32", "v_div_scale_f32">;
defm V_DIV_SCALE_F64 : VOP3be_Real_gfx11_gfx12<0x2fd, "V_DIV_SCALE_F64", "v_div_scale_f64">;
defm V_MAD_U64_U32_gfx11 : VOP3be_Real_gfx11<0x2fe, "V_MAD_U64_U32_gfx11", "v_mad_u64_u32">;
Expand Down
41 changes: 29 additions & 12 deletions llvm/lib/Target/AMDGPU/VOPInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,12 @@ class VOP3FP8OpSel_src_bytesel_gfx11_gfx12<bits<10> op, VOPProfile p> : VOP3e_gf
let Inst{14-13} = byte_sel; // op_sel2/3
}

class VOP3DotOpSel_gfx11_gfx12<bits<10> op, VOPProfile p> : VOP3OpSel_gfx11_gfx12<op, p>{
class VOP3DotOpSel_gfx11_gfx12<bits<10> op, VOPProfile p> :
VOP3e_t16_gfx11_gfx12<op, p>{
let Inst{11} = ?;
let Inst{12} = ?;
let Inst{13} = !if(p.HasSrc2Mods, src2_modifiers{2}, 0);
let Inst{14} = !if(!and(p.HasDst, p.HasSrc0Mods), src0_modifiers{3}, 0);
}

// NB: For V_INTERP* opcodes, src0 is encoded as src1 and vice versa
Expand Down Expand Up @@ -1706,10 +1709,12 @@ multiclass VOP3_Real_Base<GFXGen Gen, bits<10> op, string opName = NAME,
}
}

multiclass VOP3Dot_Real_Base<GFXGen Gen, bits<10> op, string opName = NAME,
multiclass VOP3Dot_Real_Base<GFXGen Gen, bits<10> op, string asmName, string opName = NAME,
bit isSingle = 0> {
defvar ps = !cast<VOP_Pseudo>(opName#"_e64");
let IsSingle = !or(isSingle, ps.Pfl.IsSingle) in {
let AsmString = asmName # ps.AsmOperands,
DecoderNamespace = Gen.DecoderNamespace # !if(ps.Pfl.IsRealTrue16, "", "_FAKE16"),
IsSingle = !or(isSingle, ps.Pfl.IsSingle) in {
def _e64#Gen.Suffix :
VOP3_Real_Gen<ps, Gen>,
VOP3DotOpSel_gfx11_gfx12<op, ps.Pfl>;
Expand Down Expand Up @@ -1773,9 +1778,13 @@ multiclass VOP3_Real_dpp_Base<GFXGen Gen, bits<10> op, string opName = NAME> {
VOP3_DPP16_Gen<op, ps, Gen>;
}

multiclass VOP3Dot_Real_dpp_Base<GFXGen Gen, bits<10> op, string opName = NAME> {
multiclass VOP3Dot_Real_dpp_Base<GFXGen Gen, bits<10> op, string asmName, string opName = NAME> {
defvar ps = !cast<VOP_DPP_Pseudo>(opName#"_e64"#"_dpp");
def _e64_dpp#Gen.Suffix :
VOP3_DPP16_Gen<op, !cast<VOP_DPP_Pseudo>(opName#"_e64"#"_dpp"), Gen> {
VOP3_DPP16_Gen_t16<op, ps, Gen> {
let AsmString = asmName # ps.Pfl.AsmVOP3DPP16;
let DecoderNamespace = Gen.DecoderNamespace
# !if(ps.Pfl.IsRealTrue16, "", "_FAKE16");
let Inst{11} = ?;
let Inst{12} = ?;
}
Expand All @@ -1797,12 +1806,14 @@ multiclass VOP3_Real_dpp8_Base<GFXGen Gen, bits<10> op, string opName = NAME> {
}
}

multiclass VOP3Dot_Real_dpp8_Base<GFXGen Gen, bits<10> op, string opName = NAME> {
multiclass VOP3Dot_Real_dpp8_Base<GFXGen Gen, bits<10> op, string asmName, string opName = NAME> {
defvar ps = !cast<VOP3_Pseudo>(opName#"_e64");
def _e64_dpp8#Gen.Suffix : Base_VOP3_DPP8<op, ps> {
def _e64_dpp8#Gen.Suffix : Base_VOP3_DPP8_t16<op, ps> {
let Inst{11} = ?;
let Inst{12} = ?;
let DecoderNamespace = Gen.DecoderNamespace;
let AsmString = asmName # ps.Pfl.AsmVOP3DPP8;
let DecoderNamespace = Gen.DecoderNamespace
# !if(ps.Pfl.IsRealTrue16, "", "_FAKE16");
let AssemblerPredicate = Gen.AssemblerPredicate;
}
}
Expand Down Expand Up @@ -1855,11 +1866,11 @@ multiclass VOP3_Realtriple<GFXGen Gen, bits<10> op, bit isSingle = 0,
VOP3_Real_dpp_Base<Gen, op, opName>,
VOP3_Real_dpp8_Base<Gen, op, opName>;

multiclass VOP3Dot_Realtriple<GFXGen Gen, bits<10> op, bit isSingle = 0,
multiclass VOP3Dot_Realtriple<GFXGen Gen, bits<10> op, string asmName, bit isSingle = 0,
string opName = NAME> :
VOP3Dot_Real_Base<Gen, op, opName, isSingle>,
VOP3Dot_Real_dpp_Base<Gen, op, opName>,
VOP3Dot_Real_dpp8_Base<Gen, op, opName>;
VOP3Dot_Real_Base<Gen, op, asmName, opName, isSingle>,
VOP3Dot_Real_dpp_Base<Gen, op, asmName, opName>,
VOP3Dot_Real_dpp8_Base<Gen, op, asmName, opName>;

multiclass VOP3Only_Realtriple<GFXGen Gen, bits<10> op> :
VOP3_Realtriple<Gen, op, 1>;
Expand Down Expand Up @@ -1957,6 +1968,12 @@ multiclass VOP3Only_Realtriple_with_name_gfx11_gfx12<bits<10> op, string opName,
VOP3Only_Realtriple_with_name<GFX11Gen, op, opName, asmName>,
VOP3Only_Realtriple_with_name<GFX12Gen, op, opName, asmName>;

multiclass VOP3Dot_Realtriple_gfx11_gfx12<bits<10> op, string asmName, bit isSingle = 0,
string opName = NAME> :
VOP3Dot_Realtriple<GFX11Gen, op, asmName, isSingle, opName>,
VOP3Dot_Realtriple<GFX12Gen, op, asmName, isSingle, opName>;


//===----------------------------------------------------------------------===//

include "VOPCInstructions.td"
Expand Down
Loading
Loading