Skip to content

[LV][SVE] Recognize potential DOT sequences and use a wider VF #69587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,10 @@ class TargetTransformInfo {
/// Return true if the target supports masked expand load.
bool isLegalMaskedExpandLoad(Type *DataType) const;

/// Returns true if the types are legal for DOT product instructions on
/// the target (extend->multiply->accumulate)
bool isLegalDotProd(Type *DataType, Type *ExtType) const;

/// Return true if this is an alternating opcode pattern that can be lowered
/// to a single instruction on the target. In X86 this is for the addsub
/// instruction which corrsponds to a Shuffle + Fadd + FSub pattern in IR.
Expand Down Expand Up @@ -1787,6 +1791,7 @@ class TargetTransformInfo::Concept {
Align Alignment) = 0;
virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
virtual bool isLegalDotProd(Type *DataType, Type *ExtType) = 0;
virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
unsigned Opcode1,
const SmallBitVector &OpcodeMask) const = 0;
Expand Down Expand Up @@ -2267,6 +2272,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
bool isLegalMaskedExpandLoad(Type *DataType) override {
return Impl.isLegalMaskedExpandLoad(DataType);
}
bool isLegalDotProd(Type *DataType, Type *ExtType) override {
return Impl.isLegalDotProd(DataType, ExtType);
}
bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
const SmallBitVector &OpcodeMask) const override {
return Impl.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask);
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ class TargetTransformInfoImplBase {

bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }

bool isLegalDotProd(Type *DataType, Type *ExtType) const { return false; }

bool enableOrderedReductions() const { return false; }

bool hasDivRemOp(Type *DataType, bool IsSigned) const { return false; }
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
return TTIImpl->isLegalMaskedExpandLoad(DataType);
}

bool TargetTransformInfo::isLegalDotProd(Type *DataType, Type *ExtType) const {
return TTIImpl->isLegalDotProd(DataType, ExtType);
}

bool TargetTransformInfo::enableOrderedReductions() const {
return TTIImpl->enableOrderedReductions();
}
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {

unsigned getMaxInterleaveFactor(ElementCount VF);

// TODO: NEON should be able to support this after... 8.3 or so?
// Need to make sure that the input type is either i8 or i16, and that
// the extended type is at most the accumulator type of the dot product
// instructions so that we don't lose data.
bool isLegalDotProd(Type *DataType, Type *ExtType) const {
return ST->hasSVE() && ((DataType->isIntegerTy(8) &&
ExtType->getPrimitiveSizeInBits() <= 32) ||
(DataType->isIntegerTy(16) &&
ExtType->getPrimitiveSizeInBits() <= 64));
}

bool prefersVectorizedAddressing() const;

InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
Expand Down
38 changes: 38 additions & 0 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/DemandedBits.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/LoopAccessAnalysis.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
Expand Down Expand Up @@ -1921,6 +1922,9 @@ class LoopVectorizationCostModel {

/// All element types found in the loop.
SmallPtrSet<Type *, 16> ElementTypesInLoop;

/// Extends used as part of a dot-product chain; these are 'free'.
SmallPtrSet<Value *, 2> DotExtends;
};
} // end namespace llvm

Expand Down Expand Up @@ -5580,6 +5584,7 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() {
}

void LoopVectorizationCostModel::collectElementTypesForWidening() {
using namespace llvm::PatternMatch;
ElementTypesInLoop.clear();
// For each block.
for (BasicBlock *BB : TheLoop->blocks()) {
Expand Down Expand Up @@ -5607,6 +5612,34 @@ void LoopVectorizationCostModel::collectElementTypesForWidening() {
RdxDesc.getRecurrenceType(),
TargetTransformInfo::ReductionFlags()))
continue;
// DOT Prod proto...
if (RdxDesc.getRecurrenceKind() == RecurKind::Add) {
Instruction *Sum = RdxDesc.getLoopExitInstr();
Value *Accum = Legal->getReductionVars().find(PN)->first;

if (!Accum->hasOneUse() || !Sum->hasNUses(2))
continue;

Value *Step = (Sum->getOperand(0) == Accum) ? Sum->getOperand(1)
: Sum->getOperand(0);
Value *ValA = nullptr, *ValB = nullptr;

if (match(Step,
m_OneUse(m_Mul(m_ZExtOrSExt(m_OneUse(m_Value(ValA))),
m_ZExtOrSExt(m_OneUse(m_Value(ValB)))))) &&
(ValA->getType() == ValB->getType()) &&
TTI.isLegalDotProd(ValA->getType(), Step->getType())) {
Instruction *I = cast<Instruction>(Step);

// Make sure the extends are only used by the multiply.
if (I->getOperand(0)->hasOneUser() &&
I->getOperand(1)->hasOneUser()) {
DotExtends.insert(I->getOperand(0));
DotExtends.insert(I->getOperand(1));
continue;
}
}
}
T = RdxDesc.getRecurrenceType();
}

Expand Down Expand Up @@ -7351,6 +7384,11 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
CCH = ComputeCCH(Load);
}

// Extensions used in dot product calculations are 'free', since the
// dot instruction performs that operation internally before multiplying
if (DotExtends.contains(I))
return 0;

// We optimize the truncation of induction variables having constant
// integer steps. The cost of these truncations is the same as the scalar
// operation.
Expand Down
Loading