Skip to content

[NFC][SPIR-V] Use getScalarOrVectorComponent{Count,Type} instead of raw operand access#193410

Open
aobolensk wants to merge 1 commit intollvm:mainfrom
aobolensk:llvm-spirv-getScalarOrVectorComponent
Open

[NFC][SPIR-V] Use getScalarOrVectorComponent{Count,Type} instead of raw operand access#193410
aobolensk wants to merge 1 commit intollvm:mainfrom
aobolensk:llvm-spirv-getScalarOrVectorComponent

Conversation

@aobolensk
Copy link
Copy Markdown
Contributor

Replace direct accesses to SPIR-V type instruction operands with the existing getScalarOrVectorComponentCount() and getScalarOrVectorComponentType() helpers

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 22, 2026

@llvm/pr-subscribers-backend-spir-v

Author: Arseniy Obolenskiy (aobolensk)

Changes

Replace direct accesses to SPIR-V type instruction operands with the existing getScalarOrVectorComponentCount() and getScalarOrVectorComponentType() helpers


Full diff: https://github.com/llvm/llvm-project/pull/193410.diff

5 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+10-10)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+21-30)
  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+7-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+12-33)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+1-1)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 1e6d3c52990f7..099c88d123223 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -446,7 +446,7 @@ buildBoolRegister(MachineIRBuilder &MIRBuilder, SPIRVTypeInst ResultType,
   SPIRVTypeInst BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);
 
   if (ResultType->getOpcode() == SPIRV::OpTypeVector) {
-    unsigned VectorElements = ResultType->getOperand(2).getImm();
+    unsigned VectorElements = GR->getScalarOrVectorComponentCount(ResultType);
     BoolType = GR->getOrCreateSPIRVVectorType(BoolType, VectorElements,
                                               MIRBuilder, true);
     const FixedVectorType *LLVMVectorType =
@@ -1334,7 +1334,7 @@ static bool generateRelationalInst(const SPIRV::IncomingCall *Call,
   if ((Opcode == SPIRV::OpAny || Opcode == SPIRV::OpAll) &&
       !GR->isScalarOrVectorOfType(Arguments[0], SPIRV::OpTypeBool)) {
     SPIRVTypeInst ArgType = GR->getSPIRVTypeForVReg(Arguments[0]);
-    unsigned NumElts = ArgType->getOperand(2).getImm();
+    unsigned NumElts = GR->getScalarOrVectorComponentCount(ArgType);
     SPIRVTypeInst BoolVecTy = GR->getOrCreateSPIRVVectorType(
         GR->getOrCreateSPIRVBoolType(MIRBuilder, /*EmitIR=*/true), NumElts,
         MIRBuilder, /*EmitIR=*/true);
@@ -1796,8 +1796,8 @@ static bool generateBuiltinVar(const SPIRV::IncomingCall *Call,
   unsigned BitWidth = GR->getScalarOrVectorBitWidth(Call->ReturnType);
   LLT LLType;
   if (Call->ReturnType->getOpcode() == SPIRV::OpTypeVector)
-    LLType =
-        LLT::fixed_vector(Call->ReturnType->getOperand(2).getImm(), BitWidth);
+    LLType = LLT::fixed_vector(
+        GR->getScalarOrVectorComponentCount(Call->ReturnType), BitWidth);
   else
     LLType = LLT::scalar(BitWidth);
 
@@ -2159,9 +2159,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
   // vector, expect only a single size component. Otherwise get the number of
   // expected components.
   unsigned NumExpectedRetComponents =
-      Call->ReturnType->getOpcode() == SPIRV::OpTypeVector
-          ? Call->ReturnType->getOperand(2).getImm()
-          : 1;
+      GR->getScalarOrVectorComponentCount(Call->ReturnType);
   // Get the actual number of query result/size components.
   SPIRVTypeInst ImgType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
   unsigned NumActualRetComponents = getNumSizeComponents(ImgType);
@@ -2201,10 +2199,12 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
     Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
     SPIRVTypeInst NewType = nullptr;
     if (QueryResultType->getOpcode() == SPIRV::OpTypeVector) {
-      Register NewTypeReg = QueryResultType->getOperand(1).getReg();
-      if (TypeReg != NewTypeReg &&
-          (NewType = GR->getSPIRVTypeForVReg(NewTypeReg)))
+      NewType = GR->getScalarOrVectorComponentType(QueryResultType);
+      Register NewTypeReg = GR->getSPIRVTypeID(NewType);
+      if (TypeReg != NewTypeReg)
         TypeReg = NewTypeReg;
+      else
+        NewType = nullptr;
     }
     MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
         .addDef(Call->ReturnRegister)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index d5827e30ce17d..f991512143e71 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -645,7 +645,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(const APInt &Val,
       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
   unsigned BW = getScalarOrVectorBitWidth(SpvType);
   return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
-                                    SpvType->getOperand(2).getImm(),
+                                    getScalarOrVectorComponentCount(SpvType),
                                     ZeroAsNull);
 }
 
@@ -664,7 +664,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
   unsigned BW = getScalarOrVectorBitWidth(SpvType);
   return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
-                                    SpvType->getOperand(2).getImm(),
+                                    getScalarOrVectorComponentCount(SpvType),
                                     ZeroAsNull);
 }
 
@@ -745,9 +745,9 @@ Register SPIRVGlobalRegistry::getOrCreateConsIntVector(
   auto ConstVec =
       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
   unsigned BW = getScalarOrVectorBitWidth(SpvType);
-  return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
-                                       ConstVec, BW,
-                                       SpvType->getOperand(2).getImm());
+  return getOrCreateIntCompositeOrNull(
+      Val, MIRBuilder, SpvType, EmitIR, ConstVec, BW,
+      getScalarOrVectorComponentCount(SpvType));
 }
 
 Register
@@ -1448,14 +1448,11 @@ SPIRVGlobalRegistry::getScalarOrVectorComponentType(SPIRVTypeInst Type) const {
 unsigned
 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(SPIRVTypeInst Type) const {
   assert(Type && "Invalid Type pointer");
-  if (Type->getOpcode() == SPIRV::OpTypeVector) {
-    auto EleTypeReg = Type->getOperand(1).getReg();
-    Type = getSPIRVTypeForVReg(EleTypeReg);
-  }
-  if (Type->getOpcode() == SPIRV::OpTypeInt ||
-      Type->getOpcode() == SPIRV::OpTypeFloat)
-    return Type->getOperand(1).getImm();
-  if (Type->getOpcode() == SPIRV::OpTypeBool)
+  SPIRVTypeInst ScalarType = getScalarOrVectorComponentType(Type);
+  if (ScalarType->getOpcode() == SPIRV::OpTypeInt ||
+      ScalarType->getOpcode() == SPIRV::OpTypeFloat)
+    return ScalarType->getOperand(1).getImm();
+  if (ScalarType->getOpcode() == SPIRV::OpTypeBool)
     return 1;
   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
 }
@@ -1463,22 +1460,19 @@ SPIRVGlobalRegistry::getScalarOrVectorBitWidth(SPIRVTypeInst Type) const {
 unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
     SPIRVTypeInst Type) const {
   assert(Type && "Invalid Type pointer");
-  unsigned NumElements = 1;
-  if (Type->getOpcode() == SPIRV::OpTypeVector) {
-    NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
-    Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
-  }
-  return Type->getOpcode() == SPIRV::OpTypeInt ||
-                 Type->getOpcode() == SPIRV::OpTypeFloat
-             ? NumElements * Type->getOperand(1).getImm()
+  unsigned NumElements = getScalarOrVectorComponentCount(Type);
+  SPIRVTypeInst ScalarType = getScalarOrVectorComponentType(Type);
+  return ScalarType->getOpcode() == SPIRV::OpTypeInt ||
+                 ScalarType->getOpcode() == SPIRV::OpTypeFloat
+             ? NumElements * ScalarType->getOperand(1).getImm()
              : 0;
 }
 
 SPIRVTypeInst
 SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(SPIRVTypeInst Type) const {
-  if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
-    Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
-  return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
+  SPIRVTypeInst ScalarType = getScalarOrVectorComponentType(Type);
+  return ScalarType && ScalarType->getOpcode() == SPIRV::OpTypeInt ? ScalarType
+                                                                   : nullptr;
 }
 
 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(SPIRVTypeInst Type) const {
@@ -2099,8 +2093,7 @@ SPIRVGlobalRegistry::getRegClass(SPIRVTypeInst SpvType) const {
   case SPIRV::OpTypePointer:
     return &SPIRV::pIDRegClass;
   case SPIRV::OpTypeVector: {
-    SPIRVTypeInst ElemType =
-        getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
+    SPIRVTypeInst ElemType = getScalarOrVectorComponentType(SpvType);
     unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0;
     if (ElemOpcode == SPIRV::OpTypeFloat)
       return &SPIRV::vfIDRegClass;
@@ -2128,8 +2121,7 @@ LLT SPIRVGlobalRegistry::getRegType(SPIRVTypeInst SpvType) const {
   case SPIRV::OpTypePointer:
     return LLT::pointer(getAS(SpvType), getPointerSize());
   case SPIRV::OpTypeVector: {
-    SPIRVTypeInst ElemType =
-        getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
+    SPIRVTypeInst ElemType = getScalarOrVectorComponentType(SpvType);
     LLT ET;
     switch (ElemType ? ElemType->getOpcode() : 0) {
     case SPIRV::OpTypePointer:
@@ -2143,8 +2135,7 @@ LLT SPIRVGlobalRegistry::getRegType(SPIRVTypeInst SpvType) const {
     default:
       ET = LLT::scalar(64);
     }
-    return LLT::fixed_vector(
-        static_cast<unsigned>(SpvType->getOperand(2).getImm()), ET);
+    return LLT::fixed_vector(getScalarOrVectorComponentCount(SpvType), ET);
   }
   }
   return LLT::scalar(64);
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 2fcb71939322a..6267596e44596 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -546,12 +546,13 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
           SPIRVTypeInst Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
           SPIRVTypeInst RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
           assert(RetType && "Expected return type");
-          validatePtrTypes(STI, MRI, GR, MI, MI.getNumOperands() - 1,
-                           RetType->getOpcode() != SPIRV::OpTypeVector
-                               ? Int32Type
-                               : GR.getOrCreateSPIRVVectorType(
-                                     Int32Type, RetType->getOperand(2).getImm(),
-                                     MIB, false));
+          validatePtrTypes(
+              STI, MRI, GR, MI, MI.getNumOperands() - 1,
+              RetType->getOpcode() != SPIRV::OpTypeVector
+                  ? Int32Type
+                  : GR.getOrCreateSPIRVVectorType(
+                        Int32Type, GR.getScalarOrVectorComponentCount(RetType),
+                        MIB, false));
         } break;
         case SPIRV::OpenCLExtInst::fract:
         case SPIRV::OpenCLExtInst::modf:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f1e0450bb20f9..881bc8f9f2c0e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2290,8 +2290,7 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
     report_fatal_error(
         "cannot select G_UNMERGE_VALUES with a non-vector argument");
 
-  SPIRVTypeInst ScalarType =
-      GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
+  SPIRVTypeInst ScalarType = GR.getScalarOrVectorComponentType(SrcType);
   MachineBasicBlock &BB = *I.getParent();
   unsigned CurrentIndex = 0;
   for (unsigned i = 0; i < I.getNumDefs(); ++i) {
@@ -2804,7 +2803,7 @@ bool SPIRVInstructionSelector::selectAnyOrAll(Register ResVReg,
     NotEqualReg =
         IsBoolTy ? InputRegister
                  : createVirtualRegister(SpvBoolTy, &GR, MRI, MRI->getMF());
-    const unsigned NumElts = InputType->getOperand(2).getImm();
+    const unsigned NumElts = GR.getScalarOrVectorComponentCount(InputType);
     SpvBoolTy = GR.getOrCreateSPIRVVectorType(SpvBoolTy, NumElts, I, TII);
   }
 
@@ -2857,7 +2856,7 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
          "dot product requires a vector of at least 2 components");
 
   [[maybe_unused]] SPIRVTypeInst EltType =
-      GR.getSPIRVTypeForVReg(VecType->getOperand(1).getReg());
+      GR.getScalarOrVectorComponentType(VecType);
 
   assert(EltType->getOpcode() == SPIRV::OpTypeFloat);
 
@@ -3233,15 +3232,6 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
   return true;
 }
 
-unsigned getVectorSizeOrOne(SPIRVTypeInst Type) {
-
-  if (Type->getOpcode() != SPIRV::OpTypeVector)
-    return 1;
-
-  // Operand(2) is the vector size
-  return Type->getOperand(2).getImm();
-}
-
 bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
                                                         SPIRVTypeInst ResType,
                                                         MachineInstr &I) const {
@@ -3253,16 +3243,12 @@ bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
   SPIRVTypeInst InputType = GR.getSPIRVTypeForVReg(InputReg);
 
   // Determine if input is vector
-  unsigned NumElems = getVectorSizeOrOne(InputType);
+  unsigned NumElems = GR.getScalarOrVectorComponentCount(InputType);
   bool IsVector = NumElems > 1;
 
   // Determine element types
-  SPIRVTypeInst ElemInputType = InputType;
-  SPIRVTypeInst ElemBoolType = ResType;
-  if (IsVector) {
-    ElemInputType = GR.getSPIRVTypeForVReg(InputType->getOperand(1).getReg());
-    ElemBoolType = GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg());
-  }
+  SPIRVTypeInst ElemInputType = GR.getScalarOrVectorComponentType(InputType);
+  SPIRVTypeInst ElemBoolType = GR.getScalarOrVectorComponentType(ResType);
 
   // Subgroup scope constant
   SPIRVTypeInst IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
@@ -3903,10 +3889,8 @@ bool SPIRVInstructionSelector::selectExp10(Register ResVReg,
 
     MachineIRBuilder MIRBuilder(I);
 
-    SPIRVTypeInst SpirvScalarType = ResType->getOpcode() == SPIRV::OpTypeVector
-                                        ? SPIRVTypeInst(GR.getSPIRVTypeForVReg(
-                                              ResType->getOperand(1).getReg()))
-                                        : ResType;
+    SPIRVTypeInst SpirvScalarType =
+        GR.getScalarOrVectorComponentType(ResType);
 
     assert(SpirvScalarType->getOperand(1).getImm() == 32 &&
            "only float operands supported by GLSL extended math");
@@ -4098,7 +4082,7 @@ bool SPIRVInstructionSelector::selectIToF(Register ResVReg,
     unsigned BitWidth = GR.getScalarOrVectorBitWidth(ResType);
     SPIRVTypeInst TmpType = GR.getOrCreateSPIRVIntegerType(BitWidth, I, TII);
     if (ResType->getOpcode() == SPIRV::OpTypeVector) {
-      const unsigned NumElts = ResType->getOperand(2).getImm();
+      const unsigned NumElts = GR.getScalarOrVectorComponentCount(ResType);
       TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII);
     }
     SrcReg = createVirtualRegister(TmpType, &GR, MRI, MRI->getMF());
@@ -6477,10 +6461,7 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
   assert(ResType->getOpcode() == SPIRV::OpTypeVector ||
          ResType->getOpcode() == SPIRV::OpTypeFloat);
   // TODO: Add matrix implementation once supported by the HLSL frontend.
-  SPIRVTypeInst SpirvScalarType = ResType->getOpcode() == SPIRV::OpTypeVector
-                                      ? SPIRVTypeInst(GR.getSPIRVTypeForVReg(
-                                            ResType->getOperand(1).getReg()))
-                                      : ResType;
+  SPIRVTypeInst SpirvScalarType = GR.getScalarOrVectorComponentType(ResType);
   Register ScaleReg =
       GR.buildConstantFP(APFloat(0.30103f), MIRBuilder, SpirvScalarType);
 
@@ -6692,12 +6673,10 @@ SPIRVTypeInst SPIRVInstructionSelector::widenTypeToVec4(SPIRVTypeInst Type,
   if (Type->getOpcode() != SPIRV::OpTypeVector)
     return GR.getOrCreateSPIRVVectorType(Type, 4, MIRBuilder, false);
 
-  uint64_t VectorSize = Type->getOperand(2).getImm();
-  if (VectorSize == 4)
+  if (GR.getScalarOrVectorComponentCount(Type) == 4)
     return Type;
 
-  Register ScalarTypeReg = Type->getOperand(1).getReg();
-  const SPIRVTypeInst ScalarType = GR.getSPIRVTypeForVReg(ScalarTypeReg);
+  SPIRVTypeInst ScalarType = GR.getScalarOrVectorComponentType(Type);
   return GR.getOrCreateSPIRVVectorType(ScalarType, 4, MIRBuilder, false);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index d9e134552a269..8fe59ba4f8e9c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -317,7 +317,7 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
   SPIRVTypeInst ScalarType = nullptr;
   if (SPIRVTypeInst DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
     assert(DefType->getOpcode() == SPIRV::OpTypeVector);
-    ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+    ScalarType = GR->getScalarOrVectorComponentType(DefType);
   }
 
   if (!ScalarType) {

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 22, 2026

✅ With the latest revision this PR passed the C/C++ code formatter.

…aw operand access

Replace direct accesses to SPIR-V type instruction operands with the existing getScalarOrVectorComponentCount() and getScalarOrVectorComponentType() helpers
@aobolensk aobolensk force-pushed the llvm-spirv-getScalarOrVectorComponent branch from 7a95fb0 to 9cc7fa6 Compare April 22, 2026 05:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants