Skip to content

[AMDGPU] Replace isInlinableLiteral16 with specific version #84402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3327,35 +3327,41 @@ bool AMDGPUDAGToDAGISel::SelectWMMAVISrc(SDValue In, SDValue &Src) const {

// 16 bit splat
SDValue SplatSrc32 = stripBitcast(In);
if (auto *SplatSrc32BV = dyn_cast<BuildVectorSDNode>(SplatSrc32)) {
if (auto *SplatSrc32BV = dyn_cast<BuildVectorSDNode>(SplatSrc32))
if (SDValue Splat32 = SplatSrc32BV->getSplatValue()) {
SDValue SplatSrc16 = stripBitcast(Splat32);
if (auto *SplatSrc16BV = dyn_cast<BuildVectorSDNode>(SplatSrc16)) {
if (auto *SplatSrc16BV = dyn_cast<BuildVectorSDNode>(SplatSrc16))
if (SDValue Splat = SplatSrc16BV->getSplatValue()) {

// f16
if (isInlineImmediate(Splat.getNode())) {
const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Splat);
int64_t Imm = C->getValueAPF().bitcastToAPInt().getSExtValue();
Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i16);
return true;
}

// bf16
if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat)) {
const SIInstrInfo *TII = Subtarget->getInstrInfo();
APInt BF16Value = C->getAPIntValue();
APInt F32Value = BF16Value.zext(32).shl(16);
if (TII->isInlineConstant(F32Value)) {
int64_t Imm = F32Value.getSExtValue();
Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i32);
return true;
}
const SIInstrInfo *TII = Subtarget->getInstrInfo();
std::optional<APInt> RawValue;
if (const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Splat))
RawValue = C->getValueAPF().bitcastToAPInt();
else if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat))
RawValue = C->getAPIntValue();

if (RawValue.has_value()) {
EVT VT = In.getValueType().getScalarType();
if (VT.getSimpleVT() == MVT::f16 || VT.getSimpleVT() == MVT::bf16) {
APFloat FloatVal(VT.getSimpleVT() == MVT::f16
? APFloatBase::IEEEhalf()
: APFloatBase::BFloat(),
RawValue.value());
if (TII->isInlineConstant(FloatVal)) {
Src = CurDAG->getTargetConstant(RawValue.value(), SDLoc(In),
MVT::i16);
return true;
}
} else if (VT.getSimpleVT() == MVT::i16) {
if (TII->isInlineConstant(RawValue.value())) {
Src = CurDAG->getTargetConstant(RawValue.value(), SDLoc(In),
MVT::i16);
return true;
}
} else
llvm_unreachable("unknown 16-bit type");
}
}
}
}
}

