Skip to content

[Spirv][HLSL] Add OpAll lowering and float vec support #87952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 10, 2024
218 changes: 191 additions & 27 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/ADT/APInt.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/TypedPointerType.h"
#include "llvm/Support/Casting.h"
#include <cassert>

using namespace llvm;
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
Expand All @@ -35,6 +40,15 @@ SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
return SpirvType;
}

SPIRVType *
SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
MachineInstr &I,
const SPIRVInstrInfo &TII) {
SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
return SpirvType;
}

SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
const SPIRVInstrInfo &TII) {
Expand Down Expand Up @@ -151,6 +165,8 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
Expand All @@ -164,9 +180,83 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
return std::make_tuple(Res, CI, NewInstr);
}

std::tuple<Register, ConstantFP *, bool, unsigned>
SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
MachineIRBuilder *MIRBuilder,
MachineInstr *I,
const SPIRVInstrInfo *TII) {
const Type *LLVMFloatTy;
LLVMContext &Ctx = CurMF->getFunction().getContext();
unsigned BitWidth = 32;
if (SpvType)
LLVMFloatTy = getTypeForSPIRVType(SpvType);
else {
LLVMFloatTy = Type::getFloatTy(Ctx);
if (MIRBuilder)
SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder);
}
bool NewInstr = false;
// Find a constant in DT or build a new one.
auto *const CI = ConstantFP::get(Ctx, Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
if (SpvType)
BitWidth = getScalarOrVectorBitWidth(SpvType);
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
if (MIRBuilder)
assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
else
assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
DT.add(CI, CurMF, Res);
NewInstr = true;
}
return std::make_tuple(Res, CI, NewInstr, BitWidth);
}

Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
assert(SpvType);
ConstantFP *CI;
Register Res;
bool New;
unsigned BitWidth;
std::tie(Res, CI, New, BitWidth) =
getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII);
// If we have found Res register which is defined by the passed G_CONSTANT
// machine instruction, a new constant instruction should be created.
if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
return Res;
MachineInstrBuilder MIB;
MachineBasicBlock &BB = *I.getParent();
// In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
if (Val.isPosZero() && ZeroAsNull) {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
} else {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(
APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
MIB);
}
const auto &ST = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
*ST.getRegisterInfo(), *ST.getRegBankInfo());
return Res;
}

Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII) {
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
assert(SpvType);
ConstantInt *CI;
Register Res;
Expand All @@ -179,7 +269,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
return Res;
MachineInstrBuilder MIB;
MachineBasicBlock &BB = *I.getParent();
if (Val) {
if (Val || !ZeroAsNull) {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
Expand Down Expand Up @@ -270,21 +360,46 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
return Res;
}

Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
Register SPIRVGlobalRegistry::getOrCreateBaseRegister(Constant *Val,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
unsigned BitWidth) {
SPIRVType *Type = SpvType;
if (SpvType->getOpcode() == SPIRV::OpTypeVector ||
SpvType->getOpcode() == SPIRV::OpTypeArray) {
auto EleTypeReg = SpvType->getOperand(1).getReg();
Type = getSPIRVTypeForVReg(EleTypeReg);
}
if (Type->getOpcode() == SPIRV::OpTypeFloat) {
SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I,
SpvBaseType, TII);
}
assert(Type->getOpcode() == SPIRV::OpTypeInt);
SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
return getOrCreateConstInt(Val->getUniqueInteger().getSExtValue(), I,
SpvBaseType, TII);
}

Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
Constant *Val, MachineInstr &I, SPIRVType *SpvType,
const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
unsigned ElemCnt) {
unsigned ElemCnt, bool ZeroAsNull) {
// Find a constant vector in DT or build a new one.
Register Res = DT.find(CA, CurMF);
// If no values are attached, the composite is null constant.
bool IsNull = Val->isNullValue() && ZeroAsNull;
if (!Res.isValid()) {
SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
// SpvScalConst should be created before SpvVecConst to avoid undefined ID
// error on validation.
// TODO: can moved below once sorting of types/consts/defs is implemented.
Register SpvScalConst;
if (Val)
SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
// TODO: maybe use bitwidth of base type.
if (!IsNull)
SpvScalConst = getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth);

// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
Expand All @@ -293,7 +408,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
DT.add(CA, CurMF, SpvVecConst);
MachineInstrBuilder MIB;
MachineBasicBlock &BB = *I.getParent();
if (Val) {
if (!IsNull) {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
Expand All @@ -313,20 +428,42 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
return Res;
}

Register
SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII) {
Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy->isVectorTy());
const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
Type *LLVMBaseTy = LLVMVecTy->getElementType();
const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
auto ConstVec =
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
assert(LLVMBaseTy->isIntegerTy());
auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val);
auto *ConstVec =
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
unsigned BW = getScalarOrVectorBitWidth(SpvType);
return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW,
SpvType->getOperand(2).getImm());
return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
SpvType->getOperand(2).getImm(),
ZeroAsNull);
}

Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy->isVectorTy());
const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
Type *LLVMBaseTy = LLVMVecTy->getElementType();
assert(LLVMBaseTy->isFloatingPointTy());
auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val);
auto *ConstVec =
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
unsigned BW = getScalarOrVectorBitWidth(SpvType);
return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
SpvType->getOperand(2).getImm(),
ZeroAsNull);
}

