diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index afad73bcd3501..591a3f26d4d13 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -525,6 +525,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue { case VPRecipeBase::VPInstructionSC: case VPRecipeBase::VPReductionEVLSC: case VPRecipeBase::VPReductionSC: + case VPRecipeBase::VPMulAccumulateReductionSC: + case VPRecipeBase::VPExtendedReductionSC: case VPRecipeBase::VPReplicateSC: case VPRecipeBase::VPScalarIVStepsSC: case VPRecipeBase::VPVectorPointerSC: @@ -609,13 +611,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe { DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {} }; + struct NonNegFlagsTy { + char NonNeg : 1; + NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {} + }; + private: struct ExactFlagsTy { char IsExact : 1; }; - struct NonNegFlagsTy { - char NonNeg : 1; - }; struct FastMathFlagsTy { char AllowReassoc : 1; char NoNaNs : 1; @@ -709,6 +713,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe { : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp), DisjointFlags(DisjointFlags) {} + template + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, + NonNegFlagsTy NonNegFlags, DebugLoc DL = {}) + : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp), + NonNegFlags(NonNegFlags) {} + protected: template VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, @@ -728,7 +738,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe { R->getVPDefID() == VPRecipeBase::VPReductionEVLSC || R->getVPDefID() == VPRecipeBase::VPReplicateSC || R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC || - R->getVPDefID() == VPRecipeBase::VPVectorPointerSC; + R->getVPDefID() == VPRecipeBase::VPVectorPointerSC || + R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC || + R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC; } static inline bool classof(const VPUser *U) { @@ -820,6 +832,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe { FastMathFlags getFastMathFlags() const; + /// Returns true if the recipe has non-negative flag. + bool hasNonNegFlag() const { return OpType == OperationType::NonNegOp; } + + bool isNonNeg() const { + assert(OpType == OperationType::NonNegOp && + "recipe doesn't have a NNEG flag"); + return NonNegFlags.NonNeg; + } + bool hasNoUnsignedWrap() const { assert(OpType == OperationType::OverflowingBinOp && "recipe doesn't have a NUW flag"); @@ -2373,6 +2394,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags { setUnderlyingValue(I); } + /// For VPExtendedReductionRecipe. + /// Note that the debug location is from the extend. + VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind, + ArrayRef Operands, VPValue *CondOp, + bool IsOrdered, DebugLoc DL) + : VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind), + IsOrdered(IsOrdered), IsConditional(CondOp) { + if (CondOp) + addOperand(CondOp); + } + + /// For VPMulAccumulateReductionRecipe. + /// Note that the NUW/NSW flags and the debug location are from the Mul. + VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind, + ArrayRef Operands, VPValue *CondOp, + bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL) + : VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind), + IsOrdered(IsOrdered), IsConditional(CondOp) { + if (CondOp) + addOperand(CondOp); + } + public: VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I, VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp, @@ -2381,6 +2424,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags { ArrayRef({ChainOp, VecOp}), CondOp, IsOrdered, DL) {} + VPReductionRecipe(const RecurKind RdxKind, FastMathFlags FMFs, + VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp, + bool IsOrdered, DebugLoc DL = {}) + : VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr, + ArrayRef({ChainOp, VecOp}), CondOp, + IsOrdered, DL) {} + ~VPReductionRecipe() override = default; VPReductionRecipe *clone() override { @@ -2391,7 +2441,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags { static inline bool classof(const VPRecipeBase *R) { return R->getVPDefID() == VPRecipeBase::VPReductionSC || - R->getVPDefID() == VPRecipeBase::VPReductionEVLSC; + R->getVPDefID() == VPRecipeBase::VPReductionEVLSC || + R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC || + R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC; } static inline bool classof(const VPUser *U) { @@ -2471,6 +2523,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe { } }; +/// A recipe to represent inloop extended reduction operations, performing a +/// reduction on a extended vector operand into a scalar value, and adding the +/// result to a chain. This recipe is abstract and needs to be lowered to +/// concrete recipes before codegen. The operands are {ChainOp, VecOp, +/// [Condition]}. +class VPExtendedReductionRecipe : public VPReductionRecipe { + /// Opcode of the extend recipe will be lowered to. + Instruction::CastOps ExtOp; + + Type *ResultTy; + + /// For cloning VPExtendedReductionRecipe. + VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed) + : VPReductionRecipe( + VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(), + {ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(), + ExtRed->isOrdered(), ExtRed->getDebugLoc()), + ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) { + transferFlags(*ExtRed); + } + +public: + VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext) + : VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(), + {R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(), + R->isOrdered(), Ext->getDebugLoc()), + ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) { + // Not all WidenCastRecipes contain nneg flag. Need to transfer flags from + // the original recipe to prevent setting wrong flags. + transferFlags(*Ext); + } + + ~VPExtendedReductionRecipe() override = default; + + VPExtendedReductionRecipe *clone() override { + auto *Copy = new VPExtendedReductionRecipe(this); + Copy->transferFlags(*this); + return Copy; + } + + VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC); + + void execute(VPTransformState &State) override { + llvm_unreachable("VPExtendedReductionRecipe should be transform to " + "VPExtendedRecipe + VPReductionRecipe before execution."); + }; + + /// Return the cost of VPExtendedReductionRecipe. + InstructionCost computeCost(ElementCount VF, + VPCostContext &Ctx) const override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + /// The scalar type after extending. + Type *getResultType() const { return ResultTy; } + + /// Is the extend ZExt? + bool isZExt() const { return getExtOpcode() == Instruction::ZExt; } + + /// The opcode of extend recipe. + Instruction::CastOps getExtOpcode() const { return ExtOp; } +}; + +/// A recipe to represent inloop MulAccumulateReduction operations, performing a +/// reduction.add on the result of vector operands (might be extended) +/// multiplication into a scalar value, and adding the result to a chain. This +/// recipe is abstract and needs to be lowered to concrete recipes before +/// codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}. +class VPMulAccumulateReductionRecipe : public VPReductionRecipe { + /// Opcode of the extend recipe. + Instruction::CastOps ExtOp; + + /// Non-neg flag of the extend recipe. + bool IsNonNeg = false; + + Type *ResultTy; + + /// For cloning VPMulAccumulateReductionRecipe. + VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc) + : VPReductionRecipe( + VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(), + {MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()}, + MulAcc->getCondOp(), MulAcc->isOrdered(), + WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()), + MulAcc->getDebugLoc()), + ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()), + ResultTy(MulAcc->getResultType()) {} + +public: + VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul, + VPWidenCastRecipe *Ext0, + VPWidenCastRecipe *Ext1, Type *ResultTy) + : VPReductionRecipe( + VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(), + {R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)}, + R->getCondOp(), R->isOrdered(), + WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()), + R->getDebugLoc()), + ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) { + assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) == + Instruction::Add && + "The reduction instruction in MulAccumulateteReductionRecipe must " + "be Add"); + // Only set the non-negative flag if the original recipe contains. + if (Ext0->hasNonNegFlag()) + IsNonNeg = Ext0->isNonNeg(); + } + + VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul) + : VPReductionRecipe( + VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(), + {R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)}, + R->getCondOp(), R->isOrdered(), + WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()), + R->getDebugLoc()), + ExtOp(Instruction::CastOps::CastOpsEnd) { + assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) == + Instruction::Add && + "The reduction instruction in MulAccumulateReductionRecipe must be " + "Add"); + } + + ~VPMulAccumulateReductionRecipe() override = default; + + VPMulAccumulateReductionRecipe *clone() override { + auto *Copy = new VPMulAccumulateReductionRecipe(this); + Copy->transferFlags(*this); + return Copy; + } + + VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC); + + void execute(VPTransformState &State) override { + llvm_unreachable("VPMulAccumulateReductionRecipe should transform to " + "VPWidenCastRecipe + " + "VPWidenRecipe + VPReductionRecipe before execution"); + } + + /// Return the cost of VPMulAccumulateReductionRecipe. + InstructionCost computeCost(ElementCount VF, + VPCostContext &Ctx) const override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + Type *getResultType() const { + assert(isExtended() && "Only support getResultType when this recipe " + "contains implicit extend."); + return ResultTy; + } + + /// The VPValue of the vector value to be extended and reduced. + VPValue *getVecOp0() const { return getOperand(1); } + VPValue *getVecOp1() const { return getOperand(2); } + + /// Return if this MulAcc recipe contains extended operands. + bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; } + + /// Return the opcode of the extends for the operands. + Instruction::CastOps getExtOpcode() const { return ExtOp; } + + /// Return if the operands are zero extended. + bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; } + + /// Return the non negative flag of the ext recipe. + bool isNonNeg() const { return IsNonNeg; } +}; + /// VPReplicateRecipe replicates a given instruction producing multiple scalar /// copies of the original scalar type, one per lane, instead of producing a /// single copy of widened type for all lanes. If the instruction is known to be diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp index c86815c84d8d9..7dcbd72c25191 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp @@ -273,6 +273,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) { // TODO: Use info from interleave group. return V->getUnderlyingValue()->getType(); }) + .Case( + [](const auto *R) { return R->getResultType(); }) .Case([](const VPExpandSCEVRecipe *R) { return R->getSCEV()->getType(); }) diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 75d056026025a..8978a4d5e93cf 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -71,6 +71,8 @@ bool VPRecipeBase::mayWriteToMemory() const { case VPBlendSC: case VPReductionEVLSC: case VPReductionSC: + case VPExtendedReductionSC: + case VPMulAccumulateReductionSC: case VPVectorPointerSC: case VPWidenCanonicalIVSC: case VPWidenCastSC: @@ -118,6 +120,8 @@ bool VPRecipeBase::mayReadFromMemory() const { case VPBlendSC: case VPReductionEVLSC: case VPReductionSC: + case VPExtendedReductionSC: + case VPMulAccumulateReductionSC: case VPVectorPointerSC: case VPWidenCanonicalIVSC: case VPWidenCastSC: @@ -155,6 +159,8 @@ bool VPRecipeBase::mayHaveSideEffects() const { case VPBlendSC: case VPReductionEVLSC: case VPReductionSC: + case VPExtendedReductionSC: + case VPMulAccumulateReductionSC: case VPScalarIVStepsSC: case VPVectorPointerSC: case VPWidenCanonicalIVSC: @@ -2513,6 +2519,18 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF, Ctx.CostKind); } +InstructionCost +VPExtendedReductionRecipe::computeCost(ElementCount VF, + VPCostContext &Ctx) const { + return 0; +} + +InstructionCost +VPMulAccumulateReductionRecipe::computeCost(ElementCount VF, + VPCostContext &Ctx) const { + return 0; +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { @@ -2555,6 +2573,56 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent, } O << ")"; } + +void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "EXTENDED-REDUCE "; + printAsOperand(O, SlotTracker); + O << " = "; + getChainOp()->printAsOperand(O, SlotTracker); + O << " +"; + O << " reduce." + << Instruction::getOpcodeName( + RecurrenceDescriptor::getOpcode(getRecurrenceKind())) + << " ("; + getVecOp()->printAsOperand(O, SlotTracker); + O << " extended to " << *getResultType(); + if (isConditional()) { + O << ", "; + getCondOp()->printAsOperand(O, SlotTracker); + } + O << ")"; +} + +void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "MULACC-REDUCE "; + printAsOperand(O, SlotTracker); + O << " = "; + getChainOp()->printAsOperand(O, SlotTracker); + O << " + "; + O << "reduce." + << Instruction::getOpcodeName( + RecurrenceDescriptor::getOpcode(getRecurrenceKind())) + << " ("; + O << "mul"; + printFlags(O); + if (isExtended()) + O << "("; + getVecOp0()->printAsOperand(O, SlotTracker); + if (isExtended()) + O << " extended to " << *getResultType() << "), ("; + else + O << ", "; + getVecOp1()->printAsOperand(O, SlotTracker); + if (isExtended()) + O << " extended to " << *getResultType() << ")"; + if (isConditional()) { + O << ", "; + getCondOp()->printAsOperand(O, SlotTracker); + } + O << ")"; +} #endif bool VPReplicateRecipe::shouldPack() const { diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h index 638156eab7a84..64065edd315f9 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -339,6 +339,8 @@ class VPDef { VPInterleaveSC, VPReductionEVLSC, VPReductionSC, + VPMulAccumulateReductionSC, + VPExtendedReductionSC, VPPartialReductionSC, VPReplicateSC, VPScalarIVStepsSC,