Skip to content

[VPlan] Add new recipes for extended-reduction and mul-accumulate-reduction. NFC #137745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

ElvisWang123
Copy link
Contributor

@ElvisWang123 ElvisWang123 commented Apr 29, 2025

This patch add two new recipes for extended-reduction and the mul-accumulate-reductions.

  • VPExtendedReductionRecipe.
    • Contains widen-cast + reduction.
  • VPMulAccumulateReductionRecipe.
    • Contains widen-mul + widen-cast + reduction.

The transformation and the cost model of these recipes will in following patches.

Split from #113903.

…on. NFC

This patch add two new recipes for extended-reduction and the
mul-accumulate-reductions.

Split from llvm#113904.
@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2025

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: Elvis Wang (ElvisWang123)

Changes

This patch add two new recipes for extended-reduction and the mul-accumulate-reductions.

  • VPExtendedReductionRecipe.
    • Contains widen-cast + reduction.
  • VPMulAccumulateReductionRecipe.
    • Contains widen-mul + widen-cast + reduction.

The transformation and the cost model of these recipes will in following patches.

Split from #113904.


Full diff: https://github.com/llvm/llvm-project/pull/137745.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+232-5)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+2)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+68)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+2)
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index afad73bcd3501..591a3f26d4d13 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -525,6 +525,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPInstructionSC:
     case VPRecipeBase::VPReductionEVLSC:
     case VPRecipeBase::VPReductionSC:
+    case VPRecipeBase::VPMulAccumulateReductionSC:
+    case VPRecipeBase::VPExtendedReductionSC:
     case VPRecipeBase::VPReplicateSC:
     case VPRecipeBase::VPScalarIVStepsSC:
     case VPRecipeBase::VPVectorPointerSC:
@@ -609,13 +611,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
   };
 
+  struct NonNegFlagsTy {
+    char NonNeg : 1;
+    NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {}
+  };
+
 private:
   struct ExactFlagsTy {
     char IsExact : 1;
   };
-  struct NonNegFlagsTy {
-    char NonNeg : 1;
-  };
   struct FastMathFlagsTy {
     char AllowReassoc : 1;
     char NoNaNs : 1;
@@ -709,6 +713,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
       : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
         DisjointFlags(DisjointFlags) {}
 
+  template <typename IterT>
+  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                      NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
+        NonNegFlags(NonNegFlags) {}
+
 protected:
   template <typename IterT>
   VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
@@ -728,7 +738,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
            R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
            R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
            R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
+           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -820,6 +832,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
 
   FastMathFlags getFastMathFlags() const;
 
+  /// Returns true if the recipe has non-negative flag.
+  bool hasNonNegFlag() const { return OpType == OperationType::NonNegOp; }
+
+  bool isNonNeg() const {
+    assert(OpType == OperationType::NonNegOp &&
+           "recipe doesn't have a NNEG flag");
+    return NonNegFlags.NonNeg;
+  }
+
   bool hasNoUnsignedWrap() const {
     assert(OpType == OperationType::OverflowingBinOp &&
            "recipe doesn't have a NUW flag");
@@ -2373,6 +2394,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
     setUnderlyingValue(I);
   }
 
+  /// For VPExtendedReductionRecipe.
+  /// Note that the debug location is from the extend.
+  VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
+        IsOrdered(IsOrdered), IsConditional(CondOp) {
+    if (CondOp)
+      addOperand(CondOp);
+  }
+
+  /// For VPMulAccumulateReductionRecipe.
+  /// Note that the NUW/NSW flags and the debug location are from the Mul.
+  VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
+        IsOrdered(IsOrdered), IsConditional(CondOp) {
+    if (CondOp)
+      addOperand(CondOp);
+  }
+
 public:
   VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
                     VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2381,6 +2424,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
                           ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
                           IsOrdered, DL) {}
 
+  VPReductionRecipe(const RecurKind RdxKind, FastMathFlags FMFs,
+                    VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL = {})
+      : VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr,
+                          ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
+                          IsOrdered, DL) {}
+
   ~VPReductionRecipe() override = default;
 
   VPReductionRecipe *clone() override {
@@ -2391,7 +2441,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
 
   static inline bool classof(const VPRecipeBase *R) {
     return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
+           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -2471,6 +2523,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
   }
 };
 
