Skip to content

[NFC][TargetTransformInfo][VectorUtils] Consolidate isVectorIntrinsic... api #117635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 19, 2024
28 changes: 21 additions & 7 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -899,14 +899,20 @@ class TargetTransformInfo {

bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;

/// Identifies if the vector form of the intrinsic has a scalar operand.
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) const;

/// Identifies if the vector form of the intrinsic is overloaded on the type
/// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is
/// -1.
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) const;
bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) const;

/// Identifies if the vector form of the intrinsic that returns a struct is
/// overloaded at the struct element index \p RetIdx.
bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) const;

/// Estimate the overhead of scalarizing an instruction. Insert and Extract
/// are set if the demanded result elements need to be inserted and/or
Expand Down Expand Up @@ -2002,8 +2008,11 @@ class TargetTransformInfo::Concept {
virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) = 0;
virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) = 0;
virtual bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) = 0;
virtual bool
isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) = 0;
virtual InstructionCost
getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract, TargetCostKind CostKind,
Expand Down Expand Up @@ -2580,9 +2589,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
}

bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) override {
return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) override {
return Impl.isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
}

bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) override {
return Impl.isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
}

InstructionCost getScalarizationOverhead(VectorType *Ty,
Expand Down
11 changes: 8 additions & 3 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,14 @@ class TargetTransformInfoImplBase {
return false;
}

bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) const {
return ScalarOpdIdx == -1;
bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) const {
return OpdIdx == -1;
}

bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) const {
return RetIdx == 0;
}

InstructionCost getScalarizationOverhead(VectorType *Ty,
Expand Down
26 changes: 21 additions & 5 deletions llvm/include/llvm/Analysis/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,25 @@ inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
/// This method returns true if the intrinsic's argument types are all scalars
/// for the scalar form of the intrinsic and all vectors (or scalars handled by
/// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
///
/// Note: isTriviallyVectorizable implies isTriviallyScalarizable.
bool isTriviallyVectorizable(Intrinsic::ID ID);

/// Identify if the intrinsic is trivially scalarizable.
/// This method returns true following the same predicates of
/// isTriviallyVectorizable.

/// Note: There are intrinsics where implementing vectorization for the
/// intrinsic is redundant, but we want to implement scalarization of the
/// vector. To prevent the requirement that an intrinsic also implements
/// vectorization we provide this seperate function.
bool isTriviallyScalarizable(Intrinsic::ID ID, const TargetTransformInfo *TTI);

/// Identifies if the vector form of the intrinsic has a scalar operand.
bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx);
/// \p TTI is used to consider target specific intrinsics, if no target specific
/// intrinsics will be considered then it is appropriate to pass in nullptr.
bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx,
const TargetTransformInfo *TTI);

/// Identifies if the vector form of the intrinsic is overloaded on the type of
/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
Expand All @@ -158,9 +172,11 @@ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
const TargetTransformInfo *TTI);

/// Identifies if the vector form of the intrinsic that returns a struct is
/// overloaded at the struct element index \p RetIdx.
bool isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx);
/// overloaded at the struct element index \p RetIdx. /// \p TTI is used to
/// consider target specific intrinsics, if no target specific intrinsics
/// will be considered then it is appropriate to pass in nullptr.
bool isVectorIntrinsicWithStructReturnOverloadAtField(
Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI);

/// Returns intrinsic ID for call.
/// For the input call instruction it finds mapping intrinsic and returns
Expand Down
11 changes: 8 additions & 3 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return false;
}

bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) const {
return ScalarOpdIdx == -1;
bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) const {
return OpdIdx == -1;
}

bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) const {
return RetIdx == 0;
}

/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3447,7 +3447,7 @@ static Constant *ConstantFoldFixedVectorCall(
// Gather a column of constants.
for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) {
// Some intrinsics use a scalar type for certain arguments.
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J)) {
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J, /*TTI=*/nullptr)) {
Lane[J] = Operands[J];
continue;
}
Expand Down
11 changes: 8 additions & 3 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,14 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
}

bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
Intrinsic::ID ID, int ScalarOpdIdx) const {
return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
bool TargetTransformInfo::isTargetIntrinsicWithOverloadTypeAtArg(
Intrinsic::ID ID, int OpdIdx) const {
return TTIImpl->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
}

bool TargetTransformInfo::isTargetIntrinsicWithStructReturnOverloadAtField(
Intrinsic::ID ID, int RetIdx) const {
return TTIImpl->isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
}

InstructionCost TargetTransformInfo::getScalarizationOverhead(
Expand Down
34 changes: 30 additions & 4 deletions llvm/lib/Analysis/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,31 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) {
}
}

