Skip to content

[LoopVectorizer][AArch64] Add support for partial reduce subtraction #123636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 13, 2025

Conversation

NickGuy-Arm
Copy link
Contributor

No description provided.

@NickGuy-Arm
Copy link
Contributor Author

Having trouble isolating the tests, I believe due to the loop vectorizer not currently supporting subtraction recurrences. But I wanted to get the code itself visible even if the tests aren't ready, hence the draft.

@llvmbot
Copy link
Member

llvmbot commented Jan 28, 2025

@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-vectorizers

Author: Nicholas Guy (NickGuy-Arm)

Changes

Patch is 38.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123636.diff

6 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+11-7)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+25-4)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll (+58-46)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll (+3-3)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll (+100)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index aae2fdaf5bec37..9111efb9bb4e90 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4673,7 +4673,7 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
   InstructionCost Invalid = InstructionCost::getInvalid();
   InstructionCost Cost(TTI::TCC_Basic);
 
-  if (Opcode != Instruction::Add)
+  if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
     return Invalid;
 
   if (InputTypeA != InputTypeB)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 57b7358049bcef..c195230e8dcb32 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8697,8 +8697,9 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
 
   // Build up a set of partial reduction bin ops for efficient use checking.
   SmallSet<User *, 4> PartialReductionBinOps;
-  for (const auto &[PartialRdx, _] : PartialReductionChains)
+  for (const auto &[PartialRdx, _] : PartialReductionChains) {
     PartialReductionBinOps.insert(PartialRdx.BinOp);
+  }
 
   auto ExtendIsOnlyUsedByPartialReductions =
       [&PartialReductionBinOps](Instruction *Extend) {
@@ -8761,20 +8762,23 @@ bool VPRecipeBuilder::getScaledReductions(
     return false;
 
   using namespace llvm::PatternMatch;
+  BinaryOperator *ExtendedBinOp = BinOp;
+  match(BinOp, m_Neg(m_BinOp(ExtendedBinOp)));
+
   Value *A, *B;
-  if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
-      !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
+  if (!match(ExtendedBinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
+      !match(ExtendedBinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
     return false;
 
-  Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
-  Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+  Instruction *ExtA = cast<Instruction>(ExtendedBinOp->getOperand(0));
+  Instruction *ExtB = cast<Instruction>(ExtendedBinOp->getOperand(1));
 
   TTI::PartialReductionExtendKind OpAExtend =
       TargetTransformInfo::getPartialReductionExtendKind(ExtA);
   TTI::PartialReductionExtendKind OpBExtend =
       TargetTransformInfo::getPartialReductionExtendKind(ExtB);
 
-  PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
+  PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, ExtendedBinOp);
 
   unsigned TargetScaleFactor =
       PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8785,7 +8789,7 @@ bool VPRecipeBuilder::getScaledReductions(
             InstructionCost Cost = TTI->getPartialReductionCost(
                 Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
                 VF, OpAExtend, OpBExtend,
-                std::make_optional(BinOp->getOpcode()));
+                std::make_optional(ExtendedBinOp->getOpcode()));
             return Cost.isValid();
           },
           Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2679ed6b26b5d1..eed2476f3770ce 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -25,6 +25,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/IR/VectorBuilder.h"
@@ -282,7 +283,20 @@ InstructionCost
 VPPartialReductionRecipe::computeCost(ElementCount VF,
                                       VPCostContext &Ctx) const {
   std::optional<unsigned> Opcode = std::nullopt;
-  VPRecipeBase *BinOpR = getOperand(0)->getDefiningRecipe();
+  VPValue *BinOp = getOperand(0);
+  VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
+
+  using namespace llvm::PatternMatch;
+  if (auto *UnderInst =
+          dyn_cast_if_present<Instruction>(BinOp->getUnderlyingValue())) {
+    if (match(UnderInst, m_Neg(m_BinOp()))) {
+      BinOpR = BinOpR->getOperand(1)->getDefiningRecipe();
+    }
+  }
+  // BinOp is never used again, any further interaction should be via the
+  // defining recipe `BinOpR`
+  BinOp = nullptr;
+
   if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
     Opcode = std::make_optional(WidenR->getOpcode());
 
@@ -318,13 +332,20 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
   State.setDebugLocFrom(getDebugLoc());
   auto &Builder = State.Builder;
 
-  assert(getOpcode() == Instruction::Add &&
-         "Unhandled partial reduction opcode");
-
   Value *BinOpVal = State.get(getOperand(0));
   Value *PhiVal = State.get(getOperand(1));
   assert(PhiVal && BinOpVal && "Phi and Mul must be set");
 
+  unsigned Opcode = getOpcode();
+
+  if (Opcode == Instruction::Sub) {
+    bool HasNSW = cast<Instruction>(BinOpVal)->hasNoSignedWrap();
+    BinOpVal = Builder.CreateNeg(BinOpVal, "", HasNSW);
+    Opcode = Instruction::Add;
+  }
+
+  assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
+
   Type *RetTy = PhiVal->getType();
 
   CallInst *V = Builder.CreateIntrinsic(
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index bedf8b6b3a9b56..2dc515649b2262 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -22,7 +22,7 @@ define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-NEON-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-NEON:       vector.body:
 ; CHECK-NEON-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEON-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEON-NEXT:    [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP0]]
 ; CHECK-NEON-NEXT:    [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP0]]
@@ -37,14 +37,15 @@ define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-NEON-NEXT:    [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
 ; CHECK-NEON-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NEON-NEXT:    [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
-; CHECK-NEON-NEXT:    [[TMP11:%.*]] = add <16 x i32> [[VEC_PHI]], [[TMP10]]
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
 ; CHECK-NEON-NEXT:    [[TMP12:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP9]]
-; CHECK-NEON-NEXT:    [[TMP13]] = sub <16 x i32> [[TMP11]], [[TMP12]]
+; CHECK-NEON-NEXT:    [[TMP13:%.*]] = sub nsw <16 x i32> zeroinitializer, [[TMP12]]
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP13]])
 ; CHECK-NEON-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
 ; CHECK-NEON-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEON-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK-NEON:       middle.block:
-; CHECK-NEON-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP13]])
+; CHECK-NEON-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
 ; CHECK-NEON-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
 ; CHECK-NEON-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
 ;
@@ -114,7 +115,7 @@ define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-SVE-MAXBW-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-SVE-MAXBW:       vector.body:
 ; CHECK-SVE-MAXBW-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-SVE-MAXBW-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-MAXBW-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 2 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
@@ -129,14 +130,15 @@ define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP14:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP15:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP16:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-MAXBW-NEXT:    [[TMP17:%.*]] = add <vscale x 8 x i32> [[VEC_PHI]], [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP16]])
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP15]]
-; CHECK-SVE-MAXBW-NEXT:    [[TMP19]] = sub <vscale x 8 x i32> [[TMP17]], [[TMP18]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP19:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP18]]
+; CHECK-SVE-MAXBW-NEXT:    [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP19]])
 ; CHECK-SVE-MAXBW-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-SVE-MAXBW-NEXT:    br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK-SVE-MAXBW:       middle.block:
-; CHECK-SVE-MAXBW-NEXT:    [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP19]])
+; CHECK-SVE-MAXBW-NEXT:    [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
 ; CHECK-SVE-MAXBW-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
 ; CHECK-SVE-MAXBW-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
 ;
@@ -350,7 +352,7 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-NEON-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-NEON:       vector.body:
 ; CHECK-NEON-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEON-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEON-NEXT:    [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP0]]
 ; CHECK-NEON-NEXT:    [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP0]]
@@ -365,14 +367,15 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-NEON-NEXT:    [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
 ; CHECK-NEON-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NEON-NEXT:    [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
-; CHECK-NEON-NEXT:    [[TMP11:%.*]] = sub <16 x i32> [[VEC_PHI]], [[TMP10]]
+; CHECK-NEON-NEXT:    [[TMP11:%.*]] = sub nsw <16 x i32> zeroinitializer, [[TMP10]]
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
 ; CHECK-NEON-NEXT:    [[TMP12:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP9]]
-; CHECK-NEON-NEXT:    [[TMP13]] = add <16 x i32> [[TMP11]], [[TMP12]]
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP12]])
 ; CHECK-NEON-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
 ; CHECK-NEON-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEON-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
 ; CHECK-NEON:       middle.block:
-; CHECK-NEON-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP13]])
+; CHECK-NEON-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
 ; CHECK-NEON-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
 ; CHECK-NEON-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
 ;
@@ -442,7 +445,7 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-SVE-MAXBW-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-SVE-MAXBW:       vector.body:
 ; CHECK-SVE-MAXBW-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-SVE-MAXBW-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-MAXBW-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 2 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