+/// A recipe to represent inloop extended reduction operations, performing a
+/// reduction on a extended vector operand into a scalar value, and adding the
+/// result to a chain. This recipe is abstract and needs to be lowered to
+/// concrete recipes before codegen. The operands are {ChainOp, VecOp,
+/// [Condition]}.
+class VPExtendedReductionRecipe : public VPReductionRecipe {
+  /// Opcode of the extend recipe will be lowered to.
+  Instruction::CastOps ExtOp;
+
+  Type *ResultTy;
+
+  /// For cloning VPExtendedReductionRecipe.
+  VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
+      : VPReductionRecipe(
+            VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(),
+            {ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(),
+            ExtRed->isOrdered(), ExtRed->getDebugLoc()),
+        ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) {
+    transferFlags(*ExtRed);
+  }
+
+public:
+  VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
+      : VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(),
+                          {R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
+                          R->isOrdered(), Ext->getDebugLoc()),
+        ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) {
+    // Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
+    // the original recipe to prevent setting wrong flags.
+    transferFlags(*Ext);
+  }
+
+  ~VPExtendedReductionRecipe() override = default;
+
+  VPExtendedReductionRecipe *clone() override {
+    auto *Copy = new VPExtendedReductionRecipe(this);
+    Copy->transferFlags(*this);
+    return Copy;
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
+
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPExtendedReductionRecipe should be transform to "
+                     "VPExtendedRecipe + VPReductionRecipe before execution.");
+  };
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// The scalar type after extending.
+  Type *getResultType() const { return ResultTy; }
+
+  /// Is the extend ZExt?
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
+
+  /// The opcode of extend recipe.
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+};
+
+/// A recipe to represent inloop MulAccumulateReduction operations, performing a
+/// reduction.add on the result of vector operands (might be extended)
+/// multiplication into a scalar value, and adding the result to a chain. This
+/// recipe is abstract and needs to be lowered to concrete recipes before
+/// codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
+  /// Opcode of the extend recipe.
+  Instruction::CastOps ExtOp;
+
+  /// Non-neg flag of the extend recipe.
+  bool IsNonNeg = false;
+
+  Type *ResultTy;
+
+  /// For cloning VPMulAccumulateReductionRecipe.
+  VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(),
+            {MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
+            MulAcc->getCondOp(), MulAcc->isOrdered(),
+            WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
+            MulAcc->getDebugLoc()),
+        ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
+        ResultTy(MulAcc->getResultType()) {}
+
+public:
+  VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
+                                 VPWidenCastRecipe *Ext0,
+                                 VPWidenCastRecipe *Ext1, Type *ResultTy)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
+            {R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)},
+            R->getCondOp(), R->isOrdered(),
+            WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
+            R->getDebugLoc()),
+        ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
+    assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
+               Instruction::Add &&
+           "The reduction instruction in MulAccumulateteReductionRecipe must "
+           "be Add");
+    // Only set the non-negative flag if the original recipe contains.
+    if (Ext0->hasNonNegFlag())
+      IsNonNeg = Ext0->isNonNeg();
+  }
+
+  VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
+            {R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
+            R->getCondOp(), R->isOrdered(),
+            WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
+            R->getDebugLoc()),
+        ExtOp(Instruction::CastOps::CastOpsEnd) {
+    assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
+               Instruction::Add &&
+           "The reduction instruction in MulAccumulateReductionRecipe must be "
+           "Add");
+  }
+
+  ~VPMulAccumulateReductionRecipe() override = default;
+
+  VPMulAccumulateReductionRecipe *clone() override {
+    auto *Copy = new VPMulAccumulateReductionRecipe(this);
+    Copy->transferFlags(*this);
+    return Copy;
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC);
+
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPMulAccumulateReductionRecipe should transform to "
+                     "VPWidenCastRecipe + "
+                     "VPWidenRecipe + VPReductionRecipe before execution");
+  }
+
+  /// Return the cost of VPMulAccumulateReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  Type *getResultType() const {
+    assert(isExtended() && "Only support getResultType when this recipe "
+                           "contains implicit extend.");
+    return ResultTy;
+  }
+
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp0() const { return getOperand(1); }
+  VPValue *getVecOp1() const { return getOperand(2); }
+
+  /// Return if this MulAcc recipe contains extended operands.
+  bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
+
+  /// Return the opcode of the extends for the operands.
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+
+  /// Return if the operands are zero extended.
+  bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
+
+  /// Return the non negative flag of the ext recipe.
+  bool isNonNeg() const { return IsNonNeg; }
+};
+
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
 /// copies of the original scalar type, one per lane, instead of producing a
 /// single copy of widened type for all lanes. If the instruction is known to be
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index c86815c84d8d9..7dcbd72c25191 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -273,6 +273,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             // TODO: Use info from interleave group.
             return V->getUnderlyingValue()->getType();
           })
