Skip to content

Commit 160c237

Browse files
committed
[VectorCombine][X86] foldShuffleOfCastops - fold shuffle(cast(x),cast(y)) -> cast(shuffle(x,y)) iff cost efficient
Based off the existing foldShuffleOfBinops fold Fixes #67803
1 parent 4d8a3f5 commit 160c237

File tree

3 files changed

+86
-37
lines changed

3 files changed

+86
-37
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

+59
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class VectorCombine {
112112
bool foldSingleElementStore(Instruction &I);
113113
bool scalarizeLoadExtract(Instruction &I);
114114
bool foldShuffleOfBinops(Instruction &I);
115+
bool foldShuffleOfCastops(Instruction &I);
115116
bool foldShuffleFromReductions(Instruction &I);
116117
bool foldTruncFromReductions(Instruction &I);
117118
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
@@ -1432,6 +1433,63 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
14321433
return true;
14331434
}
14341435

1436+
/// Try to convert "shuffle (castop), (castop)" with a shared castop operand into
1437+
/// "castop (shuffle)".
1438+
bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
1439+
Value *V0, *V1;
1440+
ArrayRef<int> Mask;
1441+
if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)),
1442+
m_Mask(Mask))))
1443+
return false;
1444+
1445+
auto *C0 = dyn_cast<CastInst>(V0);
1446+
auto *C1 = dyn_cast<CastInst>(V1);
1447+
if (!C0 || !C1)
1448+
return false;
1449+
1450+
// TODO: Handle shuffle(zext_nneg(x), sext(y)) folds.
1451+
Instruction::CastOps Opcode = C0->getOpcode();
1452+
if (Opcode == Instruction::BitCast || Opcode != C1->getOpcode() ||
1453+
C0->getSrcTy() != C1->getSrcTy())
1454+
return false;
1455+
1456+
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1457+
auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
1458+
auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
1459+
if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
1460+
return false;
1461+
assert(CastDstTy->getElementCount() == CastSrcTy->getElementCount() &&
1462+
"Unexpected src/dst element counts");
1463+
1464+
auto *NewShuffleDstTy =
1465+
FixedVectorType::get(CastSrcTy->getScalarType(), Mask.size());
1466+
1467+
// Try to replace a castop with a shuffle if the shuffle is not costly.
1468+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1469+
1470+
InstructionCost OldCost =
1471+
TTI.getCastInstrCost(Opcode, CastDstTy, CastSrcTy,
1472+
TTI::CastContextHint::None, CostKind) +
1473+
TTI.getCastInstrCost(Opcode, CastDstTy, CastSrcTy,
1474+
TTI::CastContextHint::None, CostKind);
1475+
OldCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
1476+
CastDstTy, Mask, CostKind);
1477+
1478+
InstructionCost NewCost = TTI.getShuffleCost(
1479+
TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, Mask, CostKind);
1480+
NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
1481+
TTI::CastContextHint::None, CostKind);
1482+
if (NewCost > OldCost)
1483+
return false;
1484+
1485+
Value *Shuf =
1486+
Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0), Mask);
1487+
Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
1488+
// TODO Copy IR flags?
1489+
replaceValue(I, *Cast);
1490+
return true;
1491+
}
1492+
14351493
/// Given a commutative reduction, the order of the input lanes does not alter
14361494
/// the results. We can use this to remove certain shuffles feeding the
14371495
/// reduction, removing the need to shuffle at all.
@@ -1986,6 +2044,7 @@ bool VectorCombine::run() {
19862044
break;
19872045
case Instruction::ShuffleVector:
19882046
MadeChange |= foldShuffleOfBinops(I);
2047+
MadeChange |= foldShuffleOfCastops(I);
19892048
MadeChange |= foldSelectShuffle(I);
19902049
break;
19912050
case Instruction::BitCast:

llvm/test/Transforms/PhaseOrdering/X86/pr67803.ll

+1-5
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@ define <4 x i64> @PR67803(<4 x i64> %x, <4 x i64> %y, <4 x i64> %a, <4 x i64> %b
99
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x i64> [[X:%.*]] to <8 x i32>
1010
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i64> [[Y:%.*]] to <8 x i32>
1111
; CHECK-NEXT: [[TMP2:%.*]] = icmp sgt <8 x i32> [[TMP0]], [[TMP1]]
12-
; CHECK-NEXT: [[CMP_I21:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
13-
; CHECK-NEXT: [[SEXT_I22:%.*]] = sext <4 x i1> [[CMP_I21]] to <4 x i32>
14-
; CHECK-NEXT: [[CMP_I:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
15-
; CHECK-NEXT: [[SEXT_I:%.*]] = sext <4 x i1> [[CMP_I]] to <4 x i32>
16-
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x i32> [[SEXT_I22]], <4 x i32> [[SEXT_I]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
12+
; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i32>
1713
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <4 x i64> [[A:%.*]] to <32 x i8>
1814
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <32 x i8> [[TMP5]], <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
1915
; CHECK-NEXT: [[TMP7:%.*]] = bitcast <4 x i64> [[B:%.*]] to <32 x i8>

llvm/test/Transforms/VectorCombine/X86/shuffle-of-casts.ll

+26-32
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77
define <16 x i32> @concat_zext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
88
; CHECK-LABEL: @concat_zext_v8i16_v16i32(
9-
; CHECK-NEXT: [[X0:%.*]] = zext <8 x i16> [[A0:%.*]] to <8 x i32>
10-
; CHECK-NEXT: [[X1:%.*]] = zext <8 x i16> [[A1:%.*]] to <8 x i32>
11-
; CHECK-NEXT: [[R:%.*]] = shufflevector <8 x i32> [[X0]], <8 x i32> [[X1]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
9+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i16> [[A0:%.*]], <8 x i16> [[A1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
10+
; CHECK-NEXT: [[R:%.*]] = zext <16 x i16> [[TMP1]] to <16 x i32>
1211
; CHECK-NEXT: ret <16 x i32> [[R]]
1312
;
1413
%x0 = zext <8 x i16> %a0 to <8 x i32>
@@ -19,9 +18,8 @@ define <16 x i32> @concat_zext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
1918

2019
define <16 x i32> @concat_sext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
2120
; CHECK-LABEL: @concat_sext_v8i16_v16i32(
22-
; CHECK-NEXT: [[X0:%.*]] = sext <8 x i16> [[A0:%.*]] to <8 x i32>
23-
; CHECK-NEXT: [[X1:%.*]] = sext <8 x i16> [[A1:%.*]] to <8 x i32>
24-
; CHECK-NEXT: [[R:%.*]] = shufflevector <8 x i32> [[X0]], <8 x i32> [[X1]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
21+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i16> [[A0:%.*]], <8 x i16> [[A1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
22+
; CHECK-NEXT: [[R:%.*]] = sext <16 x i16> [[TMP1]] to <16 x i32>
2523
; CHECK-NEXT: ret <16 x i32> [[R]]
2624
;
2725
%x0 = sext <8 x i16> %a0 to <8 x i32>
@@ -32,9 +30,8 @@ define <16 x i32> @concat_sext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
3230

3331
define <8 x i32> @concat_sext_v4i1_v8i32(<4 x i1> %a0, <4 x i1> %a1) {
3432
; CHECK-LABEL: @concat_sext_v4i1_v8i32(
35-
; CHECK-NEXT: [[X0:%.*]] = sext <4 x i1> [[A0:%.*]] to <4 x i32>
36-
; CHECK-NEXT: [[X1:%.*]] = sext <4 x i1> [[A1:%.*]] to <4 x i32>
37-
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x i32> [[X0]], <4 x i32> [[X1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
33+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i1> [[A0:%.*]], <4 x i1> [[A1:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
34+
; CHECK-NEXT: [[R:%.*]] = sext <8 x i1> [[TMP1]] to <8 x i32>
3835
; CHECK-NEXT: ret <8 x i32> [[R]]
3936
;
4037
%x0 = sext <4 x i1> %a0 to <4 x i32>
@@ -45,9 +42,8 @@ define <8 x i32> @concat_sext_v4i1_v8i32(<4 x i1> %a0, <4 x i1> %a1) {
4542

4643
define <8 x i16> @concat_trunc_v4i32_v8i16(<4 x i32> %a0, <4 x i32> %a1) {
4744
; CHECK-LABEL: @concat_trunc_v4i32_v8i16(
48-
; CHECK-NEXT: [[X0:%.*]] = trunc <4 x i32> [[A0:%.*]] to <4 x i16>
49-
; CHECK-NEXT: [[X1:%.*]] = trunc <4 x i32> [[A1:%.*]] to <4 x i16>
50-
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x i16> [[X0]], <4 x i16> [[X1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
45+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[A0:%.*]], <4 x i32> [[A1:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
46+
; CHECK-NEXT: [[R:%.*]] = trunc <8 x i32> [[TMP1]] to <8 x i16>
5147
; CHECK-NEXT: ret <8 x i16> [[R]]
5248
;
5349
%x0 = trunc <4 x i32> %a0 to <4 x i16>
@@ -58,9 +54,8 @@ define <8 x i16> @concat_trunc_v4i32_v8i16(<4 x i32> %a0, <4 x i32> %a1) {
5854

5955
define <8 x ptr> @concat_inttoptr_v4i32_v8iptr(<4 x i32> %a0, <4 x i32> %a1) {
6056
; CHECK-LABEL: @concat_inttoptr_v4i32_v8iptr(
61-
; CHECK-NEXT: [[X0:%.*]] = inttoptr <4 x i32> [[A0:%.*]] to <4 x ptr>
62-
; CHECK-NEXT: [[X1:%.*]] = inttoptr <4 x i32> [[A1:%.*]] to <4 x ptr>
63-
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x ptr> [[X0]], <4 x ptr> [[X1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
57+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[A0:%.*]], <4 x i32> [[A1:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
58+
; CHECK-NEXT: [[R:%.*]] = inttoptr <8 x i32> [[TMP1]] to <8 x ptr>
6459
; CHECK-NEXT: ret <8 x ptr> [[R]]
6560
;
6661
%x0 = inttoptr <4 x i32> %a0 to <4 x ptr>
@@ -71,9 +66,8 @@ define <8 x ptr> @concat_inttoptr_v4i32_v8iptr(<4 x i32> %a0, <4 x i32> %a1) {
7166

7267
define <16 x i64> @concat_ptrtoint_v8i16_v16i32(<8 x ptr> %a0, <8 x ptr> %a1) {
7368
; CHECK-LABEL: @concat_ptrtoint_v8i16_v16i32(
74-
; CHECK-NEXT: [[X0:%.*]] = ptrtoint <8 x ptr> [[A0:%.*]] to <8 x i64>
75-
; CHECK-NEXT: [[X1:%.*]] = ptrtoint <8 x ptr> [[A1:%.*]] to <8 x i64>
76-
; CHECK-NEXT: [[R:%.*]] = shufflevector <8 x i64> [[X0]], <8 x i64> [[X1]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
69+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x ptr> [[A0:%.*]], <8 x ptr> [[A1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
70+
; CHECK-NEXT: [[R:%.*]] = ptrtoint <16 x ptr> [[TMP1]] to <16 x i64>
7771
; CHECK-NEXT: ret <16 x i64> [[R]]
7872
;
7973
%x0 = ptrtoint <8 x ptr> %a0 to <8 x i64>
@@ -83,11 +77,16 @@ define <16 x i64> @concat_ptrtoint_v8i16_v16i32(<8 x ptr> %a0, <8 x ptr> %a1) {
8377
}
8478

8579
define <8 x double> @concat_fpext_v4f32_v8f64(<4 x float> %a0, <4 x float> %a1) {
86-
; CHECK-LABEL: @concat_fpext_v4f32_v8f64(
87-
; CHECK-NEXT: [[X0:%.*]] = fpext <4 x float> [[A0:%.*]] to <4 x double>
88-
; CHECK-NEXT: [[X1:%.*]] = fpext <4 x float> [[A1:%.*]] to <4 x double>
89-
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x double> [[X0]], <4 x double> [[X1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
90-
; CHECK-NEXT: ret <8 x double> [[R]]
80+
; SSE-LABEL: @concat_fpext_v4f32_v8f64(
81+
; SSE-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
82+
; SSE-NEXT: [[R:%.*]] = fpext <8 x float> [[TMP1]] to <8 x double>
83+
; SSE-NEXT: ret <8 x double> [[R]]
84+
;
85+
; AVX-LABEL: @concat_fpext_v4f32_v8f64(
86+
; AVX-NEXT: [[X0:%.*]] = fpext <4 x float> [[A0:%.*]] to <4 x double>
87+
; AVX-NEXT: [[X1:%.*]] = fpext <4 x float> [[A1:%.*]] to <4 x double>
88+
; AVX-NEXT: [[R:%.*]] = shufflevector <4 x double> [[X0]], <4 x double> [[X1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
89+
; AVX-NEXT: ret <8 x double> [[R]]
9190
;
9291
%x0 = fpext <4 x float> %a0 to <4 x double>
9392
%x1 = fpext <4 x float> %a1 to <4 x double>
@@ -112,9 +111,8 @@ define <16 x float> @concat_fptrunc_v8f64_v16f32(<8 x double> %a0, <8 x double>
112111

113112
define <16 x i32> @rconcat_sext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
114113
; CHECK-LABEL: @rconcat_sext_v8i16_v16i32(
115-
; CHECK-NEXT: [[X0:%.*]] = sext <8 x i16> [[A0:%.*]] to <8 x i32>
116-
; CHECK-NEXT: [[X1:%.*]] = sext <8 x i16> [[A1:%.*]] to <8 x i32>
117-
; CHECK-NEXT: [[R:%.*]] = shufflevector <8 x i32> [[X0]], <8 x i32> [[X1]], <16 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
114+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i16> [[A0:%.*]], <8 x i16> [[A1:%.*]], <16 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
115+
; CHECK-NEXT: [[R:%.*]] = sext <16 x i16> [[TMP1]] to <16 x i32>
118116
; CHECK-NEXT: ret <16 x i32> [[R]]
119117
;
120118
%x0 = sext <8 x i16> %a0 to <8 x i32>
@@ -127,9 +125,8 @@ define <16 x i32> @rconcat_sext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
127125

128126
define <8 x double> @interleave_fpext_v4f32_v8f64(<4 x float> %a0, <4 x float> %a1) {
129127
; CHECK-LABEL: @interleave_fpext_v4f32_v8f64(
130-
; CHECK-NEXT: [[X0:%.*]] = fpext <4 x float> [[A0:%.*]] to <4 x double>
131-
; CHECK-NEXT: [[X1:%.*]] = fpext <4 x float> [[A1:%.*]] to <4 x double>
132-
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x double> [[X0]], <4 x double> [[X1]], <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
128+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
129+
; CHECK-NEXT: [[R:%.*]] = fpext <8 x float> [[TMP1]] to <8 x double>
133130
; CHECK-NEXT: ret <8 x double> [[R]]
134131
;
135132
%x0 = fpext <4 x float> %a0 to <4 x double>
@@ -184,6 +181,3 @@ define <16 x i32> @concat_sext_zext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
184181
%r = shufflevector <8 x i32> %x0, <8 x i32> %x1, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
185182
ret <16 x i32> %r
186183
}
187-
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
188-
; AVX: {{.*}}
189-
; SSE: {{.*}}

0 commit comments

Comments
 (0)