@@ -457,14 +460,15 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP14:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP15:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP16:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-MAXBW-NEXT:    [[TMP17:%.*]] = sub <vscale x 8 x i32> [[VEC_PHI]], [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP17:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP17]])
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP15]]
-; CHECK-SVE-MAXBW-NEXT:    [[TMP19]] = add <vscale x 8 x i32> [[TMP17]], [[TMP18]]
+; CHECK-SVE-MAXBW-NEXT:    [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP18]])
 ; CHECK-SVE-MAXBW-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-SVE-MAXBW-NEXT:    br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
 ; CHECK-SVE-MAXBW:       middle.block:
-; CHECK-SVE-MAXBW-NEXT:    [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP19]])
+; CHECK-SVE-MAXBW-NEXT:    [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
 ; CHECK-SVE-MAXBW-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
 ; CHECK-SVE-MAXBW-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
 ;
@@ -516,7 +520,7 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-NEON-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-NEON:       vector.body:
 ; CHECK-NEON-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEON-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEON-NEXT:    [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP0]]
 ; CHECK-NEON-NEXT:    [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP0]]
@@ -531,14 +535,16 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-NEON-NEXT:    [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
 ; CHECK-NEON-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NEON-NEXT:    [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
-; CHECK-NEON-NEXT:    [[TMP11:%.*]] = sub <16 x i32> [[VEC_PHI]], [[TMP10]]
+; CHECK-NEON-NEXT:    [[TMP11:%.*]] = sub nsw <16 x i32> zeroinitializer, [[TMP10]]
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
 ; CHECK-NEON-NEXT:    [[TMP12:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP9]]
-; CHECK-NEON-NEXT:    [[TMP13]] = sub <16 x i32> [[TMP11]], [[TMP12]]
+; CHECK-NEON-NEXT:    [[TMP13:%.*]] = sub nsw <16 x i32> zeroinitializer, [[TMP12]]
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP13]])
 ; CHECK-NEON-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
 ; CHECK-NEON-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEON-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
 ; CHECK-NEON:       middle.block:
-; CHECK-NEON-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP13]])
+; CHECK-NEON-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
 ; CHECK-NEON-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
 ; CHECK-NEON-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
 ;
@@ -608,7 +614,7 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-SVE-MAXBW-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-SVE-MAXBW:       vector.body:
 ; CHECK-SVE-MAXBW-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-SVE-MAXBW-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-MAXBW-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 2 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
@@ -623,14 +629,16 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP14:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP15:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP16:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-MAXBW-NEXT:    [[TMP17:%.*]] = sub <vscale x 8 x i32> [[VEC_PHI]], [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP17:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP17]])
 ; CHECK-SVE-MAXBW-NEXT:    [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP15]]
-; CHECK-SVE-MAXBW-NE...
[truncated]