Register
Expand All @@ -337,13 +474,13 @@ SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
assert(LLVMTy->isArrayTy());
const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
Type *LLVMBaseTy = LLVMArrTy->getElementType();
const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
auto ConstArr =
auto *ConstInt = ConstantInt::get(LLVMBaseTy, Val);
auto *ConstArr =
ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW,
LLVMArrTy->getNumElements());
return getOrCreateCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW,
LLVMArrTy->getNumElements());
}

Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
Expand Down Expand Up @@ -1093,21 +1230,48 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
return SpirvType;
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
MachineInstr &I,
const SPIRVInstrInfo &TII,
unsigned SPIRVOPcode,
Type *LLVMTy) {
Register Reg = DT.find(LLVMTy, CurMF);
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(0);
DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
return finishCreatingSPIRVType(LLVMTy, MIB);
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
LLVMContext &Ctx = CurMF->getFunction().getContext();
Type *LLVMTy;
switch (BitWidth) {
case 16:
LLVMTy = Type::getHalfTy(Ctx);
break;
case 32:
LLVMTy = Type::getFloatTy(Ctx);
break;
case 64:
LLVMTy = Type::getDoubleTy(Ctx);
break;
default:
llvm_unreachable("Bit width is of unexpected size.");
}
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
}

SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
Expand Down
42 changes: 33 additions & 9 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "SPIRVDuplicatesTracker.h"
#include "SPIRVInstrInfo.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/IR/Constant.h"

namespace llvm {
using SPIRVType = const MachineInstr;
Expand Down Expand Up @@ -234,6 +235,8 @@ class SPIRVGlobalRegistry {
bool EmitIR = true);
SPIRVType *assignIntTypeToVReg(unsigned BitWidth, Register VReg,
MachineInstr &I, const SPIRVInstrInfo &TII);
SPIRVType *assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
MachineInstr &I, const SPIRVInstrInfo &TII);
SPIRVType *assignVectTypeToVReg(SPIRVType *BaseType, unsigned NumElements,
Register VReg, MachineInstr &I,
const SPIRVInstrInfo &TII);
Expand Down Expand Up @@ -367,12 +370,20 @@ class SPIRVGlobalRegistry {
std::tuple<Register, ConstantInt *, bool> getOrCreateConstIntReg(
uint64_t Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
std::tuple<Register, ConstantFP *, bool, unsigned> getOrCreateConstFloatReg(
APFloat Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
SPIRVType *finishCreatingSPIRVType(const Type *LLVMTy, SPIRVType *SpirvType);
Register getOrCreateIntCompositeOrNull(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
Constant *CA, unsigned BitWidth,
unsigned ElemCnt);
Register getOrCreateBaseRegister(Constant *Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
unsigned BitWidth);
Register getOrCreateCompositeOrNull(Constant *Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII, Constant *CA,
unsigned BitWidth, unsigned ElemCnt,
bool ZeroAsNull = true);

Register getOrCreateIntCompositeOrNull(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
Expand All @@ -383,12 +394,20 @@ class SPIRVGlobalRegistry {
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType = nullptr, bool EmitIR = true);
Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII);
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType = nullptr);
Register getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII);

Register getOrCreateConstVector(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
Register getOrCreateConstVector(APFloat Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII);
Expand Down Expand Up @@ -418,6 +437,11 @@ class SPIRVGlobalRegistry {
MachineIRBuilder &MIRBuilder);
SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineInstr &I,
const SPIRVInstrInfo &TII);
SPIRVType *getOrCreateSPIRVType(unsigned BitWidth, MachineInstr &I,
const SPIRVInstrInfo &TII,
unsigned SPIRVOPcode, Type *LLVMTy);
SPIRVType *getOrCreateSPIRVFloatType(unsigned BitWidth, MachineInstr &I,
const SPIRVInstrInfo &TII);
SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder);
SPIRVType *getOrCreateSPIRVBoolType(MachineInstr &I,
const SPIRVInstrInfo &TII);
Expand Down
Loading