Skip to content

[AMDGPU] Use bf16 instead of i16 for bfloat #80908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ typedef unsigned short __attribute__((ext_vector_type(2))) ushort2;
// CHECK: call float @llvm.amdgcn.fdot2(<2 x half> %v2hA, <2 x half> %v2hB, float %fC, i1 false)
// CHECK: call float @llvm.amdgcn.fdot2(<2 x half> %v2hA, <2 x half> %v2hB, float %fC, i1 true)
// CHECK: call half @llvm.amdgcn.fdot2.f16.f16(<2 x half> %v2hA, <2 x half> %v2hB, half %hC)
// CHECK: call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, i16 %sC)
// CHECK: [[s1:%[0-9]+]] = bitcast <2 x i16> %v2ssA to <2 x bfloat>
// CHECK-NEXT: [[s2:%[0-9]+]] = bitcast <2 x i16> %v2ssB to <2 x bfloat>
// CHECK-NEXT: [[s3:%[0-9]+]] = bitcast i16 %sC to bfloat
// CHECK-NEXT: [[d:%[0-9]+]] = tail call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> [[s1]], <2 x bfloat> [[s2]], bfloat [[s3]])
// CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, float %fC, i1 false)
// CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, float %fC, i1 true)
// CHECK: call i32 @llvm.amdgcn.udot4(i32 %uiA, i32 %uiB, i32 %uiC, i1 false)
Expand Down
8 changes: 4 additions & 4 deletions llvm/include/llvm/IR/IntrinsicsAMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -2819,11 +2819,11 @@ def int_amdgcn_fdot2_f16_f16 :
def int_amdgcn_fdot2_bf16_bf16 :
ClangBuiltin<"__builtin_amdgcn_fdot2_bf16_bf16">,
DefaultAttrsIntrinsic<
[llvm_i16_ty], // %r
[llvm_bfloat_ty], // %r
[
llvm_v2i16_ty, // %a
llvm_v2i16_ty, // %b
llvm_i16_ty // %c
llvm_v2bf16_ty, // %a
llvm_v2bf16_ty, // %b
llvm_bfloat_ty // %c
],
[IntrNoMem, IntrSpeculatable]
>;
Expand Down
92 changes: 92 additions & 0 deletions llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {

bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); }

bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); }

bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); }

bool isSSrcV2F16() const {
Expand Down Expand Up @@ -541,22 +543,40 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64);
}

bool isVCSrcTBF16() const {
return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16);
}

bool isVCSrcTF16() const {
return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16);
}

bool isVCSrcTBF16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16);
}

bool isVCSrcTF16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16);
}

bool isVCSrcFake16BF16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16);
}

bool isVCSrcFake16F16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16);
}

bool isVCSrc_bf16() const {
return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16);
}

bool isVCSrc_f16() const {
return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16);
}

bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); }

bool isVCSrc_v2f16() const { return isVCSrc_f16(); }

bool isVSrc_b32() const {
Expand Down Expand Up @@ -597,18 +617,34 @@ class AMDGPUOperand : public MCParsedAsmOperand {

bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }

bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }

bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }

bool isVSrcT_bf16_Lo128() const {
return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16);
}

bool isVSrcT_f16_Lo128() const {
return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16);
}

bool isVSrcFake16_bf16_Lo128() const {
return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16);
}

bool isVSrcFake16_f16_Lo128() const {
return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16);
}

bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); }

bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); }

bool isVSrc_v2bf16() const {
return isVSrc_bf16() || isLiteralImm(MVT::v2bf16);
}

bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); }

bool isVISrcB32() const {
Expand All @@ -635,6 +671,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isVISrcF16() || isVISrcB32();
}

bool isVISrc_64_bf16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16);
}

bool isVISrc_64_f16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16);
}
Expand Down Expand Up @@ -803,6 +843,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isAISrc_128F16() || isAISrc_128_b32();
}

bool isVISrc_128_bf16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16);
}

bool isVISrc_128_f16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16);
}
Expand Down Expand Up @@ -1890,6 +1934,14 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_KIMM16:
return &APFloat::IEEEhalf();
case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
return &APFloat::BFloat();
default:
llvm_unreachable("unsupported fp type");
}
Expand Down Expand Up @@ -2186,17 +2238,24 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
case AMDGPU::OPERAND_REG_IMM_V2FP32:
Expand Down Expand Up @@ -2240,6 +2299,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2FP32:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
Expand Down Expand Up @@ -2295,6 +2355,22 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
setImmKindLiteral();
return;

case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
if (isSafeTruncation(Val, 16) &&
AMDGPU::isInlinableLiteralBF16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm())) {
Inst.addOperand(MCOperand::createImm(Val));
setImmKindConst();
return;
}

Inst.addOperand(MCOperand::createImm(Val & 0xffff));
setImmKindLiteral();
return;

case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
Expand All @@ -2306,6 +2382,17 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
Inst.addOperand(MCOperand::createImm(Val));
return;
}

case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: {
assert(isSafeTruncation(Val, 16));
assert(AMDGPU::isInlinableLiteralBF16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm()));