Comment on lines 8700 to 8702
for (const auto &[PartialRdx, _] : PartialReductionChains) {
PartialReductionBinOps.insert(PartialRdx.BinOp);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, I needed the braces for some extra debug output. Removed

Comment on lines 8765 to 8766
BinaryOperator *ExtendedBinOp = BinOp;
match(BinOp, m_Neg(m_BinOp(ExtendedBinOp)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this just set BinOp inside the match instead of creating a new variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another remnant of a previous iteration. We did use some info from BinOp (I think initially it was for costing), but now we're costing from ExtendedBinOp. Removed the new variable.

Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, with a couple of suggestions.

Comment on lines 289 to 295
using namespace llvm::PatternMatch;
if (auto *UnderInst =
dyn_cast_if_present<Instruction>(BinOp->getUnderlyingValue())) {
if (match(UnderInst, m_Neg(m_BinOp()))) {
BinOpR = BinOpR->getOperand(1)->getDefiningRecipe();
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment explaining why this has to get the 2nd operand of the BinOp would be good.

Comment on lines 296 to 302
// BinOp is never used again, any further interaction should be via the
// defining recipe `BinOpR`
BinOp = nullptr;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels a bit strange to set something to null like this, perhaps it should just be left alone. Or we could just not declare the variable in the first place and use getOperand(0) instead.

Comment on lines 290 to 294
if (auto *UnderInst =
dyn_cast_if_present<Instruction>(BinOp->getUnderlyingValue())) {
if (match(UnderInst, m_Neg(m_BinOp()))) {
BinOpR = BinOpR->getOperand(1)->getDefiningRecipe();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we match this on the VPValues/VPRecipes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can indeed. Done.

@@ -2214,7 +2214,7 @@ define i32 @not_dotp_extend_user(ptr %a, ptr %b) #0 {
; CHECK-MAXBW-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-MAXBW: vector.body:
; CHECK-MAXBW-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-MAXBW-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP24:%.*]], [[VECTOR_BODY]] ]
; CHECK-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP23:%.*]], [[VECTOR_BODY]] ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seemingly so, not sure why the update script changed only these. Reverted as there are no other changes in this file.

Comment on lines 341 to 345
if (Opcode == Instruction::Sub) {
bool HasNSW = cast<Instruction>(BinOpVal)->hasNoSignedWrap();
BinOpVal = Builder.CreateNeg(BinOpVal, "", HasNSW);
Opcode = Instruction::Add;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The operands are always VPWidenRecipe's, correct?

Could we instead create adjust the input recipes to first negate the operand and use Add instead of Sub for the partial reduction recipe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are always Widen recipes as far as I can tell, however the incoming IR doesn't always match the pattern add(<a>, neg(<b>)). In the case of complex dot products, the second sub is explicitly represented as a sub, we then transform that here to follow the aformentioned pattern.

Unless I'm missing something, there doesn't seem to be a method of creating multiple VPRecipes from a single source instruction, which would be required for this suggestion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general it should be possible to create multiple recipes for a single source instruction (we do it in multiple places already IIRC), as long as the additional instructions do not need to be mapped to IR instructions for lookup later.

A recent example is https://github.com/llvm/llvm-project/pull/124268/files#diff-da321d454a7246f8ae276bf1db2782bf26b5210b8133cb59e4d7fd45d0905decR8900

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, though it caused the cost model to evaluate the vplan as being more expensive, so doesn't emit scalable vectors by default in this case. (I believe it's a similar case to why -vectorizer-maximize-bandwidth is now a thing, but even with that fixed-width plans are chosen over scalable plans).

A workaround for this is to use the opt arguments -force-vector-width=8/16 -scalable-vectorization=preferred in tandem (or in C++ by using a loop pragma vectorize_width(8/16, scalable)) to force a VF that is supported by partial reductions.

auto *const Zero = ConstantInt::get(Reduction->getType(), 0);
SmallVector<VPValue *, 2> Ops;
Ops.push_back(Plan.getOrAddLiveIn(Zero));
Ops.push_back(cast<VPWidenRecipe>(BinOp->getDefiningRecipe()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to cast? Can it just add BinOp to the list?

Copy link
Contributor Author

@NickGuy-Arm NickGuy-Arm Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was initially using the recipe, and after a few iterations got rid of the direct recipe usage. This stuck around as I didn't register at the time that BinOp in this context is the same as the recipe. Will be fixed in the next commit (assuming no test failures)

@@ -632,7 +632,7 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: [[TMP17:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP16]]
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP17]])
; CHECK-SVE-MAXBW-NEXT: [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP15]]
; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP18]]
; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = sub <vscale x 8 x i32> zeroinitializer, [[TMP18]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the flags be discarded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, they're not being discarded so much as they simply don't exist. This is the sub that is being created by matching an existing sub instruction and replacing it with add(%a, sub(0, %b)). I'm also not sure that the flags will do anything in this case, as the only time it would overflow is if TMP18 had already overflowed, but in that case would be a poison value.

I'm happy to add them if they're deemed necessary/helpful though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction: We do already propagate the flags, however in this case the flags are missing due to not being on the respective sub instruction in the source IR

Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link

github-actions bot commented Feb 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Comment on lines 8940 to 8942
VPBasicBlock *ParentBlock = Builder.getInsertBlock();
if (!ParentBlock)
return nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the insert block ever be null? I'd expect it to always be set at this point?

Copy link
Contributor Author

@NickGuy-Arm NickGuy-Arm Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen any case where it would be, so this was just a bit of defensiveness for something I wasn't 100% sure on. I can switch it to an assert, or remove it altogether if you think the check is irrelevant.

Edit: Looks like VPBuilder::clearInsertPoint can result in getInsertPoint returning a nullptr. So I've changed it to an assert.

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@NickGuy-Arm NickGuy-Arm merged commit 9c89faa into llvm:main Feb 13, 2025
8 checks passed
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
@@ -8800,6 +8800,10 @@ bool VPRecipeBuilder::getScaledReductions(
return false;

using namespace llvm::PatternMatch;
// Use the side-effect of match to replace BinOp only if the pattern is
// matched, we don't care at this point whether it actually matched.
match(BinOp, m_Neg(m_BinOp(BinOp)));
Copy link
Contributor

@david-arm david-arm Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how this works. We've already bailed out above if BinOp is not a BinaryOperator, which surely means that it cannot simultaneously be a unary operator, which is what m_Neg represents?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I realise now that m_Neg is actually matching a sub 0, %x operation. Please ignore the noise!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants