Skip to content
Merged
38 changes: 31 additions & 7 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
Expand Down Expand Up @@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
case MVT::v4i8:
return Opcode_i32;
case MVT::f32:
return Opcode_f32;
Expand Down Expand Up @@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
// Vector Setting
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
if (SimpleVT.isVector()) {
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
"Unexpected vector type");
// v2f16/v2bf16/v2i16 is loaded using ld.b32
fromTypeWidth = 32;
}
Expand Down Expand Up @@ -1254,19 +1257,23 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
SDLoc DL(N);
SDNode *LD;
SDValue Base, Offset, Addr;
EVT OrigType = N->getValueType(0);

EVT EltVT = Mem->getMemoryVT();
unsigned NumElts = 1;
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
(EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = N->getValueType(0);
EltVT = OrigType;
NumElts /= 2;
} else if (OrigType == MVT::v4i8) {
EltVT = OrigType;
NumElts = 1;
}
}

Expand Down Expand Up @@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// concept of sign-/zero-extension, so emulate it here by adding an explicit
// CVT instruction. Ptxas should clean up any redundancies here.

EVT OrigType = N->getValueType(0);
LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);

if (OrigType != EltVT &&
Expand Down Expand Up @@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
MVT ScalarVT = SimpleVT.getScalarType();
unsigned toTypeWidth = ScalarVT.getSizeInBits();
if (SimpleVT.isVector()) {
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
"Unexpected vector type");
// v2x16 is stored using st.b32
toTypeWidth = 32;
}
Expand Down Expand Up @@ -3563,6 +3570,23 @@ bool NVPTXDAGToDAGISel::SelectADDRri64(SDNode *OpNode, SDValue Addr,
return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i64);
}

bool NVPTXDAGToDAGISel::SelectExtractEltFromV4I8(SDValue N, SDValue &V,
SDValue &BitOffset) {
SDValue Vector = N->getOperand(0);
if (!(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
Vector->getValueType(0) == MVT::v4i8))
return false;

if (const ConstantSDNode *IdxConst =
dyn_cast<ConstantSDNode>(N->getOperand(1))) {
V = Vector;
BitOffset = CurDAG->getTargetConstant(IdxConst->getZExtValue() * 8,
SDLoc(N), MVT::i32);
return true;
}
return false;
}

bool NVPTXDAGToDAGISel::ChkMemSDNodeAddressSpace(SDNode *N,
unsigned int spN) const {
const Value *Src = nullptr;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
SDValue &Offset);
bool SelectADDRsi64(SDNode *OpNode, SDValue Addr, SDValue &Base,
SDValue &Offset);
bool SelectExtractEltFromV4I8(SDValue N, SDValue &Value, SDValue &Idx);

bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;

Expand Down
87 changes: 52 additions & 35 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
llvm_unreachable("Unexpected type");
}
NumElts /= 2;
} else if (EltVT.getSimpleVT() == MVT::i8 &&
(NumElts % 4 == 0 || NumElts == 3)) {
// v*i8 are formally lowered as v4i8
EltVT = MVT::v4i8;
NumElts = (NumElts + 3) / 4;
}
for (unsigned j = 0; j != NumElts; ++j) {
ValueVTs.push_back(EltVT);
Expand Down Expand Up @@ -458,6 +463,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
Expand Down Expand Up @@ -491,6 +497,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);

// TODO: we should eventually lower it as PRMT instruction.
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Expand);
setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);

