@@ -2610,7 +2610,7 @@ static Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) {
26102610
26112611// / Get the stride of a pointer access in a loop. Looks for symbolic
26122612// / strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
2613- static Value *getStrideFromPointer (Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
2613+ static const SCEV *getStrideFromPointer (Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
26142614 auto *PtrTy = dyn_cast<PointerType>(Ptr->getType ());
26152615 if (!PtrTy || PtrTy->isAggregateType ())
26162616 return nullptr ;
@@ -2664,28 +2664,27 @@ static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
26642664 }
26652665 }
26662666
2667- // Strip off casts.
2668- Type *StripedOffRecurrenceCast = nullptr ;
2669- if (const SCEVIntegralCastExpr *C = dyn_cast<SCEVIntegralCastExpr>(V)) {
2670- StripedOffRecurrenceCast = C->getType ();
2671- V = C->getOperand ();
2672- }
2667+ // Note that the restriction after this loop invariant check are only
2668+ // profitability restrictions.
2669+ if (!SE->isLoopInvariant (V, Lp))
2670+ return nullptr ;
26732671
26742672 // Look for the loop invariant symbolic value.
26752673 const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V);
2676- if (!U)
2677- return nullptr ;
2674+ if (!U) {
2675+ const auto *C = dyn_cast<SCEVIntegralCastExpr>(V);
2676+ if (!C)
2677+ return nullptr ;
2678+ U = dyn_cast<SCEVUnknown>(C->getOperand ());
2679+ if (!U)
2680+ return nullptr ;
26782681
2679- Value *Stride = U->getValue ();
2680- if (!Lp->isLoopInvariant (Stride))
2681- return nullptr ;
2682-
2683- // If we have stripped off the recurrence cast we have to make sure that we
2684- // return the value that is used in this loop so that we can replace it later.
2685- if (StripedOffRecurrenceCast)
2686- Stride = getUniqueCastUse (Stride, Lp, StripedOffRecurrenceCast);
2682+ // Match legacy behavior - this is not needed for correctness
2683+ if (!getUniqueCastUse (U->getValue (), Lp, V->getType ()))
2684+ return nullptr ;
2685+ }
26872686
2688- return Stride ;
2687+ return V ;
26892688}
26902689
26912690void LoopAccessInfo::collectStridedAccess (Value *MemAccess) {
@@ -2699,13 +2698,13 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
26992698 // computation of an interesting IV - but we chose not to as we
27002699 // don't have a cost model here, and broadening the scope exposes
27012700 // far too many unprofitable cases.
2702- Value *Stride = getStrideFromPointer (Ptr, PSE->getSE (), TheLoop);
2703- if (!Stride )
2701+ const SCEV *StrideExpr = getStrideFromPointer (Ptr, PSE->getSE (), TheLoop);
2702+ if (!StrideExpr )
27042703 return ;
27052704
27062705 LLVM_DEBUG (dbgs () << " LAA: Found a strided access that is a candidate for "
27072706 " versioning:" );
2708- LLVM_DEBUG (dbgs () << " Ptr: " << *Ptr << " Stride: " << *Stride << " \n " );
2707+ LLVM_DEBUG (dbgs () << " Ptr: " << *Ptr << " Stride: " << *StrideExpr << " \n " );
27092708
27102709 if (!SpeculateUnitStride) {
27112710 LLVM_DEBUG (dbgs () << " Chose not to due to -laa-speculate-unit-stride\n " );
@@ -2725,7 +2724,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
27252724 // of various possible stride specializations, considering the alternatives
27262725 // of using gather/scatters (if available).
27272726
2728- const SCEV *StrideExpr = PSE->getSCEV (Stride);
27292727 const SCEV *BETakenCount = PSE->getBackedgeTakenCount ();
27302728
27312729 // Match the types so we can compare the stride and the BETakenCount.
@@ -2756,8 +2754,10 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
27562754
27572755 // Strip back off the integer cast, and check that our result is a
27582756 // SCEVUnknown as we expect.
2759- Value *StrideVal = stripIntegerCast (Stride);
2760- SymbolicStrides[Ptr] = cast<SCEVUnknown>(PSE->getSCEV (StrideVal));
2757+ const SCEV *StrideBase = StrideExpr;
2758+ if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
2759+ StrideBase = C->getOperand ();
2760+ SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
27612761}
27622762
27632763LoopAccessInfo::LoopAccessInfo (Loop *L, ScalarEvolution *SE,
0 commit comments