@@ -8682,12 +8682,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8682
8682
// / are valid so recipes can be formed later.
8683
8683
void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8684
8684
// Find all possible partial reductions.
8685
- SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8685
+ SmallVector<std::pair<PartialReductionChain, unsigned >>
8686
8686
PartialReductionChains;
8687
- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8688
- if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8689
- getScaledReduction (Phi, RdxDesc, Range))
8690
- PartialReductionChains. push_back (*Pair);
8687
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8688
+ if (auto SR = getScaledReduction (Phi, RdxDesc. getLoopExitInstr (), Range))
8689
+ PartialReductionChains. append (*SR);
8690
+ }
8691
8691
8692
8692
// A partial reduction is invalid if any of its extends are used by
8693
8693
// something that isn't another partial reduction. This is because the
@@ -8715,26 +8715,44 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8715
8715
}
8716
8716
}
8717
8717
8718
- std::optional<std::pair<PartialReductionChain, unsigned >>
8719
- VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8720
- const RecurrenceDescriptor &Rdx,
8718
+ std::optional<SmallVector<std::pair<PartialReductionChain, unsigned >>>
8719
+ VPRecipeBuilder::getScaledReduction (Instruction *PHI, Instruction *RdxExitInstr,
8721
8720
VFRange &Range) {
8721
+
8722
+ if (!CM.TheLoop ->contains (RdxExitInstr))
8723
+ return std::nullopt;
8724
+
8722
8725
// TODO: Allow scaling reductions when predicating. The select at
8723
8726
// the end of the loop chooses between the phi value and most recent
8724
8727
// reduction result, both of which have different VFs to the active lane
8725
8728
// mask when scaling.
8726
- if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
8729
+ if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
8727
8730
return std::nullopt;
8728
8731
8729
- auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
8732
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
8730
8733
if (!Update)
8731
8734
return std::nullopt;
8732
8735
8733
8736
Value *Op = Update->getOperand (0 );
8734
8737
Value *PhiOp = Update->getOperand (1 );
8735
- if (Op == PHI) {
8736
- Op = Update->getOperand (1 );
8737
- PhiOp = Update->getOperand (0 );
8738
+ if (Op == PHI)
8739
+ std::swap (Op, PhiOp);
8740
+
8741
+ SmallVector<std::pair<PartialReductionChain, unsigned >> Chains;
8742
+
8743
+ // Try and get a scaled reduction from the first non-phi operand.
8744
+ // If one is found, we use the discovered reduction instruction in
8745
+ // place of the accumulator for costing.
8746
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8747
+ if (auto SR0 = getScaledReduction (PHI, OpInst, Range)) {
8748
+ Chains.append (*SR0);
8749
+ PHI = SR0->rbegin ()->first .Reduction ;
8750
+
8751
+ Op = Update->getOperand (0 );
8752
+ PhiOp = Update->getOperand (1 );
8753
+ if (Op == PHI)
8754
+ std::swap (Op, PhiOp);
8755
+ }
8738
8756
}
8739
8757
if (PhiOp != PHI)
8740
8758
return std::nullopt;
@@ -8757,7 +8775,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8757
8775
TTI::PartialReductionExtendKind OpBExtend =
8758
8776
TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8759
8777
8760
- PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
8778
+ PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
8761
8779
8762
8780
unsigned TargetScaleFactor =
8763
8781
PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8772,9 +8790,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8772
8790
return Cost.isValid ();
8773
8791
},
8774
8792
Range))
8775
- return std::make_pair (Chain, TargetScaleFactor);
8793
+ Chains. push_back ( std::make_pair (Chain, TargetScaleFactor) );
8776
8794
8777
- return std::nullopt ;
8795
+ return Chains ;
8778
8796
}
8779
8797
8780
8798
VPRecipeBase *
@@ -8869,12 +8887,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8869
8887
" Unexpected number of operands for partial reduction" );
8870
8888
8871
8889
VPValue *BinOp = Operands[0 ];
8872
- VPValue *Phi = Operands[1 ];
8873
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8874
- std::swap (BinOp, Phi);
8875
-
8876
- return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8877
- Reduction);
8890
+ VPValue *Accumulator = Operands[1 ];
8891
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
8892
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8893
+ isa<VPPartialReductionRecipe>(BinOpRecipe))
8894
+ std::swap (BinOp, Accumulator);
8895
+
8896
+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp,
8897
+ Accumulator, Reduction);
8878
8898
}
8879
8899
8880
8900
void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
0 commit comments