// Operations not directly supported by NVPTX.
for (MVT VT :
{MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32, MVT::f64,
Expand Down Expand Up @@ -2150,45 +2160,47 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
}

// We can init constant f16x2 with a single .b32 move. Normally it
// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
// would get lowered as two constant loads and vector-packing move.
// mov.b16 %h1, 0x4000;
// mov.b16 %h2, 0x3C00;
// mov.b32 %hh2, {%h2, %h1};
// Instead we want just a constant move:
// mov.b32 %hh2, 0x40003C00
//
// This results in better SASS code with CUDA 7.x. Ptxas in CUDA 8.0
// generates good SASS in both cases.
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op->getValueType(0);
if (!(Isv2x16VT(VT)))
if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
return Op;

if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
isa<ConstantFPSDNode>(Operand);
}))
return Op;
APInt E0;
APInt E1;
if (VT == MVT::v2f16 || VT == MVT::v2bf16) {
if (!(isa<ConstantFPSDNode>(Op->getOperand(0)) &&
isa<ConstantFPSDNode>(Op->getOperand(1))))
return Op;

E0 = cast<ConstantFPSDNode>(Op->getOperand(0))
->getValueAPF()
.bitcastToAPInt();
E1 = cast<ConstantFPSDNode>(Op->getOperand(1))
->getValueAPF()
.bitcastToAPInt();
} else {
assert(VT == MVT::v2i16);
if (!(isa<ConstantSDNode>(Op->getOperand(0)) &&
isa<ConstantSDNode>(Op->getOperand(1))))
return Op;

E0 = cast<ConstantSDNode>(Op->getOperand(0))->getAPIntValue();
E1 = cast<ConstantSDNode>(Op->getOperand(1))->getAPIntValue();
// Get value or the Nth operand as an APInt(32). Undef values treated as 0.
auto GetOperand = [](SDValue Op, int N) -> APInt {
const SDValue &Operand = Op->getOperand(N);
EVT VT = Op->getValueType(0);
if (Operand->isUndef())
return APInt(32, 0);
APInt Value;
if (VT == MVT::v2f16 || VT == MVT::v2bf16)
Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
else if (VT == MVT::v2i16 || VT == MVT::v4i8)
Value = cast<ConstantSDNode>(Operand)->getAPIntValue();
else
llvm_unreachable("Unsupported type");
return Value.zext(32);
};
APInt Value;
if (Isv2x16VT(VT)) {
Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
} else if (VT == MVT::v4i8) {
Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
} else {
llvm_unreachable("Unsupported type");
}
SDValue Const =
DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
SDValue Const = DAG.getConstant(Value, SDLoc(Op), MVT::i32);
return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
}

Expand Down Expand Up @@ -2631,7 +2643,7 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
return expandUnalignedStore(Store, DAG);

// v2f16, v2bf16 and v2i16 don't need special handling.
if (Isv2x16VT(VT))
if (Isv2x16VT(VT) || VT == MVT::v4i8)
return SDValue();

if (VT.isVector())
Expand Down Expand Up @@ -2903,7 +2915,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
EVT LoadVT = EltVT;
if (EltVT == MVT::i1)
LoadVT = MVT::i8;
else if (Isv2x16VT(EltVT))
else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
// getLoad needs a vector type, but it can't handle
// vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
Expand All @@ -2929,7 +2941,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (EltVT == MVT::i1)
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
// v2f16 was loaded as an i32. Now we must bitcast it back.
else if (Isv2x16VT(EltVT))
else if (EltVT != LoadVT)
Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);

// If a promoted integer type is used, truncate down to the original
Expand Down Expand Up @@ -5258,9 +5270,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
IsPTXVectorType(VectorVT.getSimpleVT()))
return SDValue(); // Native vector loads already combine nicely w/
// extract_vector_elt.
// extract_vector_elt, except for v4i8.
// Don't mess with singletons or v2*16 types, we already handle them OK.
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT))
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
VectorVT == MVT::v4i8)
return SDValue();

uint64_t VectorBits = VectorVT.getSizeInBits();
Expand Down Expand Up @@ -5289,6 +5302,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
// If element has non-integer type, bitcast it back to the expected type.
if (EltVT != EltIVT)
Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result);
// Past legalizer, we may need to extent i8 -> i16 to match the register type.
if (EltVT != N->getValueType(0))
Result = DCI.DAG.getNode(ISD::ANY_EXTEND, DL, N->getValueType(0), Result);

return Result;
}

Expand Down
Loading