bool llvm::isTriviallyScalarizable(Intrinsic::ID ID,
const TargetTransformInfo *TTI) {
if (isTriviallyVectorizable(ID))
return true;

if (TTI && Intrinsic::isTargetIntrinsic(ID))
return TTI->isTargetIntrinsicTriviallyScalarizable(ID);

// TODO: Move frexp to isTriviallyVectorizable.
// https://github.com/llvm/llvm-project/issues/112408
switch (ID) {
case Intrinsic::frexp:
return true;
}
return false;
}

/// Identifies if the vector form of the intrinsic has a scalar operand.
bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) {
unsigned ScalarOpdIdx,
const TargetTransformInfo *TTI) {

if (TTI && Intrinsic::isTargetIntrinsic(ID))
return TTI->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);

switch (ID) {
case Intrinsic::abs:
case Intrinsic::vp_abs:
Expand All @@ -142,7 +164,7 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");

if (TTI && Intrinsic::isTargetIntrinsic(ID))
return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
return TTI->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);

if (VPCastIntrinsic::isVPCast(ID))
return OpdIdx == -1 || OpdIdx == 0;
Expand All @@ -167,8 +189,12 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
}
}

bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) {
bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(
Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI) {

if (TTI && Intrinsic::isTargetIntrinsic(ID))
return TTI->isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);

switch (ID) {
case Intrinsic::frexp:
return RetIdx == 0 || RetIdx == 1;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/ReplaceWithVeclib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
auto *ArgTy = Arg.value()->getType();
bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
/*TTI=*/nullptr);
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index(), /*TTI=*/nullptr)) {
ScalarArgTypes.push_back(ArgTy);
if (IsOloadTy)
OloadTys.push_back(ArgTy);
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
}
}

bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) {
bool DirectXTTIImpl::isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) {
switch (ID) {
case Intrinsic::dx_asdouble:
return ScalarOpdIdx == 0;
return OpdIdx == 0;
default:
return ScalarOpdIdx == -1;
return OpdIdx == -1;
}
}

Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx);
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx);
bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
};
} // namespace llvm

Expand Down
24 changes: 4 additions & 20 deletions llvm/lib/Transforms/Scalar/Scalarizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,6 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {

bool visit(Function &F);

bool isTriviallyScalarizable(Intrinsic::ID ID);

// InstVisitor methods. They return true if the instruction was scalarized,
// false if nothing changed.
bool visitInstruction(Instruction &I) { return false; }
Expand Down Expand Up @@ -683,19 +681,6 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
return true;
}

bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
if (isTriviallyVectorizable(ID))
return true;
// TODO: Move frexp to isTriviallyVectorizable.
// https://github.com/llvm/llvm-project/issues/112408
switch (ID) {
case Intrinsic::frexp:
return true;
}
return Intrinsic::isTargetIntrinsic(ID) &&
TTI->isTargetIntrinsicTriviallyScalarizable(ID);
}

/// If a call to a vector typed intrinsic function, split into a scalar call per
/// element if possible for the intrinsic.
bool ScalarizerVisitor::splitCall(CallInst &CI) {
Expand All @@ -715,7 +700,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {

Intrinsic::ID ID = F->getIntrinsicID();

if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID))
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI))
return false;

// unsigned NumElems = VT->getNumElements();
Expand Down Expand Up @@ -743,7 +728,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
// will only scalarize when the struct elements have the same bitness.
if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
return false;
if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I))
if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI))
Tys.push_back(CurrVS->SplitTy);
}
}
Expand Down Expand Up @@ -794,8 +779,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
Tys[0] = VS->RemainderTy;

for (unsigned J = 0; J != NumArgs; ++J) {
if (isVectorIntrinsicWithScalarOpAtArg(ID, J) ||
TTI->isTargetIntrinsicWithScalarOpAtArg(ID, J)) {
if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
ScalarCallOps.push_back(ScalarOperands[J]);
} else {
ScalarCallOps.push_back(Scattered[J][I]);
Expand Down Expand Up @@ -1089,7 +1073,7 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
if (!F)
return false;
Intrinsic::ID ID = F->getIntrinsicID();
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID))
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI))
return false;
// Note: Fall through means Operand is a`CallInst` and it is defined in
// `isTriviallyScalarizable`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
auto *SE = PSE.getSE();
Intrinsic::ID IntrinID = getVectorIntrinsicIDForCall(CI, TLI);
for (unsigned Idx = 0; Idx < CI->arg_size(); ++Idx)
if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx)) {
if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx, TTI)) {
if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(Idx)),
TheLoop)) {
reportVectorizationFailure("Found unvectorizable intrinsic",
Expand Down
Loading
Loading