Inst.addOperand(MCOperand::createImm(Val));
return;
}

case AMDGPU::OPERAND_KIMM32:
Inst.addOperand(MCOperand::createImm(Literal.getLoBits(32).getZExtValue()));
setImmKindMandatoryLiteral();
Expand Down Expand Up @@ -3429,6 +3516,11 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
OperandType == AMDGPU::OPERAND_REG_IMM_V2FP16)
return AMDGPU::isInlinableLiteralV2F16(Val);

if (OperandType == AMDGPU::OPERAND_REG_INLINE_C_V2BF16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_V2BF16 ||
OperandType == AMDGPU::OPERAND_REG_IMM_V2BF16)
return AMDGPU::isInlinableLiteralV2BF16(Val);

return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm());
}
default:
Expand Down
57 changes: 57 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,47 @@ static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
return true;
}

static bool printImmediateBFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O) {
if (Imm == 0x3F80)
O << "1.0";
else if (Imm == 0xBF80)
O << "-1.0";
else if (Imm == 0x3F00)
O << "0.5";
else if (Imm == 0xBF00)
O << "-0.5";
else if (Imm == 0x4000)
O << "2.0";
else if (Imm == 0xC000)
O << "-2.0";
else if (Imm == 0x4080)
O << "4.0";
else if (Imm == 0xC080)
O << "-4.0";
else if (Imm == 0x3E22 && STI.hasFeature(AMDGPU::FeatureInv2PiInlineImm))
O << "0.15915494";
else
return false;

return true;
}

void AMDGPUInstPrinter::printImmediateBF16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int16_t SImm = static_cast<int16_t>(Imm);
if (isInlinableIntLiteral(SImm)) {
O << SImm;
return;
}

if (printImmediateBFloat16(static_cast<uint16_t>(Imm), STI, O))
return;

O << formatHex(static_cast<uint64_t>(Imm));
}

void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
Expand Down Expand Up @@ -528,6 +569,13 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
return;
break;
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
if (isUInt<16>(Imm) &&
printImmediateBFloat16(static_cast<uint16_t>(Imm), STI, O))
return;
break;
default:
llvm_unreachable("bad operand type");
}
Expand Down Expand Up @@ -799,11 +847,20 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
printImmediate16(Op.getImm(), STI, O);
break;
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
printImmediateBF16(Op.getImm(), STI, O);
break;
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
printImmediateV216(Op.getImm(), OpTy, STI, O);
break;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class AMDGPUInstPrinter : public MCInstPrinter {
raw_ostream &O);
void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateBF16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateV216(uint32_t Imm, uint8_t OpType,
const MCSubtargetInfo &STI, raw_ostream &O);
bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI,
Expand Down
39 changes: 39 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,27 @@ static uint32_t getLit16Encoding(uint16_t Val, const MCSubtargetInfo &STI) {
return 255;
}

static uint32_t getLitBF16Encoding(uint16_t Val) {
uint16_t IntImm = getIntInlineImmEncoding(static_cast<int16_t>(Val));
if (IntImm != 0)
return IntImm;

// clang-format off
switch (Val) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this call getInlineEncodingV2BF16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, yes, but for now we can't because getInlineEncodingV2BF16 can't handle some cases (that I didn't dig yet). It looks like in the conversion between uint16_t and uint32_t that makes some test cases fail. IMO we need to unify them (not only for 16-bit) in one place instead of having almost the same logic at least in three places.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I really don't like having 4 different copies of this list of hex values (0x3f00, 0xbf00...).

case 0x3F00: return 240; // 0.5
case 0xBF00: return 241; // -0.5
case 0x3F80: return 242; // 1.0
case 0xBF80: return 243; // -1.0
case 0x4000: return 244; // 2.0
case 0xC000: return 245; // -2.0
case 0x4080: return 246; // 4.0
case 0xC080: return 247; // -4.0
case 0x3E22: return 248; // 1.0 / (2.0 * pi)
default: return 255;
}
// clang-format on
}

static uint32_t getLit32Encoding(uint32_t Val, const MCSubtargetInfo &STI) {
uint32_t IntImm = getIntInlineImmEncoding(static_cast<int32_t>(Val));
if (IntImm != 0)
Expand Down Expand Up @@ -276,23 +297,41 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);

case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
// FIXME Is this correct? What do inline immediates do on SI for f16 src
// which does not have f16 support?
return getLit16Encoding(static_cast<uint16_t>(Imm), STI);

case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
// We don't actually need to check Inv2Pi here because BF16 instructions can
// only be emitted for targets that already support the feature.
return getLitBF16Encoding(static_cast<uint16_t>(Imm));

case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
return AMDGPU::getInlineEncodingV2I16(static_cast<uint32_t>(Imm))
.value_or(255);

case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
return AMDGPU::getInlineEncodingV2F16(static_cast<uint32_t>(Imm))
.value_or(255);

case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
return AMDGPU::getInlineEncodingV2BF16(static_cast<uint32_t>(Imm))
.value_or(255);

case AMDGPU::OPERAND_KIMM32:
case AMDGPU::OPERAND_KIMM16:
return MO.getImm();
Expand Down
Loading