+          .Case<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(
+              [](const auto *R) { return R->getResultType(); })
           .Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) {
             return R->getSCEV()->getType();
           })
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 75d056026025a..8978a4d5e93cf 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -71,6 +71,8 @@ bool VPRecipeBase::mayWriteToMemory() const {
   case VPBlendSC:
   case VPReductionEVLSC:
   case VPReductionSC:
+  case VPExtendedReductionSC:
+  case VPMulAccumulateReductionSC:
   case VPVectorPointerSC:
   case VPWidenCanonicalIVSC:
   case VPWidenCastSC:
@@ -118,6 +120,8 @@ bool VPRecipeBase::mayReadFromMemory() const {
   case VPBlendSC:
   case VPReductionEVLSC:
   case VPReductionSC:
+  case VPExtendedReductionSC:
+  case VPMulAccumulateReductionSC:
   case VPVectorPointerSC:
   case VPWidenCanonicalIVSC:
   case VPWidenCastSC:
@@ -155,6 +159,8 @@ bool VPRecipeBase::mayHaveSideEffects() const {
   case VPBlendSC:
   case VPReductionEVLSC:
   case VPReductionSC:
+  case VPExtendedReductionSC:
+  case VPMulAccumulateReductionSC:
   case VPScalarIVStepsSC:
   case VPVectorPointerSC:
   case VPWidenCanonicalIVSC:
@@ -2513,6 +2519,18 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
                                                    Ctx.CostKind);
 }
 
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+                                       VPCostContext &Ctx) const {
+  return 0;
+}
+
+InstructionCost
+VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
+                                            VPCostContext &Ctx) const {
+  return 0;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
@@ -2555,6 +2573,56 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
   }
   O << ")";
 }
+
+void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                      VPSlotTracker &SlotTracker) const {
+  O << Indent << "EXTENDED-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " +";
+  O << " reduce."
+    << Instruction::getOpcodeName(
+           RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
+    << " (";
+  getVecOp()->printAsOperand(O, SlotTracker);
+  O << " extended to " << *getResultType();
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+}
+
+void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                           VPSlotTracker &SlotTracker) const {
+  O << Indent << "MULACC-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " + ";
+  O << "reduce."
+    << Instruction::getOpcodeName(
+           RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
+    << " (";
+  O << "mul";
+  printFlags(O);
+  if (isExtended())
+    O << "(";
+  getVecOp0()->printAsOperand(O, SlotTracker);
+  if (isExtended())
+    O << " extended to " << *getResultType() << "), (";
+  else
+    O << ", ";
+  getVecOp1()->printAsOperand(O, SlotTracker);
+  if (isExtended())
+    O << " extended to " << *getResultType() << ")";
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+}
 #endif
 
 bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 638156eab7a84..64065edd315f9 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -339,6 +339,8 @@ class VPDef {
     VPInterleaveSC,
     VPReductionEVLSC,
     VPReductionSC,
+    VPMulAccumulateReductionSC,
+    VPExtendedReductionSC,
     VPPartialReductionSC,
     VPReplicateSC,
     VPScalarIVStepsSC,

@ElvisWang123
Copy link
Contributor Author

Move changes in this patch to #137746.

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants