[NFC][SPIR-V] Use getScalarOrVectorComponent{Count,Type} instead of raw operand access#193410
Open
[NFC][SPIR-V] Use getScalarOrVectorComponent{Count,Type} instead of raw operand access#193410
Conversation
Member
|
@llvm/pr-subscribers-backend-spir-v Author: Arseniy Obolenskiy (aobolensk) ChangesReplace direct accesses to SPIR-V type instruction operands with the existing getScalarOrVectorComponentCount() and getScalarOrVectorComponentType() helpers Full diff: https://github.com/llvm/llvm-project/pull/193410.diff 5 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 1e6d3c52990f7..099c88d123223 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -446,7 +446,7 @@ buildBoolRegister(MachineIRBuilder &MIRBuilder, SPIRVTypeInst ResultType,
SPIRVTypeInst BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);
if (ResultType->getOpcode() == SPIRV::OpTypeVector) {
- unsigned VectorElements = ResultType->getOperand(2).getImm();
+ unsigned VectorElements = GR->getScalarOrVectorComponentCount(ResultType);
BoolType = GR->getOrCreateSPIRVVectorType(BoolType, VectorElements,
MIRBuilder, true);
const FixedVectorType *LLVMVectorType =
@@ -1334,7 +1334,7 @@ static bool generateRelationalInst(const SPIRV::IncomingCall *Call,
if ((Opcode == SPIRV::OpAny || Opcode == SPIRV::OpAll) &&
!GR->isScalarOrVectorOfType(Arguments[0], SPIRV::OpTypeBool)) {
SPIRVTypeInst ArgType = GR->getSPIRVTypeForVReg(Arguments[0]);
- unsigned NumElts = ArgType->getOperand(2).getImm();
+ unsigned NumElts = GR->getScalarOrVectorComponentCount(ArgType);
SPIRVTypeInst BoolVecTy = GR->getOrCreateSPIRVVectorType(
GR->getOrCreateSPIRVBoolType(MIRBuilder, /*EmitIR=*/true), NumElts,
MIRBuilder, /*EmitIR=*/true);
@@ -1796,8 +1796,8 @@ static bool generateBuiltinVar(const SPIRV::IncomingCall *Call,
unsigned BitWidth = GR->getScalarOrVectorBitWidth(Call->ReturnType);
LLT LLType;
if (Call->ReturnType->getOpcode() == SPIRV::OpTypeVector)
- LLType =
- LLT::fixed_vector(Call->ReturnType->getOperand(2).getImm(), BitWidth);
+ LLType = LLT::fixed_vector(
+ GR->getScalarOrVectorComponentCount(Call->ReturnType), BitWidth);
else
LLType = LLT::scalar(BitWidth);
@@ -2159,9 +2159,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
// vector, expect only a single size component. Otherwise get the number of
// expected components.
unsigned NumExpectedRetComponents =
- Call->ReturnType->getOpcode() == SPIRV::OpTypeVector
- ? Call->ReturnType->getOperand(2).getImm()
- : 1;
+ GR->getScalarOrVectorComponentCount(Call->ReturnType);
// Get the actual number of query result/size components.
SPIRVTypeInst ImgType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
unsigned NumActualRetComponents = getNumSizeComponents(ImgType);
@@ -2201,10 +2199,12 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
SPIRVTypeInst NewType = nullptr;
if (QueryResultType->getOpcode() == SPIRV::OpTypeVector) {
- Register NewTypeReg = QueryResultType->getOperand(1).getReg();
- if (TypeReg != NewTypeReg &&
- (NewType = GR->getSPIRVTypeForVReg(NewTypeReg)))
+ NewType = GR->getScalarOrVectorComponentType(QueryResultType);
+ Register NewTypeReg = GR->getSPIRVTypeID(NewType);
+ if (TypeReg != NewTypeReg)
TypeReg = NewTypeReg;
+ else
+ NewType = nullptr;
}
MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
.addDef(Call->ReturnRegister)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index d5827e30ce17d..f991512143e71 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -645,7 +645,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(const APInt &Val,
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
unsigned BW = getScalarOrVectorBitWidth(SpvType);
return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
- SpvType->getOperand(2).getImm(),
+ getScalarOrVectorComponentCount(SpvType),
ZeroAsNull);
}
@@ -664,7 +664,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
unsigned BW = getScalarOrVectorBitWidth(SpvType);
return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
- SpvType->getOperand(2).getImm(),
+ getScalarOrVectorComponentCount(SpvType),
ZeroAsNull);
}
@@ -745,9 +745,9 @@ Register SPIRVGlobalRegistry::getOrCreateConsIntVector(
auto ConstVec =
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
unsigned BW = getScalarOrVectorBitWidth(SpvType);
- return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
- ConstVec, BW,
- SpvType->getOperand(2).getImm());
+ return getOrCreateIntCompositeOrNull(
+ Val, MIRBuilder, SpvType, EmitIR, ConstVec, BW,
+ getScalarOrVectorComponentCount(SpvType));
}
Register
@@ -1448,14 +1448,11 @@ SPIRVGlobalRegistry::getScalarOrVectorComponentType(SPIRVTypeInst Type) const {
unsigned
SPIRVGlobalRegistry::getScalarOrVectorBitWidth(SPIRVTypeInst Type) const {
assert(Type && "Invalid Type pointer");
- if (Type->getOpcode() == SPIRV::OpTypeVector) {
- auto EleTypeReg = Type->getOperand(1).getReg();
- Type = getSPIRVTypeForVReg(EleTypeReg);
- }
- if (Type->getOpcode() == SPIRV::OpTypeInt ||
- Type->getOpcode() == SPIRV::OpTypeFloat)
- return Type->getOperand(1).getImm();
- if (Type->getOpcode() == SPIRV::OpTypeBool)
+ SPIRVTypeInst ScalarType = getScalarOrVectorComponentType(Type);
+ if (ScalarType->getOpcode() == SPIRV::OpTypeInt ||
+ ScalarType->getOpcode() == SPIRV::OpTypeFloat)
+ return ScalarType->getOperand(1).getImm();
+ if (ScalarType->getOpcode() == SPIRV::OpTypeBool)
return 1;
llvm_unreachable("Attempting to get bit width of non-integer/float type.");
}
@@ -1463,22 +1460,19 @@ SPIRVGlobalRegistry::getScalarOrVectorBitWidth(SPIRVTypeInst Type) const {
unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
SPIRVTypeInst Type) const {
assert(Type && "Invalid Type pointer");
- unsigned NumElements = 1;
- if (Type->getOpcode() == SPIRV::OpTypeVector) {
- NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
- Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
- }
- return Type->getOpcode() == SPIRV::OpTypeInt ||
- Type->getOpcode() == SPIRV::OpTypeFloat
- ? NumElements * Type->getOperand(1).getImm()
+ unsigned NumElements = getScalarOrVectorComponentCount(Type);
+ SPIRVTypeInst ScalarType = getScalarOrVectorComponentType(Type);
+ return ScalarType->getOpcode() == SPIRV::OpTypeInt ||
+ ScalarType->getOpcode() == SPIRV::OpTypeFloat
+ ? NumElements * ScalarType->getOperand(1).getImm()
: 0;
}
SPIRVTypeInst
SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(SPIRVTypeInst Type) const {
- if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
- Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
- return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
+ SPIRVTypeInst ScalarType = getScalarOrVectorComponentType(Type);
+ return ScalarType && ScalarType->getOpcode() == SPIRV::OpTypeInt ? ScalarType
+ : nullptr;
}
bool SPIRVGlobalRegistry::isScalarOrVectorSigned(SPIRVTypeInst Type) const {
@@ -2099,8 +2093,7 @@ SPIRVGlobalRegistry::getRegClass(SPIRVTypeInst SpvType) const {
case SPIRV::OpTypePointer:
return &SPIRV::pIDRegClass;
case SPIRV::OpTypeVector: {
- SPIRVTypeInst ElemType =
- getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
+ SPIRVTypeInst ElemType = getScalarOrVectorComponentType(SpvType);
unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0;
if (ElemOpcode == SPIRV::OpTypeFloat)
return &SPIRV::vfIDRegClass;
@@ -2128,8 +2121,7 @@ LLT SPIRVGlobalRegistry::getRegType(SPIRVTypeInst SpvType) const {
case SPIRV::OpTypePointer:
return LLT::pointer(getAS(SpvType), getPointerSize());
case SPIRV::OpTypeVector: {
- SPIRVTypeInst ElemType =
- getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
+ SPIRVTypeInst ElemType = getScalarOrVectorComponentType(SpvType);
LLT ET;
switch (ElemType ? ElemType->getOpcode() : 0) {
case SPIRV::OpTypePointer:
@@ -2143,8 +2135,7 @@ LLT SPIRVGlobalRegistry::getRegType(SPIRVTypeInst SpvType) const {
default:
ET = LLT::scalar(64);
}
- return LLT::fixed_vector(
- static_cast<unsigned>(SpvType->getOperand(2).getImm()), ET);
+ return LLT::fixed_vector(getScalarOrVectorComponentCount(SpvType), ET);
}
}
return LLT::scalar(64);
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 2fcb71939322a..6267596e44596 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -546,12 +546,13 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
SPIRVTypeInst Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
SPIRVTypeInst RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
assert(RetType && "Expected return type");
- validatePtrTypes(STI, MRI, GR, MI, MI.getNumOperands() - 1,
- RetType->getOpcode() != SPIRV::OpTypeVector
- ? Int32Type
- : GR.getOrCreateSPIRVVectorType(
- Int32Type, RetType->getOperand(2).getImm(),
- MIB, false));
+ validatePtrTypes(
+ STI, MRI, GR, MI, MI.getNumOperands() - 1,
+ RetType->getOpcode() != SPIRV::OpTypeVector
+ ? Int32Type
+ : GR.getOrCreateSPIRVVectorType(
+ Int32Type, GR.getScalarOrVectorComponentCount(RetType),
+ MIB, false));
} break;
case SPIRV::OpenCLExtInst::fract:
case SPIRV::OpenCLExtInst::modf:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f1e0450bb20f9..881bc8f9f2c0e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2290,8 +2290,7 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
report_fatal_error(
"cannot select G_UNMERGE_VALUES with a non-vector argument");
- SPIRVTypeInst ScalarType =
- GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
+ SPIRVTypeInst ScalarType = GR.getScalarOrVectorComponentType(SrcType);
MachineBasicBlock &BB = *I.getParent();
unsigned CurrentIndex = 0;
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
@@ -2804,7 +2803,7 @@ bool SPIRVInstructionSelector::selectAnyOrAll(Register ResVReg,
NotEqualReg =
IsBoolTy ? InputRegister
: createVirtualRegister(SpvBoolTy, &GR, MRI, MRI->getMF());
- const unsigned NumElts = InputType->getOperand(2).getImm();
+ const unsigned NumElts = GR.getScalarOrVectorComponentCount(InputType);
SpvBoolTy = GR.getOrCreateSPIRVVectorType(SpvBoolTy, NumElts, I, TII);
}
@@ -2857,7 +2856,7 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
"dot product requires a vector of at least 2 components");
[[maybe_unused]] SPIRVTypeInst EltType =
- GR.getSPIRVTypeForVReg(VecType->getOperand(1).getReg());
+ GR.getScalarOrVectorComponentType(VecType);
assert(EltType->getOpcode() == SPIRV::OpTypeFloat);
@@ -3233,15 +3232,6 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return true;
}
-unsigned getVectorSizeOrOne(SPIRVTypeInst Type) {
-
- if (Type->getOpcode() != SPIRV::OpTypeVector)
- return 1;
-
- // Operand(2) is the vector size
- return Type->getOperand(2).getImm();
-}
-
bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
SPIRVTypeInst ResType,
MachineInstr &I) const {
@@ -3253,16 +3243,12 @@ bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
SPIRVTypeInst InputType = GR.getSPIRVTypeForVReg(InputReg);
// Determine if input is vector
- unsigned NumElems = getVectorSizeOrOne(InputType);
+ unsigned NumElems = GR.getScalarOrVectorComponentCount(InputType);
bool IsVector = NumElems > 1;
// Determine element types
- SPIRVTypeInst ElemInputType = InputType;
- SPIRVTypeInst ElemBoolType = ResType;
- if (IsVector) {
- ElemInputType = GR.getSPIRVTypeForVReg(InputType->getOperand(1).getReg());
- ElemBoolType = GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg());
- }
+ SPIRVTypeInst ElemInputType = GR.getScalarOrVectorComponentType(InputType);
+ SPIRVTypeInst ElemBoolType = GR.getScalarOrVectorComponentType(ResType);
// Subgroup scope constant
SPIRVTypeInst IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
@@ -3903,10 +3889,8 @@ bool SPIRVInstructionSelector::selectExp10(Register ResVReg,
MachineIRBuilder MIRBuilder(I);
- SPIRVTypeInst SpirvScalarType = ResType->getOpcode() == SPIRV::OpTypeVector
- ? SPIRVTypeInst(GR.getSPIRVTypeForVReg(
- ResType->getOperand(1).getReg()))
- : ResType;
+ SPIRVTypeInst SpirvScalarType =
+ GR.getScalarOrVectorComponentType(ResType);
assert(SpirvScalarType->getOperand(1).getImm() == 32 &&
"only float operands supported by GLSL extended math");
@@ -4098,7 +4082,7 @@ bool SPIRVInstructionSelector::selectIToF(Register ResVReg,
unsigned BitWidth = GR.getScalarOrVectorBitWidth(ResType);
SPIRVTypeInst TmpType = GR.getOrCreateSPIRVIntegerType(BitWidth, I, TII);
if (ResType->getOpcode() == SPIRV::OpTypeVector) {
- const unsigned NumElts = ResType->getOperand(2).getImm();
+ const unsigned NumElts = GR.getScalarOrVectorComponentCount(ResType);
TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII);
}
SrcReg = createVirtualRegister(TmpType, &GR, MRI, MRI->getMF());
@@ -6477,10 +6461,7 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
assert(ResType->getOpcode() == SPIRV::OpTypeVector ||
ResType->getOpcode() == SPIRV::OpTypeFloat);
// TODO: Add matrix implementation once supported by the HLSL frontend.
- SPIRVTypeInst SpirvScalarType = ResType->getOpcode() == SPIRV::OpTypeVector
- ? SPIRVTypeInst(GR.getSPIRVTypeForVReg(
- ResType->getOperand(1).getReg()))
- : ResType;
+ SPIRVTypeInst SpirvScalarType = GR.getScalarOrVectorComponentType(ResType);
Register ScaleReg =
GR.buildConstantFP(APFloat(0.30103f), MIRBuilder, SpirvScalarType);
@@ -6692,12 +6673,10 @@ SPIRVTypeInst SPIRVInstructionSelector::widenTypeToVec4(SPIRVTypeInst Type,
if (Type->getOpcode() != SPIRV::OpTypeVector)
return GR.getOrCreateSPIRVVectorType(Type, 4, MIRBuilder, false);
- uint64_t VectorSize = Type->getOperand(2).getImm();
- if (VectorSize == 4)
+ if (GR.getScalarOrVectorComponentCount(Type) == 4)
return Type;
- Register ScalarTypeReg = Type->getOperand(1).getReg();
- const SPIRVTypeInst ScalarType = GR.getSPIRVTypeForVReg(ScalarTypeReg);
+ SPIRVTypeInst ScalarType = GR.getScalarOrVectorComponentType(Type);
return GR.getOrCreateSPIRVVectorType(ScalarType, 4, MIRBuilder, false);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index d9e134552a269..8fe59ba4f8e9c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -317,7 +317,7 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
SPIRVTypeInst ScalarType = nullptr;
if (SPIRVTypeInst DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
assert(DefType->getOpcode() == SPIRV::OpTypeVector);
- ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ ScalarType = GR->getScalarOrVectorComponentType(DefType);
}
if (!ScalarType) {
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
…aw operand access Replace direct accesses to SPIR-V type instruction operands with the existing getScalarOrVectorComponentCount() and getScalarOrVectorComponentType() helpers
7a95fb0 to
9cc7fa6
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Replace direct accesses to SPIR-V type instruction operands with the existing getScalarOrVectorComponentCount() and getScalarOrVectorComponentType() helpers