-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[LoopVectorizer] Add support for partial reductions #92418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b2c34e9
50ea501
b67f1b7
78948b5
03245db
70b1b7f
708f6fe
1cbb030
72ddb66
d005cde
7c24f91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,6 +211,12 @@ typedef TargetTransformInfo TTI; | |
/// for IR-level transformations. | ||
class TargetTransformInfo { | ||
public: | ||
enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend }; | ||
|
||
/// Get the kind of extension that an instruction represents. | ||
static PartialReductionExtendKind | ||
getPartialReductionExtendKind(Instruction *I); | ||
|
||
/// Construct a TTI object using a type implementing the \c Concept | ||
/// API below. | ||
/// | ||
|
@@ -1274,6 +1280,18 @@ class TargetTransformInfo { | |
/// \return if target want to issue a prefetch in address space \p AS. | ||
bool shouldPrefetchAddressSpace(unsigned AS) const; | ||
|
||
/// \return The cost of a partial reduction, which is a reduction from a | ||
/// vector to another vector with fewer elements of larger size. They are | ||
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which | ||
/// takes an accumulator and a binary operation operand that itself is fed by | ||
/// two extends. An example of an operation that uses a partial reduction is a | ||
/// dot product, which reduces a vector to another of 4 times fewer elements. | ||
InstructionCost | ||
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType, | ||
ElementCount VF, PartialReductionExtendKind OpAExtend, | ||
PartialReductionExtendKind OpBExtend, | ||
std::optional<unsigned> BinOp = std::nullopt) const; | ||
|
||
/// \return The maximum interleave factor that any transform should try to | ||
/// perform for this target. This number depends on the level of parallelism | ||
/// and the number of execution units in the CPU. | ||
|
@@ -2098,6 +2116,18 @@ class TargetTransformInfo::Concept { | |
/// \return if target want to issue a prefetch in address space \p AS. | ||
virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0; | ||
|
||
/// \return The cost of a partial reduction, which is a reduction from a | ||
/// vector to another vector with fewer elements of larger size. They are | ||
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which | ||
/// takes an accumulator and a binary operation operand that itself is fed by | ||
/// two extends. An example of an operation that uses a partial reduction is a | ||
/// dot product, which reduces a vector to another of 4 times fewer elements. | ||
virtual InstructionCost | ||
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType, | ||
Comment on lines
+2125
to
+2126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please document, this should include the definition of what partial reduction means in this context (possibly tying to the definition of the intrinsic?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
ElementCount VF, PartialReductionExtendKind OpAExtend, | ||
PartialReductionExtendKind OpBExtend, | ||
std::optional<unsigned> BinOp) const = 0; | ||
|
||
virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0; | ||
virtual InstructionCost getArithmeticInstrCost( | ||
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, | ||
|
@@ -2772,6 +2802,15 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { | |
return Impl.shouldPrefetchAddressSpace(AS); | ||
} | ||
|
||
InstructionCost getPartialReductionCost( | ||
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF, | ||
PartialReductionExtendKind OpAExtend, | ||
PartialReductionExtendKind OpBExtend, | ||
std::optional<unsigned> BinOp = std::nullopt) const override { | ||
return Impl.getPartialReductionCost(Opcode, InputType, AccumType, VF, | ||
OpAExtend, OpBExtend, BinOp); | ||
} | ||
|
||
unsigned getMaxInterleaveFactor(ElementCount VF) override { | ||
return Impl.getMaxInterleaveFactor(VF); | ||
} | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -7606,6 +7606,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan, | |||||
} | ||||||
continue; | ||||||
} | ||||||
// The VPlan-based cost model is more accurate for partial reduction and | ||||||
// comparing against the legacy cost isn't desirable. | ||||||
if (isa<VPPartialReductionRecipe>(&R)) | ||||||
return true; | ||||||
if (Instruction *UI = GetInstructionForCost(&R)) | ||||||
SeenInstrs.insert(UI); | ||||||
} | ||||||
|
@@ -8828,6 +8832,103 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I, | |||||
return Recipe; | ||||||
} | ||||||
|
||||||
/// Find all possible partial reductions in the loop and track all of those that | ||||||
/// are valid so recipes can be formed later. | ||||||
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) { | ||||||
// Find all possible partial reductions. | ||||||
SmallVector<std::pair<PartialReductionChain, unsigned>, 1> | ||||||
PartialReductionChains; | ||||||
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) | ||||||
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair = | ||||||
getScaledReduction(Phi, RdxDesc, Range)) | ||||||
PartialReductionChains.push_back(*Pair); | ||||||
|
||||||
// A partial reduction is invalid if any of its extends are used by | ||||||
// something that isn't another partial reduction. This is because the | ||||||
// extends are intended to be lowered along with the reduction itself. | ||||||
Comment on lines
+8846
to
+8848
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a fundamental limitation? Could we just keep the extend for the other users? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've looked into what the assembly looks like if we remove the restriction, and it looks pretty bad so I think removing the restriction should come in a future patch with additional work on the cost model. |
||||||
|
||||||
// Build up a set of partial reduction bin ops for efficient use checking. | ||||||
SmallSet<User *, 4> PartialReductionBinOps; | ||||||
for (const auto &[PartialRdx, _] : PartialReductionChains) | ||||||
PartialReductionBinOps.insert(PartialRdx.BinOp); | ||||||
|
||||||
auto ExtendIsOnlyUsedByPartialReductions = | ||||||
[&PartialReductionBinOps](Instruction *Extend) { | ||||||
return all_of(Extend->users(), [&](const User *U) { | ||||||
return PartialReductionBinOps.contains(U); | ||||||
}); | ||||||
}; | ||||||
|
||||||
// Check if each use of a chain's two extends is a partial reduction | ||||||
// and only add those that don't have non-partial reduction users. | ||||||
for (auto Pair : PartialReductionChains) { | ||||||
PartialReductionChain Chain = Pair.first; | ||||||
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) && | ||||||
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)) | ||||||
ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair)); | ||||||
} | ||||||
} | ||||||
|
||||||
std::optional<std::pair<PartialReductionChain, unsigned>> | ||||||
VPRecipeBuilder::getScaledReduction(PHINode *PHI, | ||||||
const RecurrenceDescriptor &Rdx, | ||||||
VFRange &Range) { | ||||||
// TODO: Allow scaling reductions when predicating. The select at | ||||||
// the end of the loop chooses between the phi value and most recent | ||||||
// reduction result, both of which have different VFs to the active lane | ||||||
// mask when scaling. | ||||||
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent())) | ||||||
return std::nullopt; | ||||||
|
||||||
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr()); | ||||||
if (!Update) | ||||||
return std::nullopt; | ||||||
|
||||||
Value *Op = Update->getOperand(0); | ||||||
if (Op == PHI) | ||||||
Op = Update->getOperand(1); | ||||||
|
||||||
auto *BinOp = dyn_cast<BinaryOperator>(Op); | ||||||
if (!BinOp || !BinOp->hasOneUse()) | ||||||
return std::nullopt; | ||||||
|
||||||
using namespace llvm::PatternMatch; | ||||||
Value *A, *B; | ||||||
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) || | ||||||
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B)))) | ||||||
return std::nullopt; | ||||||
|
||||||
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0)); | ||||||
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1)); | ||||||
|
||||||
// Check that the extends extend from the same type. | ||||||
if (A->getType() != B->getType()) | ||||||
return std::nullopt; | ||||||
|
||||||
TTI::PartialReductionExtendKind OpAExtend = | ||||||
TargetTransformInfo::getPartialReductionExtendKind(ExtA); | ||||||
TTI::PartialReductionExtendKind OpBExtend = | ||||||
TargetTransformInfo::getPartialReductionExtendKind(ExtB); | ||||||
|
||||||
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp); | ||||||
|
||||||
unsigned TargetScaleFactor = | ||||||
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor( | ||||||
A->getType()->getPrimitiveSizeInBits()); | ||||||
|
||||||
if (LoopVectorizationPlanner::getDecisionAndClampRange( | ||||||
[&](ElementCount VF) { | ||||||
InstructionCost Cost = TTI->getPartialReductionCost( | ||||||
Update->getOpcode(), A->getType(), PHI->getType(), VF, | ||||||
OpAExtend, OpBExtend, std::make_optional(BinOp->getOpcode())); | ||||||
return Cost.isValid(); | ||||||
}, | ||||||
Range)) | ||||||
return std::make_pair(Chain, TargetScaleFactor); | ||||||
|
||||||
return std::nullopt; | ||||||
} | ||||||
|
||||||
VPRecipeBase * | ||||||
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, | ||||||
ArrayRef<VPValue *> Operands, | ||||||
|
@@ -8852,9 +8953,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, | |||||
Legal->getReductionVars().find(Phi)->second; | ||||||
assert(RdxDesc.getRecurrenceStartValue() == | ||||||
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader())); | ||||||
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV, | ||||||
CM.isInLoopReduction(Phi), | ||||||
CM.useOrderedReductions(RdxDesc)); | ||||||
|
||||||
// If the PHI is used by a partial reduction, set the scale factor. | ||||||
std::optional<std::pair<PartialReductionChain, unsigned>> Pair = | ||||||
getScaledReductionForInstr(RdxDesc.getLoopExitInstr()); | ||||||
unsigned ScaleFactor = Pair ? Pair->second : 1; | ||||||
PhiRecipe = new VPReductionPHIRecipe( | ||||||
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi), | ||||||
CM.useOrderedReductions(RdxDesc), ScaleFactor); | ||||||
} else { | ||||||
// TODO: Currently fixed-order recurrences are modeled as chains of | ||||||
// first-order recurrences. If there are no users of the intermediate | ||||||
|
@@ -8886,6 +8992,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, | |||||
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr)) | ||||||
return tryToWidenMemory(Instr, Operands, Range); | ||||||
|
||||||
if (getScaledReductionForInstr(Instr)) | ||||||
return tryToCreatePartialReduction(Instr, Operands); | ||||||
|
||||||
if (!shouldWiden(Instr, Range)) | ||||||
return nullptr; | ||||||
|
||||||
|
@@ -8906,6 +9015,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, | |||||
return tryToWiden(Instr, Operands, VPBB); | ||||||
} | ||||||
|
||||||
VPRecipeBase * | ||||||
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction, | ||||||
ArrayRef<VPValue *> Operands) { | ||||||
assert(Operands.size() == 2 && | ||||||
"Unexpected number of operands for partial reduction"); | ||||||
|
||||||
VPValue *BinOp = Operands[0]; | ||||||
VPValue *Phi = Operands[1]; | ||||||
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe())) | ||||||
std::swap(BinOp, Phi); | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is already an assertion in the recipe constructor, since that is the common place for initialising the recipes. |
||||||
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi, | ||||||
Reduction); | ||||||
} | ||||||
|
||||||
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, | ||||||
ElementCount MaxVF) { | ||||||
assert(OrigLoop->isInnermost() && "Inner loop expected."); | ||||||
|
@@ -9223,7 +9347,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { | |||||
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None; | ||||||
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL); | ||||||
|
||||||
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder); | ||||||
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE, | ||||||
Builder); | ||||||
|
||||||
// --------------------------------------------------------------------------- | ||||||
// Pre-construction: record ingredients whose recipes we'll need to further | ||||||
|
@@ -9269,6 +9394,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { | |||||
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty(); | ||||||
return Legal->blockNeedsPredication(BB) || NeedsBlends; | ||||||
}); | ||||||
|
||||||
RecipeBuilder.collectScaledReductions(Range); | ||||||
|
||||||
auto *MiddleVPBB = Plan->getMiddleBlock(); | ||||||
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi(); | ||||||
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also add comment from below here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.