diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 8f6a73d0a2dd8..f639f0adb9c43 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -172,6 +172,14 @@ class VPBuilder { new VPInstruction(Opcode, Operands, *FMFs, DL, Name)); return createInstruction(Opcode, Operands, DL, Name); } + VPInstruction *createNaryOp(unsigned Opcode, + std::initializer_list Operands, + Type *ResultTy, + std::optional FMFs = {}, + DebugLoc DL = {}, const Twine &Name = "") { + return tryInsertInstruction(new VPInstructionWithType( + Opcode, Operands, ResultTy, FMFs.value_or(FastMathFlags()), DL, Name)); + } VPInstruction *createOverflowingOp(unsigned Opcode, std::initializer_list Operands, diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index a28cda9fe62b3..2f766b26222ff 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7934,7 +7934,8 @@ DenseMap LoopVectorizationPlanner::executePlan( BestVPlan, BestVF, TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)); VPlanTransforms::removeDeadRecipes(BestVPlan); - VPlanTransforms::convertToConcreteRecipes(BestVPlan); + VPlanTransforms::convertToConcreteRecipes(BestVPlan, + *Legal->getWidestInductionType()); // Perform the actual loop transformation. VPTransformState State(&TTI, BestVF, LI, DT, ILV.Builder, &ILV, &BestVPlan, diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 54d8f2e7449b0..94b5167c60089 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -884,6 +884,13 @@ class VPInstruction : public VPRecipeWithIRFlags, AnyOf, // Calculates the first active lane index of the vector predicate operand. FirstActiveLane, + + // The opcodes below are used for VPInstructionWithType. + // + /// Scale the first operand (vector step) by the second operand + /// (scalar-step). Casts both operands to the result type if needed. + WideIVStep, + }; private: @@ -1041,11 +1048,19 @@ class VPInstructionWithType : public VPInstruction { VPInstructionWithType(unsigned Opcode, ArrayRef Operands, Type *ResultTy, DebugLoc DL, const Twine &Name = "") : VPInstruction(Opcode, Operands, DL, Name), ResultTy(ResultTy) {} + VPInstructionWithType(unsigned Opcode, + std::initializer_list Operands, + Type *ResultTy, FastMathFlags FMFs, DebugLoc DL = {}, + const Twine &Name = "") + : VPInstruction(Opcode, Operands, FMFs, DL, Name), ResultTy(ResultTy) {} static inline bool classof(const VPRecipeBase *R) { // VPInstructionWithType are VPInstructions with specific opcodes requiring // type information. - return R->isScalarCast(); + if (R->isScalarCast()) + return true; + auto *VPI = dyn_cast(R); + return VPI && VPI->getOpcode() == VPInstruction::WideIVStep; } static inline bool classof(const VPUser *R) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 2cff343d915cf..2cc558f49ccce 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -873,7 +873,8 @@ bool VPInstruction::isFPMathOp() const { return Opcode == Instruction::FAdd || Opcode == Instruction::FMul || Opcode == Instruction::FNeg || Opcode == Instruction::FSub || Opcode == Instruction::FDiv || Opcode == Instruction::FRem || - Opcode == Instruction::FCmp || Opcode == Instruction::Select; + Opcode == Instruction::FCmp || Opcode == Instruction::Select || + Opcode == VPInstruction::WideIVStep; } #endif @@ -928,6 +929,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const { case VPInstruction::LogicalAnd: case VPInstruction::Not: case VPInstruction::PtrAdd: + case VPInstruction::WideIVStep: return false; default: return true; @@ -1097,9 +1099,19 @@ void VPInstructionWithType::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << "EMIT "; printAsOperand(O, SlotTracker); - O << " = " << Instruction::getOpcodeName(getOpcode()) << " "; - printOperands(O, SlotTracker); - O << " to " << *ResultTy; + O << " = "; + + switch (getOpcode()) { + case VPInstruction::WideIVStep: + O << "wide-iv-step "; + printOperands(O, SlotTracker); + break; + default: + assert(Instruction::isCast(getOpcode()) && "unhandled opcode"); + O << Instruction::getOpcodeName(getOpcode()) << " "; + printOperands(O, SlotTracker); + O << " to " << *ResultTy; + } } #endif diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index dc9f953f7447b..77fe3c367dd38 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -1019,6 +1019,14 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { TypeInfo.inferScalarType(R.getOperand(1)) == TypeInfo.inferScalarType(R.getVPSingleValue())) return R.getVPSingleValue()->replaceAllUsesWith(R.getOperand(1)); + + if (match(&R, m_VPInstruction(m_VPValue(X), + m_SpecificInt(1)))) { + Type *WideStepTy = TypeInfo.inferScalarType(R.getVPSingleValue()); + if (TypeInfo.inferScalarType(X) != WideStepTy) + X = VPBuilder(&R).createWidenCast(Instruction::Trunc, X, WideStepTy); + R.getVPSingleValue()->replaceAllUsesWith(X); + } } void VPlanTransforms::simplifyRecipes(VPlan &Plan, Type &CanonicalIVTy) { @@ -2367,23 +2375,71 @@ void VPlanTransforms::createInterleaveGroups( } } -void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) { +void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan, + Type &CanonicalIVTy) { + using namespace llvm::VPlanPatternMatch; + + VPTypeAnalysis TypeInfo(&CanonicalIVTy); + SmallVector ToRemove; for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly( vp_depth_first_deep(Plan.getEntry()))) { - for (VPRecipeBase &R : make_early_inc_range(VPBB->phis())) { - if (!isa(&R)) + for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { + if (isa(&R)) { + auto *PhiR = cast(&R); + StringRef Name = + isa(PhiR) ? "index" : "evl.based.iv"; + auto *ScalarR = new VPInstruction( + Instruction::PHI, {PhiR->getStartValue(), PhiR->getBackedgeValue()}, + PhiR->getDebugLoc(), Name); + ScalarR->insertBefore(PhiR); + PhiR->replaceAllUsesWith(ScalarR); + ToRemove.push_back(PhiR); continue; - auto *PhiR = cast(&R); - StringRef Name = - isa(PhiR) ? "index" : "evl.based.iv"; - auto *ScalarR = new VPInstruction( - Instruction::PHI, {PhiR->getStartValue(), PhiR->getBackedgeValue()}, - PhiR->getDebugLoc(), Name); - ScalarR->insertBefore(PhiR); - PhiR->replaceAllUsesWith(ScalarR); - PhiR->eraseFromParent(); + } + + VPValue *VectorStep; + VPValue *ScalarStep; + if (!match(&R, m_VPInstruction( + m_VPValue(VectorStep), m_VPValue(ScalarStep)))) + continue; + + // Expand WideIVStep. + auto *VPI = cast(&R); + VPBuilder Builder(VPI); + Type *IVTy = TypeInfo.inferScalarType(VPI); + if (TypeInfo.inferScalarType(VectorStep) != IVTy) { + Instruction::CastOps CastOp = IVTy->isFloatingPointTy() + ? Instruction::UIToFP + : Instruction::Trunc; + VectorStep = Builder.createWidenCast(CastOp, VectorStep, IVTy); + } + + auto *ConstStep = + ScalarStep->isLiveIn() + ? dyn_cast(ScalarStep->getLiveInIRValue()) + : nullptr; + assert(!ConstStep || ConstStep->getValue() != 1); + if (TypeInfo.inferScalarType(ScalarStep) != IVTy) { + ScalarStep = + Builder.createWidenCast(Instruction::Trunc, ScalarStep, IVTy); + } + + std::optional FMFs; + if (IVTy->isFloatingPointTy()) + FMFs = VPI->getFastMathFlags(); + + unsigned MulOpc = + IVTy->isFloatingPointTy() ? Instruction::FMul : Instruction::Mul; + VPInstruction *Mul = Builder.createNaryOp( + MulOpc, {VectorStep, ScalarStep}, FMFs, R.getDebugLoc()); + VectorStep = Mul; + VPI->replaceAllUsesWith(VectorStep); + ToRemove.push_back(VPI); } } + + for (VPRecipeBase *R : ToRemove) + R->eraseFromParent(); } void VPlanTransforms::handleUncountableEarlyExit( diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index c23ff38265670..ee3642a8aff73 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -176,8 +176,9 @@ struct VPlanTransforms { BasicBlock *UncountableExitingBlock, VPRecipeBuilder &RecipeBuilder); - /// Lower abstract recipes to concrete ones, that can be codegen'd. - static void convertToConcreteRecipes(VPlan &Plan); + /// Lower abstract recipes to concrete ones, that can be codegen'd. Use \p + /// CanonicalIVTy as type for all un-typed live-ins in VPTypeAnalysis. + static void convertToConcreteRecipes(VPlan &Plan, Type &CanonicalIVTy); /// Perform instcombine-like simplifications on recipes in \p Plan. Use \p /// CanonicalIVTy as type for all un-typed live-ins in VPTypeAnalysis. diff --git a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp index a513a255344cc..b48a447834cc8 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp @@ -155,33 +155,13 @@ void UnrollState::unrollWidenInductionByUF( if (isa_and_present(ID.getInductionBinOp())) FMFs = ID.getInductionBinOp()->getFastMathFlags(); - VPValue *VectorStep = &Plan.getVF(); - VPBuilder Builder(PH); - if (TypeInfo.inferScalarType(VectorStep) != IVTy) { - Instruction::CastOps CastOp = - IVTy->isFloatingPointTy() ? Instruction::UIToFP : Instruction::Trunc; - VectorStep = Builder.createWidenCast(CastOp, VectorStep, IVTy); - ToSkip.insert(VectorStep->getDefiningRecipe()); - } - VPValue *ScalarStep = IV->getStepValue(); - auto *ConstStep = ScalarStep->isLiveIn() - ? dyn_cast(ScalarStep->getLiveInIRValue()) - : nullptr; - if (!ConstStep || ConstStep->getValue() != 1) { - if (TypeInfo.inferScalarType(ScalarStep) != IVTy) { - ScalarStep = - Builder.createWidenCast(Instruction::Trunc, ScalarStep, IVTy); - ToSkip.insert(ScalarStep->getDefiningRecipe()); - } + VPBuilder Builder(PH); + VPInstruction *VectorStep = Builder.createNaryOp( + VPInstruction::WideIVStep, {&Plan.getVF(), ScalarStep}, IVTy, FMFs, + IV->getDebugLoc()); - unsigned MulOpc = - IVTy->isFloatingPointTy() ? Instruction::FMul : Instruction::Mul; - VPInstruction *Mul = Builder.createNaryOp(MulOpc, {VectorStep, ScalarStep}, - FMFs, IV->getDebugLoc()); - VectorStep = Mul; - ToSkip.insert(Mul); - } + ToSkip.insert(VectorStep); // Now create recipes to compute the induction steps for part 1 .. UF. Part 0 // remains the header phi. Parts > 0 are computed by adding Step to the