diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 1ebc62f984390..568aeae2260f1 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5818,6 +5818,15 @@ LoopVectorizationCostModel::getReductionPatternCost( if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) && RetI->user_back()->getOpcode() == Instruction::Add) { RetI = RetI->user_back(); + } else if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) && + ((match(I, m_ZExt(m_Value())) && + match(RetI->user_back(), m_OneUse(m_ZExt(m_Value())))) || + (match(I, m_SExt(m_Value())) && + match(RetI->user_back(), m_OneUse(m_SExt(m_Value()))))) && + RetI->user_back()->user_back()->getOpcode() == Instruction::Add) { + // This looks through ext(mul(ext, ext)), making sure that the extensions + // are the same sign. + RetI = RetI->user_back()->user_back(); } // Test if the found instruction is a reduction, and if not return an invalid @@ -7316,7 +7325,7 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF, // Also include the operands of instructions in the chain, as the cost-model // may mark extends as free. // - // For ARM, some of the instruction can folded into the reducion + // For ARM, some of the instructions can be folded into the reduction // instruction. So we need to mark all folded instructions free. // For example: We can fold reduce(mul(ext(A), ext(B))) into one // instruction. @@ -7324,6 +7333,10 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF, for (Value *Op : ChainOp->operands()) { if (auto *I = dyn_cast(Op)) { ChainOpsAndOperands.insert(I); + if (IsZExtOrSExt(I->getOpcode())) { + ChainOpsAndOperands.insert(I); + I = dyn_cast(I->getOperand(0)); + } if (I->getOpcode() == Instruction::Mul) { auto *Ext0 = dyn_cast(I->getOperand(0)); auto *Ext1 = dyn_cast(I->getOperand(1)); diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll index c115c91cff896..a4f96adccb64b 100644 --- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll +++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll @@ -1722,10 +1722,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 { ; CHECK-NEXT: [[TMP0:%.*]] = add nsw i32 [[N]], -1 ; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 1 ; CHECK-NEXT: [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1 -; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7 +; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15 ; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]] ; CHECK: vector.ph: -; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -4 +; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -8 ; CHECK-NEXT: [[IND_END:%.*]] = shl i32 [[N_VEC]], 1 ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -1733,26 +1733,26 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 { ; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1 ; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]] -; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2 -; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> -; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> -; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32> +; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2 +; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> +; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]] -; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2 -; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> -; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> -; CHECK-NEXT: [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32> -; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]] -; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64> -; CHECK-NEXT: [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32> -; CHECK-NEXT: [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32> -; CHECK-NEXT: [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]] -; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64> -; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]]) +; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2 +; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> +; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP6]], [[TMP5]] +; CHECK-NEXT: [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64> +; CHECK-NEXT: [[TMP13:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32> +; CHECK-NEXT: [[TMP14:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32> +; CHECK-NEXT: [[TMP11:%.*]] = mul nsw <8 x i32> [[TMP14]], [[TMP13]] +; CHECK-NEXT: [[TMP12:%.*]] = sext <8 x i32> [[TMP11]] to <8 x i64> +; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]]) ; CHECK-NEXT: [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]] -; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]]) +; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP12]]) ; CHECK-NEXT: [[TMP16]] = add i64 [[TMP15]], [[TMP10]] -; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8 ; CHECK-NEXT: [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]] ; CHECK: middle.block: