Skip to content

[CVP][LVI] Add support for vectors #97428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions llvm/lib/Analysis/LazyValueInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ LazyValueInfoImpl::solveBlockValueImpl(Value *Val, BasicBlock *BB) {
if (PT && isKnownNonZero(BBI, DL))
return ValueLatticeElement::getNot(ConstantPointerNull::get(PT));

if (BBI->getType()->isIntegerTy()) {
if (BBI->getType()->isIntOrIntVectorTy()) {
if (auto *CI = dyn_cast<CastInst>(BBI))
return solveBlockValueCast(CI, BB);

Expand Down Expand Up @@ -836,6 +836,24 @@ void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
}
}

static ConstantRange getConstantRangeFromVector(Constant *C,
FixedVectorType *Ty) {
unsigned BW = Ty->getScalarSizeInBits();
ConstantRange CR = ConstantRange::getEmpty(BW);
for (unsigned I = 0; I < Ty->getNumElements(); ++I) {
Constant *Elem = C->getAggregateElement(I);
if (!Elem)
return ConstantRange::getFull(BW);
if (isa<PoisonValue>(Elem))
continue;
auto *CI = dyn_cast<ConstantInt>(Elem);
if (!CI)
return ConstantRange::getFull(BW);
CR = CR.unionWith(CI->getValue());
}
return CR;
}

static ConstantRange toConstantRange(const ValueLatticeElement &Val,
Type *Ty, bool UndefAllowed = false) {
assert(Ty->isIntOrIntVectorTy() && "Must be integer type");
Expand All @@ -844,6 +862,9 @@ static ConstantRange toConstantRange(const ValueLatticeElement &Val,
unsigned BW = Ty->getScalarSizeInBits();
if (Val.isUnknown())
return ConstantRange::getEmpty(BW);
if (Val.isConstant())
if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
return getConstantRangeFromVector(Val.getConstant(), VTy);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (Val.isConstant())
if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
return getConstantRangeFromVector(Val.getConstant(), VTy);
if (Val.isConstant()) {
if (auto *CI = dyn_cast_or_null<ConstantInt>(Val.getConstant()->getSplatValue(/*AllowPoison=*/true)))
return ConstantRange(CI->getValue());
if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
return getConstantRangeFromVector(Val.getConstant(), VTy);
}

Do you have plan to support constant scalable vector splats?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I've added scalable vectors support.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just using computeConstantRange to do it?
AFAICT the only case they handle differently is poison elements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to use computeConstantRange() here, but extracting a common helper would make sense as a followup, as part of making computeConstantRange() handle poison elements.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah looks like you beat me too it

return ConstantRange::getFull(BW);
}

Expand Down Expand Up @@ -968,7 +989,7 @@ LazyValueInfoImpl::solveBlockValueCast(CastInst *CI, BasicBlock *BB) {
return std::nullopt;
const ConstantRange &LHSRange = *LHSRes;

const unsigned ResultBitWidth = CI->getType()->getIntegerBitWidth();
const unsigned ResultBitWidth = CI->getType()->getScalarSizeInBits();

// NOTE: We're currently limited by the set of operations that ConstantRange
// can evaluate symbolically. Enhancing that set will allows us to analyze
Expand Down Expand Up @@ -1108,7 +1129,7 @@ LazyValueInfoImpl::getValueFromSimpleICmpCondition(CmpInst::Predicate Pred,
const APInt &Offset,
Instruction *CxtI,
bool UseBlockValue) {
ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(),
ConstantRange RHSRange(RHS->getType()->getScalarSizeInBits(),
/*isFullSet=*/true);
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
RHSRange = ConstantRange(CI->getValue());
Expand Down Expand Up @@ -1728,7 +1749,6 @@ Constant *LazyValueInfo::getConstant(Value *V, Instruction *CxtI) {

ConstantRange LazyValueInfo::getConstantRange(Value *V, Instruction *CxtI,
bool UndefAllowed) {
assert(V->getType()->isIntegerTy());
BasicBlock *BB = CxtI->getParent();
ValueLatticeElement Result =
getOrCreateImpl(BB->getModule()).getValueInBlock(V, BB, CxtI);
Expand Down
53 changes: 10 additions & 43 deletions llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,8 @@ static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT,
}

static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) {
// Only for signed relational comparisons of scalar integers.
if (Cmp->getType()->isVectorTy() ||
!Cmp->getOperand(0)->getType()->isIntegerTy())
// Only for signed relational comparisons of integers.
if (!Cmp->getOperand(0)->getType()->isIntOrIntVectorTy())
return false;

if (!Cmp->isSigned())
Expand Down Expand Up @@ -505,12 +504,8 @@ static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI);
// because it is negation-invariant.
static bool processAbsIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) {
Value *X = II->getArgOperand(0);
Type *Ty = X->getType();
if (!Ty->isIntegerTy())
return false;

bool IsIntMinPoison = cast<ConstantInt>(II->getArgOperand(1))->isOne();
APInt IntMin = APInt::getSignedMinValue(Ty->getScalarSizeInBits());
APInt IntMin = APInt::getSignedMinValue(X->getType()->getScalarSizeInBits());
ConstantRange Range = LVI->getConstantRangeAtUse(
II->getOperandUse(0), /*UndefAllowed*/ IsIntMinPoison);

Expand Down Expand Up @@ -679,15 +674,13 @@ static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) {
}

if (auto *WO = dyn_cast<WithOverflowInst>(&CB)) {
if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) {
if (willNotOverflow(WO, LVI))
return processOverflowIntrinsic(WO, LVI);
}
}

if (auto *SI = dyn_cast<SaturatingInst>(&CB)) {
if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) {
if (willNotOverflow(SI, LVI))
return processSaturatingInst(SI, LVI);
}
}

bool Changed = false;
Expand Down Expand Up @@ -761,11 +754,10 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR,
const ConstantRange &RCR) {
assert(Instr->getOpcode() == Instruction::SDiv ||
Instr->getOpcode() == Instruction::SRem);
assert(!Instr->getType()->isVectorTy());

// Find the smallest power of two bitwidth that's sufficient to hold Instr's
// operands.
unsigned OrigWidth = Instr->getType()->getIntegerBitWidth();
unsigned OrigWidth = Instr->getType()->getScalarSizeInBits();

// What is the smallest bit width that can accommodate the entire value ranges
// of both of the operands?
Expand All @@ -788,7 +780,7 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR,

++NumSDivSRemsNarrowed;
IRBuilder<> B{Instr};
auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth);
auto *TruncTy = Instr->getType()->getWithNewBitWidth(NewWidth);
auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy,
Instr->getName() + ".lhs.trunc");
auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy,
Expand All @@ -809,7 +801,6 @@ static bool expandUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,
Type *Ty = Instr->getType();
assert(Instr->getOpcode() == Instruction::UDiv ||
Instr->getOpcode() == Instruction::URem);
assert(!Ty->isVectorTy());
bool IsRem = Instr->getOpcode() == Instruction::URem;

Value *X = Instr->getOperand(0);
Expand Down Expand Up @@ -892,7 +883,6 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,
const ConstantRange &YCR) {
assert(Instr->getOpcode() == Instruction::UDiv ||
Instr->getOpcode() == Instruction::URem);
assert(!Instr->getType()->isVectorTy());

// Find the smallest power of two bitwidth that's sufficient to hold Instr's
// operands.
Expand All @@ -905,12 +895,12 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,

// NewWidth might be greater than OrigWidth if OrigWidth is not a power of
// two.
if (NewWidth >= Instr->getType()->getIntegerBitWidth())
if (NewWidth >= Instr->getType()->getScalarSizeInBits())
return false;

++NumUDivURemsNarrowed;
IRBuilder<> B{Instr};
auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth);
auto *TruncTy = Instr->getType()->getWithNewBitWidth(NewWidth);
auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy,
Instr->getName() + ".lhs.trunc");
auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy,
Expand All @@ -929,9 +919,6 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,
static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) {
assert(Instr->getOpcode() == Instruction::UDiv ||
Instr->getOpcode() == Instruction::URem);
if (Instr->getType()->isVectorTy())
return false;

ConstantRange XCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(0),
/*UndefAllowed*/ false);
// Allow undef for RHS, as we can assume it is division by zero UB.
Expand All @@ -946,7 +933,6 @@ static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) {
static bool processSRem(BinaryOperator *SDI, const ConstantRange &LCR,
const ConstantRange &RCR, LazyValueInfo *LVI) {
assert(SDI->getOpcode() == Instruction::SRem);
assert(!SDI->getType()->isVectorTy());

if (LCR.abs().icmp(CmpInst::ICMP_ULT, RCR.abs())) {
SDI->replaceAllUsesWith(SDI->getOperand(0));
Expand Down Expand Up @@ -1006,7 +992,6 @@ static bool processSRem(BinaryOperator *SDI, const ConstantRange &LCR,
static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR,
const ConstantRange &RCR, LazyValueInfo *LVI) {
assert(SDI->getOpcode() == Instruction::SDiv);
assert(!SDI->getType()->isVectorTy());

// Check whether the division folds to a constant.
ConstantRange DivCR = LCR.sdiv(RCR);
Expand Down Expand Up @@ -1064,9 +1049,6 @@ static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR,
static bool processSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) {
assert(Instr->getOpcode() == Instruction::SDiv ||
Instr->getOpcode() == Instruction::SRem);
if (Instr->getType()->isVectorTy())
return false;

ConstantRange LCR =
LVI->getConstantRangeAtUse(Instr->getOperandUse(0), /*AllowUndef*/ false);
// Allow undef for RHS, as we can assume it is division by zero UB.
Expand All @@ -1085,12 +1067,9 @@ static bool processSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) {
}

static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
if (SDI->getType()->isVectorTy())
return false;

ConstantRange LRange =
LVI->getConstantRangeAtUse(SDI->getOperandUse(0), /*UndefAllowed*/ false);
unsigned OrigWidth = SDI->getType()->getIntegerBitWidth();
unsigned OrigWidth = SDI->getType()->getScalarSizeInBits();
ConstantRange NegOneOrZero =
ConstantRange(APInt(OrigWidth, (uint64_t)-1, true), APInt(OrigWidth, 1));
if (NegOneOrZero.contains(LRange)) {
Expand All @@ -1117,9 +1096,6 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
}

static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) {
if (SDI->getType()->isVectorTy())
return false;

const Use &Base = SDI->getOperandUse(0);
if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false)
.isAllNonNegative())
Expand All @@ -1138,9 +1114,6 @@ static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) {
}

static bool processPossibleNonNeg(PossiblyNonNegInst *I, LazyValueInfo *LVI) {
if (I->getType()->isVectorTy())
return false;

if (I->hasNonNeg())
return false;

Expand All @@ -1164,9 +1137,6 @@ static bool processUIToFP(UIToFPInst *UIToFP, LazyValueInfo *LVI) {
}

static bool processSIToFP(SIToFPInst *SIToFP, LazyValueInfo *LVI) {
if (SIToFP->getType()->isVectorTy())
return false;

const Use &Base = SIToFP->getOperandUse(0);
if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false)
.isAllNonNegative())
Expand All @@ -1187,9 +1157,6 @@ static bool processSIToFP(SIToFPInst *SIToFP, LazyValueInfo *LVI) {
static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) {
using OBO = OverflowingBinaryOperator;

if (BinOp->getType()->isVectorTy())
return false;

bool NSW = BinOp->hasNoSignedWrap();
bool NUW = BinOp->hasNoUnsignedWrap();
if (NSW && NUW)
Expand Down
4 changes: 1 addition & 3 deletions llvm/test/Transforms/CorrelatedValuePropagation/icmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1246,13 +1246,11 @@ define i1 @non_const_range_minmax(i8 %a, i8 %b) {
ret i1 %cmp1
}

; FIXME: Also support vectors.
define <2 x i1> @non_const_range_minmax_vec(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @non_const_range_minmax_vec(
; CHECK-NEXT: [[A2:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> <i8 10, i8 10>)
; CHECK-NEXT: [[B2:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[B:%.*]], <2 x i8> <i8 11, i8 11>)
; CHECK-NEXT: [[CMP1:%.*]] = icmp ult <2 x i8> [[A2]], [[B2]]
; CHECK-NEXT: ret <2 x i1> [[CMP1]]
; CHECK-NEXT: ret <2 x i1> <i1 true, i1 true>
;
%a2 = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %a, <2 x i8> <i8 10, i8 10>)
%b2 = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %b, <2 x i8> <i8 11, i8 11>)
Expand Down
Loading
Loading