return false;
}
Expand Down
97 changes: 75 additions & 22 deletions llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1926,6 +1926,11 @@ static const fltSemantics *getFltSemantics(MVT VT) {

static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
switch (OperandType) {
// When floating-point immediate is used as operand of type i16, the 32-bit
// representation of the constant truncated to the 16 LSBs should be used.
case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
case AMDGPU::OPERAND_REG_IMM_INT32:
case AMDGPU::OPERAND_REG_IMM_FP32:
case AMDGPU::OPERAND_REG_IMM_FP32_DEFERRED:
Expand All @@ -1949,13 +1954,10 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
case AMDGPU::OPERAND_REG_INLINE_C_FP64:
case AMDGPU::OPERAND_REG_INLINE_AC_FP64:
return &APFloat::IEEEdouble();
case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
Expand Down Expand Up @@ -2001,13 +2003,15 @@ static bool isSafeTruncation(int64_t Val, unsigned Size) {
}

static bool isInlineableLiteralOp16(int64_t Val, MVT VT, bool HasInv2Pi) {
if (VT.getScalarType() == MVT::i16) {
// FP immediate values are broken.
return isInlinableIntLiteral(Val);
}
if (VT.getScalarType() == MVT::i16)
return isInlinableLiteral32(Val, HasInv2Pi);

if (VT.getScalarType() == MVT::f16)
return AMDGPU::isInlinableLiteralFP16(Val, HasInv2Pi);

// f16/v2f16 operands work correctly for all values.
return AMDGPU::isInlinableLiteral16(Val, HasInv2Pi);
assert(VT.getScalarType() == MVT::bf16);

return AMDGPU::isInlinableLiteralBF16(Val, HasInv2Pi);
}

bool AMDGPUOperand::isInlinableImm(MVT type) const {
Expand Down Expand Up @@ -2041,9 +2045,30 @@ bool AMDGPUOperand::isInlinableImm(MVT type) const {
return false;

if (type.getScalarSizeInBits() == 16) {
return isInlineableLiteralOp16(
static_cast<int16_t>(FPLiteral.bitcastToAPInt().getZExtValue()),
type, AsmParser->hasInv2PiInlineImm());
bool Lost = false;
switch (type.getScalarType().SimpleTy) {
default:
llvm_unreachable("unknown 16-bit type");
case MVT::bf16:
FPLiteral.convert(APFloatBase::BFloat(), APFloat::rmNearestTiesToEven,
&Lost);
break;
case MVT::f16:
FPLiteral.convert(APFloatBase::IEEEhalf(), APFloat::rmNearestTiesToEven,
&Lost);
break;
case MVT::i16:
FPLiteral.convert(APFloatBase::IEEEsingle(),
APFloat::rmNearestTiesToEven, &Lost);
break;
}
// We need to use 32-bit representation here because when a floating-point
// inline constant is used as an i16 operand, its 32-bit representation
// representation will be used. We will need the 32-bit value to check if
// it is FP inline constant.
uint32_t ImmVal = FPLiteral.bitcastToAPInt().getZExtValue();
return isInlineableLiteralOp16(ImmVal, type,
AsmParser->hasInv2PiInlineImm());
}

// Check if single precision literal is inlinable
Expand Down Expand Up @@ -2375,15 +2400,26 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
return;

case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
if (isSafeTruncation(Val, 16) &&
AMDGPU::isInlinableIntLiteral(static_cast<int16_t>(Val))) {
Inst.addOperand(MCOperand::createImm(Val & 0xffffffff));
setImmKindConst();
return;
}

Inst.addOperand(MCOperand::createImm(Val & 0xffff));
setImmKindLiteral();
return;

case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
if (isSafeTruncation(Val, 16) &&
AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm())) {
AMDGPU::isInlinableLiteralFP16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm())) {
Inst.addOperand(MCOperand::createImm(Val));
setImmKindConst();
return;
Expand All @@ -2410,12 +2446,17 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
return;

case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16: {
assert(isSafeTruncation(Val, 16));
assert(AMDGPU::isInlinableIntLiteral(static_cast<int16_t>(Val)));
Inst.addOperand(MCOperand::createImm(Val));
return;
}
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
assert(isSafeTruncation(Val, 16));
assert(AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm()));
assert(AMDGPU::isInlinableLiteralFP16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm()));

Inst.addOperand(MCOperand::createImm(Val));
return;
Expand Down Expand Up @@ -3542,7 +3583,7 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
if (OperandType == AMDGPU::OPERAND_REG_IMM_INT16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_C_INT16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_INT16)
return AMDGPU::isInlinableIntLiteral(Val);
return AMDGPU::isInlinableLiteralI16(Val, hasInv2PiInlineImm());

if (OperandType == AMDGPU::OPERAND_REG_INLINE_C_V2INT16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_V2INT16 ||
Expand All @@ -3559,7 +3600,19 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
OperandType == AMDGPU::OPERAND_REG_IMM_V2BF16)
return AMDGPU::isInlinableLiteralV2BF16(Val);

return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm());
if (OperandType == AMDGPU::OPERAND_REG_IMM_FP16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_C_FP16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_FP16 ||
OperandType == AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED)
return AMDGPU::isInlinableLiteralFP16(Val, hasInv2PiInlineImm());

if (OperandType == AMDGPU::OPERAND_REG_IMM_BF16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_C_BF16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_BF16 ||
OperandType == AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED)
return AMDGPU::isInlinableLiteralBF16(Val, hasInv2PiInlineImm());

llvm_unreachable("invalid operand type");
}
default:
llvm_unreachable("invalid operand size");
Expand Down
29 changes: 15 additions & 14 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,19 +451,20 @@ void AMDGPUInstPrinter::printVINTRPDst(const MCInst *MI, unsigned OpNo,
void AMDGPUInstPrinter::printImmediateInt16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int16_t SImm = static_cast<int16_t>(Imm);
int32_t SImm = static_cast<int32_t>(Imm);
if (isInlinableIntLiteral(SImm)) {
O << SImm;
} else {
uint64_t Imm16 = static_cast<uint16_t>(Imm);
O << formatHex(Imm16);
return;
}

if (printImmediateFloat32(Imm, STI, O))
return;

O << formatHex(static_cast<uint64_t>(Imm & 0xffff));
}

// This must accept a 32-bit immediate value to correctly handle packed 16-bit
// operations.
static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O) {
static bool printImmediateFP16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O) {
if (Imm == 0x3C00)
O << "1.0";
else if (Imm == 0xBC00)
Expand Down Expand Up @@ -529,17 +530,17 @@ void AMDGPUInstPrinter::printImmediateBF16(uint32_t Imm,
O << formatHex(static_cast<uint64_t>(Imm));
}

void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
void AMDGPUInstPrinter::printImmediateF16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int16_t SImm = static_cast<int16_t>(Imm);
if (isInlinableIntLiteral(SImm)) {
O << SImm;
return;
}

uint16_t HImm = static_cast<uint16_t>(Imm);
if (printImmediateFloat16(HImm, STI, O))
if (printImmediateFP16(HImm, STI, O))
return;

uint64_t Imm16 = static_cast<uint16_t>(Imm);
Expand All @@ -566,7 +567,7 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
if (isUInt<16>(Imm) &&
printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
printImmediateFP16(static_cast<uint16_t>(Imm), STI, O))
return;
break;
case AMDGPU::OPERAND_REG_IMM_V2BF16:
Expand Down Expand Up @@ -845,7 +846,7 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
printImmediate16(Op.getImm(), STI, O);
printImmediateF16(Op.getImm(), STI, O);
break;
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ class AMDGPUInstPrinter : public MCInstPrinter {
raw_ostream &O);
void printImmediateInt16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateBF16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateF16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateV216(uint32_t Imm, uint8_t OpType,
const MCSubtargetInfo &STI, raw_ostream &O);
bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI,
Expand Down
11 changes: 5 additions & 6 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,6 @@ static uint32_t getIntInlineImmEncoding(IntTy Imm) {
return 0;
}

static uint32_t getLit16IntEncoding(uint16_t Val, const MCSubtargetInfo &STI) {
uint16_t IntImm = getIntInlineImmEncoding(static_cast<int16_t>(Val));
return IntImm == 0 ? 255 : IntImm;
}

static uint32_t getLit16Encoding(uint16_t Val, const MCSubtargetInfo &STI) {
uint16_t IntImm = getIntInlineImmEncoding(static_cast<int16_t>(Val));
if (IntImm != 0)
Expand Down Expand Up @@ -214,6 +209,10 @@ static uint32_t getLit32Encoding(uint32_t Val, const MCSubtargetInfo &STI) {
return 255;
}

static uint32_t getLit16IntEncoding(uint32_t Val, const MCSubtargetInfo &STI) {
return getLit32Encoding(Val, STI);
}

static uint32_t getLit64Encoding(uint64_t Val, const MCSubtargetInfo &STI) {
uint32_t IntImm = getIntInlineImmEncoding(static_cast<int64_t>(Val));
if (IntImm != 0)
Expand Down Expand Up @@ -296,7 +295,7 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);
return getLit16IntEncoding(static_cast<uint32_t>(Imm), STI);

case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
Expand Down
28 changes: 22 additions & 6 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15495,16 +15495,32 @@ bool SITargetLowering::checkAsmConstraintVal(SDValue Op, StringRef Constraint,
llvm_unreachable("Invalid asm constraint");
}

bool SITargetLowering::checkAsmConstraintValA(SDValue Op,
uint64_t Val,
bool SITargetLowering::checkAsmConstraintValA(SDValue Op, uint64_t Val,
unsigned MaxSize) const {
unsigned Size = std::min<unsigned>(Op.getScalarValueSizeInBits(), MaxSize);
bool HasInv2Pi = Subtarget->hasInv2PiInlineImm();
if ((Size == 16 && AMDGPU::isInlinableLiteral16(Val, HasInv2Pi)) ||
(Size == 32 && AMDGPU::isInlinableLiteral32(Val, HasInv2Pi)) ||
(Size == 64 && AMDGPU::isInlinableLiteral64(Val, HasInv2Pi))) {
return true;
if (Size == 16) {
MVT VT = Op.getSimpleValueType();
switch (VT.SimpleTy) {
default:
return false;
case MVT::i16:
return AMDGPU::isInlinableLiteralI16(Val, HasInv2Pi);
case MVT::f16:
return AMDGPU::isInlinableLiteralFP16(Val, HasInv2Pi);
case MVT::bf16:
return AMDGPU::isInlinableLiteralBF16(Val, HasInv2Pi);
case MVT::v2i16:
return AMDGPU::getInlineEncodingV2I16(Val).has_value();
case MVT::v2f16:
return AMDGPU::getInlineEncodingV2F16(Val).has_value();
case MVT::v2bf16:
return AMDGPU::getInlineEncodingV2BF16(Val).has_value();
}
}
if ((Size == 32 && AMDGPU::isInlinableLiteral32(Val, HasInv2Pi)) ||
(Size == 64 && AMDGPU::isInlinableLiteral64(Val, HasInv2Pi)))
return true;
return false;
}

Expand Down
Loading