diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 1e6d3c52990f7..099c88d123223 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -446,7 +446,7 @@ buildBoolRegister(MachineIRBuilder &MIRBuilder, SPIRVTypeInst ResultType, SPIRVTypeInst BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true); if (ResultType->getOpcode() == SPIRV::OpTypeVector) { - unsigned VectorElements = ResultType->getOperand(2).getImm(); + unsigned VectorElements = GR->getScalarOrVectorComponentCount(ResultType); BoolType = GR->getOrCreateSPIRVVectorType(BoolType, VectorElements, MIRBuilder, true); const FixedVectorType *LLVMVectorType = @@ -1334,7 +1334,7 @@ static bool generateRelationalInst(const SPIRV::IncomingCall *Call, if ((Opcode == SPIRV::OpAny || Opcode == SPIRV::OpAll) && !GR->isScalarOrVectorOfType(Arguments[0], SPIRV::OpTypeBool)) { SPIRVTypeInst ArgType = GR->getSPIRVTypeForVReg(Arguments[0]); - unsigned NumElts = ArgType->getOperand(2).getImm(); + unsigned NumElts = GR->getScalarOrVectorComponentCount(ArgType); SPIRVTypeInst BoolVecTy = GR->getOrCreateSPIRVVectorType( GR->getOrCreateSPIRVBoolType(MIRBuilder, /*EmitIR=*/true), NumElts, MIRBuilder, /*EmitIR=*/true); @@ -1796,8 +1796,8 @@ static bool generateBuiltinVar(const SPIRV::IncomingCall *Call, unsigned BitWidth = GR->getScalarOrVectorBitWidth(Call->ReturnType); LLT LLType; if (Call->ReturnType->getOpcode() == SPIRV::OpTypeVector) - LLType = - LLT::fixed_vector(Call->ReturnType->getOperand(2).getImm(), BitWidth); + LLType = LLT::fixed_vector( + GR->getScalarOrVectorComponentCount(Call->ReturnType), BitWidth); else LLType = LLT::scalar(BitWidth); @@ -2159,9 +2159,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call, // vector, expect only a single size component. Otherwise get the number of // expected components. unsigned NumExpectedRetComponents = - Call->ReturnType->getOpcode() == SPIRV::OpTypeVector - ? Call->ReturnType->getOperand(2).getImm() - : 1; + GR->getScalarOrVectorComponentCount(Call->ReturnType); // Get the actual number of query result/size components. SPIRVTypeInst ImgType = GR->getSPIRVTypeForVReg(Call->Arguments[0]); unsigned NumActualRetComponents = getNumSizeComponents(ImgType); @@ -2201,10 +2199,12 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call, Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType); SPIRVTypeInst NewType = nullptr; if (QueryResultType->getOpcode() == SPIRV::OpTypeVector) { - Register NewTypeReg = QueryResultType->getOperand(1).getReg(); - if (TypeReg != NewTypeReg && - (NewType = GR->getSPIRVTypeForVReg(NewTypeReg))) + NewType = GR->getScalarOrVectorComponentType(QueryResultType); + Register NewTypeReg = GR->getSPIRVTypeID(NewType); + if (TypeReg != NewTypeReg) TypeReg = NewTypeReg; + else + NewType = nullptr; } MIRBuilder.buildInstr(SPIRV::OpCompositeExtract) .addDef(Call->ReturnRegister) diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index d5827e30ce17d..f991512143e71 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -645,7 +645,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(const APInt &Val, ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal); unsigned BW = getScalarOrVectorBitWidth(SpvType); return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW, - SpvType->getOperand(2).getImm(), + getScalarOrVectorComponentCount(SpvType), ZeroAsNull); } @@ -664,7 +664,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val, ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal); unsigned BW = getScalarOrVectorBitWidth(SpvType); return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW, - SpvType->getOperand(2).getImm(), + getScalarOrVectorComponentCount(SpvType), ZeroAsNull); } @@ -745,9 +745,9 @@ Register SPIRVGlobalRegistry::getOrCreateConsIntVector( auto ConstVec = ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); unsigned BW = getScalarOrVectorBitWidth(SpvType); - return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, - ConstVec, BW, - SpvType->getOperand(2).getImm()); + return getOrCreateIntCompositeOrNull( + Val, MIRBuilder, SpvType, EmitIR, ConstVec, BW, + getScalarOrVectorComponentCount(SpvType)); } Register @@ -1448,14 +1448,11 @@ SPIRVGlobalRegistry::getScalarOrVectorComponentType(SPIRVTypeInst Type) const { unsigned SPIRVGlobalRegistry::getScalarOrVectorBitWidth(SPIRVTypeInst Type) const { assert(Type && "Invalid Type pointer"); - if (Type->getOpcode() == SPIRV::OpTypeVector) { - auto EleTypeReg = Type->getOperand(1).getReg(); - Type = getSPIRVTypeForVReg(EleTypeReg); - } - if (Type->getOpcode() == SPIRV::OpTypeInt || - Type->getOpcode() == SPIRV::OpTypeFloat) - return Type->getOperand(1).getImm(); - if (Type->getOpcode() == SPIRV::OpTypeBool) + SPIRVTypeInst ScalarType = getScalarOrVectorComponentType(Type); + if (ScalarType->getOpcode() == SPIRV::OpTypeInt || + ScalarType->getOpcode() == SPIRV::OpTypeFloat) + return ScalarType->getOperand(1).getImm(); + if (ScalarType->getOpcode() == SPIRV::OpTypeBool) return 1; llvm_unreachable("Attempting to get bit width of non-integer/float type."); } @@ -1463,22 +1460,19 @@ SPIRVGlobalRegistry::getScalarOrVectorBitWidth(SPIRVTypeInst Type) const { unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth( SPIRVTypeInst Type) const { assert(Type && "Invalid Type pointer"); - unsigned NumElements = 1; - if (Type->getOpcode() == SPIRV::OpTypeVector) { - NumElements = static_cast(Type->getOperand(2).getImm()); - Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg()); - } - return Type->getOpcode() == SPIRV::OpTypeInt || - Type->getOpcode() == SPIRV::OpTypeFloat - ? NumElements * Type->getOperand(1).getImm() + unsigned NumElements = getScalarOrVectorComponentCount(Type); + SPIRVTypeInst ScalarType = getScalarOrVectorComponentType(Type); + return ScalarType->getOpcode() == SPIRV::OpTypeInt || + ScalarType->getOpcode() == SPIRV::OpTypeFloat + ? NumElements * ScalarType->getOperand(1).getImm() : 0; } SPIRVTypeInst SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(SPIRVTypeInst Type) const { - if (Type && Type->getOpcode() == SPIRV::OpTypeVector) - Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg()); - return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr; + SPIRVTypeInst ScalarType = getScalarOrVectorComponentType(Type); + return ScalarType && ScalarType->getOpcode() == SPIRV::OpTypeInt ? ScalarType + : nullptr; } bool SPIRVGlobalRegistry::isScalarOrVectorSigned(SPIRVTypeInst Type) const { @@ -2099,8 +2093,7 @@ SPIRVGlobalRegistry::getRegClass(SPIRVTypeInst SpvType) const { case SPIRV::OpTypePointer: return &SPIRV::pIDRegClass; case SPIRV::OpTypeVector: { - SPIRVTypeInst ElemType = - getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); + SPIRVTypeInst ElemType = getScalarOrVectorComponentType(SpvType); unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0; if (ElemOpcode == SPIRV::OpTypeFloat) return &SPIRV::vfIDRegClass; @@ -2128,8 +2121,7 @@ LLT SPIRVGlobalRegistry::getRegType(SPIRVTypeInst SpvType) const { case SPIRV::OpTypePointer: return LLT::pointer(getAS(SpvType), getPointerSize()); case SPIRV::OpTypeVector: { - SPIRVTypeInst ElemType = - getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); + SPIRVTypeInst ElemType = getScalarOrVectorComponentType(SpvType); LLT ET; switch (ElemType ? ElemType->getOpcode() : 0) { case SPIRV::OpTypePointer: @@ -2143,8 +2135,7 @@ LLT SPIRVGlobalRegistry::getRegType(SPIRVTypeInst SpvType) const { default: ET = LLT::scalar(64); } - return LLT::fixed_vector( - static_cast(SpvType->getOperand(2).getImm()), ET); + return LLT::fixed_vector(getScalarOrVectorComponentCount(SpvType), ET); } } return LLT::scalar(64); diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index 2fcb71939322a..6267596e44596 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -546,12 +546,13 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { SPIRVTypeInst Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB); SPIRVTypeInst RetType = MRI->getVRegDef(MI.getOperand(1).getReg()); assert(RetType && "Expected return type"); - validatePtrTypes(STI, MRI, GR, MI, MI.getNumOperands() - 1, - RetType->getOpcode() != SPIRV::OpTypeVector - ? Int32Type - : GR.getOrCreateSPIRVVectorType( - Int32Type, RetType->getOperand(2).getImm(), - MIB, false)); + validatePtrTypes( + STI, MRI, GR, MI, MI.getNumOperands() - 1, + RetType->getOpcode() != SPIRV::OpTypeVector + ? Int32Type + : GR.getOrCreateSPIRVVectorType( + Int32Type, GR.getScalarOrVectorComponentCount(RetType), + MIB, false)); } break; case SPIRV::OpenCLExtInst::fract: case SPIRV::OpenCLExtInst::modf: diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index f1e0450bb20f9..63127d43aecea 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2290,8 +2290,7 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const { report_fatal_error( "cannot select G_UNMERGE_VALUES with a non-vector argument"); - SPIRVTypeInst ScalarType = - GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg()); + SPIRVTypeInst ScalarType = GR.getScalarOrVectorComponentType(SrcType); MachineBasicBlock &BB = *I.getParent(); unsigned CurrentIndex = 0; for (unsigned i = 0; i < I.getNumDefs(); ++i) { @@ -2804,7 +2803,7 @@ bool SPIRVInstructionSelector::selectAnyOrAll(Register ResVReg, NotEqualReg = IsBoolTy ? InputRegister : createVirtualRegister(SpvBoolTy, &GR, MRI, MRI->getMF()); - const unsigned NumElts = InputType->getOperand(2).getImm(); + const unsigned NumElts = GR.getScalarOrVectorComponentCount(InputType); SpvBoolTy = GR.getOrCreateSPIRVVectorType(SpvBoolTy, NumElts, I, TII); } @@ -2857,7 +2856,7 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg, "dot product requires a vector of at least 2 components"); [[maybe_unused]] SPIRVTypeInst EltType = - GR.getSPIRVTypeForVReg(VecType->getOperand(1).getReg()); + GR.getScalarOrVectorComponentType(VecType); assert(EltType->getOpcode() == SPIRV::OpTypeFloat); @@ -3233,15 +3232,6 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( return true; } -unsigned getVectorSizeOrOne(SPIRVTypeInst Type) { - - if (Type->getOpcode() != SPIRV::OpTypeVector) - return 1; - - // Operand(2) is the vector size - return Type->getOperand(2).getImm(); -} - bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg, SPIRVTypeInst ResType, MachineInstr &I) const { @@ -3253,16 +3243,12 @@ bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg, SPIRVTypeInst InputType = GR.getSPIRVTypeForVReg(InputReg); // Determine if input is vector - unsigned NumElems = getVectorSizeOrOne(InputType); + unsigned NumElems = GR.getScalarOrVectorComponentCount(InputType); bool IsVector = NumElems > 1; // Determine element types - SPIRVTypeInst ElemInputType = InputType; - SPIRVTypeInst ElemBoolType = ResType; - if (IsVector) { - ElemInputType = GR.getSPIRVTypeForVReg(InputType->getOperand(1).getReg()); - ElemBoolType = GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg()); - } + SPIRVTypeInst ElemInputType = GR.getScalarOrVectorComponentType(InputType); + SPIRVTypeInst ElemBoolType = GR.getScalarOrVectorComponentType(ResType); // Subgroup scope constant SPIRVTypeInst IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); @@ -3903,10 +3889,7 @@ bool SPIRVInstructionSelector::selectExp10(Register ResVReg, MachineIRBuilder MIRBuilder(I); - SPIRVTypeInst SpirvScalarType = ResType->getOpcode() == SPIRV::OpTypeVector - ? SPIRVTypeInst(GR.getSPIRVTypeForVReg( - ResType->getOperand(1).getReg())) - : ResType; + SPIRVTypeInst SpirvScalarType = GR.getScalarOrVectorComponentType(ResType); assert(SpirvScalarType->getOperand(1).getImm() == 32 && "only float operands supported by GLSL extended math"); @@ -4098,7 +4081,7 @@ bool SPIRVInstructionSelector::selectIToF(Register ResVReg, unsigned BitWidth = GR.getScalarOrVectorBitWidth(ResType); SPIRVTypeInst TmpType = GR.getOrCreateSPIRVIntegerType(BitWidth, I, TII); if (ResType->getOpcode() == SPIRV::OpTypeVector) { - const unsigned NumElts = ResType->getOperand(2).getImm(); + const unsigned NumElts = GR.getScalarOrVectorComponentCount(ResType); TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII); } SrcReg = createVirtualRegister(TmpType, &GR, MRI, MRI->getMF()); @@ -6477,10 +6460,7 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg, assert(ResType->getOpcode() == SPIRV::OpTypeVector || ResType->getOpcode() == SPIRV::OpTypeFloat); // TODO: Add matrix implementation once supported by the HLSL frontend. - SPIRVTypeInst SpirvScalarType = ResType->getOpcode() == SPIRV::OpTypeVector - ? SPIRVTypeInst(GR.getSPIRVTypeForVReg( - ResType->getOperand(1).getReg())) - : ResType; + SPIRVTypeInst SpirvScalarType = GR.getScalarOrVectorComponentType(ResType); Register ScaleReg = GR.buildConstantFP(APFloat(0.30103f), MIRBuilder, SpirvScalarType); @@ -6692,12 +6672,10 @@ SPIRVTypeInst SPIRVInstructionSelector::widenTypeToVec4(SPIRVTypeInst Type, if (Type->getOpcode() != SPIRV::OpTypeVector) return GR.getOrCreateSPIRVVectorType(Type, 4, MIRBuilder, false); - uint64_t VectorSize = Type->getOperand(2).getImm(); - if (VectorSize == 4) + if (GR.getScalarOrVectorComponentCount(Type) == 4) return Type; - Register ScalarTypeReg = Type->getOperand(1).getReg(); - const SPIRVTypeInst ScalarType = GR.getSPIRVTypeForVReg(ScalarTypeReg); + SPIRVTypeInst ScalarType = GR.getScalarOrVectorComponentType(Type); return GR.getOrCreateSPIRVVectorType(ScalarType, 4, MIRBuilder, false); } diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp index d9e134552a269..8fe59ba4f8e9c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp @@ -317,7 +317,7 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF, SPIRVTypeInst ScalarType = nullptr; if (SPIRVTypeInst DefType = GR->getSPIRVTypeForVReg(SrcReg)) { assert(DefType->getOpcode() == SPIRV::OpTypeVector); - ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); + ScalarType = GR->getScalarOrVectorComponentType(DefType); } if (!ScalarType) {