Skip to content

Commit 5ffbdd9

Browse files
committed
[RISCV] Handle .vx pseudos in hasAllNBitUsers (#67419)
Vector pseudos with scalar operands only use the lower SEW bits (or less in the case of shifts and clips). This patch accounts for this in hasAllNBitUsers for both SDNodes in RISCVISelDAGToDAG. We also need to handle this in RISCVOptWInstrs otherwise we introduce slliw instructions that are less compressible than their original slli counterpart. This is a reland of aff6ffc with the refactoring omitted.
1 parent b4a8999 commit 5ffbdd9

11 files changed

+1030
-794
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2753,6 +2753,148 @@ bool RISCVDAGToDAGISel::selectSHXADD_UWOp(SDValue N, unsigned ShAmt,
27532753
return false;
27542754
}
27552755

2756+
static bool vectorPseudoHasAllNBitUsers(SDNode *User, unsigned UserOpNo,
2757+
unsigned Bits,
2758+
const TargetInstrInfo *TII) {
2759+
const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
2760+
RISCVVPseudosTable::getPseudoInfo(User->getMachineOpcode());
2761+
2762+
if (!PseudoInfo)
2763+
return false;
2764+
2765+
const MCInstrDesc &MCID = TII->get(User->getMachineOpcode());
2766+
const uint64_t TSFlags = MCID.TSFlags;
2767+
if (!RISCVII::hasSEWOp(TSFlags))
2768+
return false;
2769+
assert(RISCVII::hasVLOp(TSFlags));
2770+
2771+
bool HasGlueOp = User->getGluedNode() != nullptr;
2772+
unsigned ChainOpIdx = User->getNumOperands() - HasGlueOp - 1;
2773+
bool HasChainOp = User->getOperand(ChainOpIdx).getValueType() == MVT::Other;
2774+
bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(TSFlags);
2775+
unsigned VLIdx =
2776+
User->getNumOperands() - HasVecPolicyOp - HasChainOp - HasGlueOp - 2;
2777+
const unsigned Log2SEW = User->getConstantOperandVal(VLIdx + 1);
2778+
2779+
if (UserOpNo == VLIdx)
2780+
return false;
2781+
2782+
// TODO: Handle Zvbb instructions
2783+
switch (PseudoInfo->BaseInstr) {
2784+
default:
2785+
return false;
2786+
2787+
// 11.6. Vector Single-Width Shift Instructions
2788+
case RISCV::VSLL_VX:
2789+
case RISCV::VSRL_VX:
2790+
case RISCV::VSRA_VX:
2791+
// 12.4. Vector Single-Width Scaling Shift Instructions
2792+
case RISCV::VSSRL_VX:
2793+
case RISCV::VSSRA_VX:
2794+
// Only the low lg2(SEW) bits of the shift-amount value are used.
2795+
if (Bits < Log2SEW)
2796+
return false;
2797+
break;
2798+
2799+
// 11.7 Vector Narrowing Integer Right Shift Instructions
2800+
case RISCV::VNSRL_WX:
2801+
case RISCV::VNSRA_WX:
2802+
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
2803+
case RISCV::VNCLIPU_WX:
2804+
case RISCV::VNCLIP_WX:
2805+
// Only the low lg2(2*SEW) bits of the shift-amount value are used.
2806+
if (Bits < Log2SEW + 1)
2807+
return false;
2808+
break;
2809+
2810+
// 11.1. Vector Single-Width Integer Add and Subtract
2811+
case RISCV::VADD_VX:
2812+
case RISCV::VSUB_VX:
2813+
case RISCV::VRSUB_VX:
2814+
// 11.2. Vector Widening Integer Add/Subtract
2815+
case RISCV::VWADDU_VX:
2816+
case RISCV::VWSUBU_VX:
2817+
case RISCV::VWADD_VX:
2818+
case RISCV::VWSUB_VX:
2819+
case RISCV::VWADDU_WX:
2820+
case RISCV::VWSUBU_WX:
2821+
case RISCV::VWADD_WX:
2822+
case RISCV::VWSUB_WX:
2823+
// 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
2824+
case RISCV::VADC_VXM:
2825+
case RISCV::VADC_VIM:
2826+
case RISCV::VMADC_VXM:
2827+
case RISCV::VMADC_VIM:
2828+
case RISCV::VMADC_VX:
2829+
case RISCV::VSBC_VXM:
2830+
case RISCV::VMSBC_VXM:
2831+
case RISCV::VMSBC_VX:
2832+
// 11.5 Vector Bitwise Logical Instructions
2833+
case RISCV::VAND_VX:
2834+
case RISCV::VOR_VX:
2835+
case RISCV::VXOR_VX:
2836+
// 11.8. Vector Integer Compare Instructions
2837+
case RISCV::VMSEQ_VX:
2838+
case RISCV::VMSNE_VX:
2839+
case RISCV::VMSLTU_VX:
2840+
case RISCV::VMSLT_VX:
2841+
case RISCV::VMSLEU_VX:
2842+
case RISCV::VMSLE_VX:
2843+
case RISCV::VMSGTU_VX:
2844+
case RISCV::VMSGT_VX:
2845+
// 11.9. Vector Integer Min/Max Instructions
2846+
case RISCV::VMINU_VX:
2847+
case RISCV::VMIN_VX:
2848+
case RISCV::VMAXU_VX:
2849+
case RISCV::VMAX_VX:
2850+
// 11.10. Vector Single-Width Integer Multiply Instructions
2851+
case RISCV::VMUL_VX:
2852+
case RISCV::VMULH_VX:
2853+
case RISCV::VMULHU_VX:
2854+
case RISCV::VMULHSU_VX:
2855+
// 11.11. Vector Integer Divide Instructions
2856+
case RISCV::VDIVU_VX:
2857+
case RISCV::VDIV_VX:
2858+
case RISCV::VREMU_VX:
2859+
case RISCV::VREM_VX:
2860+
// 11.12. Vector Widening Integer Multiply Instructions
2861+
case RISCV::VWMUL_VX:
2862+
case RISCV::VWMULU_VX:
2863+
case RISCV::VWMULSU_VX:
2864+
// 11.13. Vector Single-Width Integer Multiply-Add Instructions
2865+
case RISCV::VMACC_VX:
2866+
case RISCV::VNMSAC_VX:
2867+
case RISCV::VMADD_VX:
2868+
case RISCV::VNMSUB_VX:
2869+
// 11.14. Vector Widening Integer Multiply-Add Instructions
2870+
case RISCV::VWMACCU_VX:
2871+
case RISCV::VWMACC_VX:
2872+
case RISCV::VWMACCSU_VX:
2873+
case RISCV::VWMACCUS_VX:
2874+
// 11.15. Vector Integer Merge Instructions
2875+
case RISCV::VMERGE_VXM:
2876+
// 11.16. Vector Integer Move Instructions
2877+
case RISCV::VMV_V_X:
2878+
// 12.1. Vector Single-Width Saturating Add and Subtract
2879+
case RISCV::VSADDU_VX:
2880+
case RISCV::VSADD_VX:
2881+
case RISCV::VSSUBU_VX:
2882+
case RISCV::VSSUB_VX:
2883+
// 12.2. Vector Single-Width Averaging Add and Subtract
2884+
case RISCV::VAADDU_VX:
2885+
case RISCV::VAADD_VX:
2886+
case RISCV::VASUBU_VX:
2887+
case RISCV::VASUB_VX:
2888+
// 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
2889+
case RISCV::VSMUL_VX:
2890+
// 16.1. Integer Scalar Move Instructions
2891+
case RISCV::VMV_S_X:
2892+
if (Bits < (1 << Log2SEW))
2893+
return false;
2894+
}
2895+
return true;
2896+
}
2897+
27562898
// Return true if all users of this SDNode* only consume the lower \p Bits.
27572899
// This can be used to form W instructions for add/sub/mul/shl even when the
27582900
// root isn't a sext_inreg. This can allow the ADDW/SUBW/MULW/SLLIW to CSE if
@@ -2784,6 +2926,8 @@ bool RISCVDAGToDAGISel::hasAllNBitUsers(SDNode *Node, unsigned Bits,
27842926
// TODO: Add more opcodes?
27852927
switch (User->getMachineOpcode()) {
27862928
default:
2929+
if (vectorPseudoHasAllNBitUsers(User, UI.getOperandNo(), Bits, TII))
2930+
break;
27872931
return false;
27882932
case RISCV::ADDW:
27892933
case RISCV::ADDIW:

llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,141 @@ FunctionPass *llvm::createRISCVOptWInstrsPass() {
7878
return new RISCVOptWInstrs();
7979
}
8080

81+
static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
82+
unsigned Bits) {
83+
const MachineInstr &MI = *UserOp.getParent();
84+
const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
85+
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
86+
87+
if (!PseudoInfo)
88+
return false;
89+
90+
const MCInstrDesc &MCID = MI.getDesc();
91+
const uint64_t TSFlags = MI.getDesc().TSFlags;
92+
if (!RISCVII::hasSEWOp(TSFlags))
93+
return false;
94+
assert(RISCVII::hasVLOp(TSFlags));
95+
const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
96+
97+
if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
98+
return false;
99+
100+
// TODO: Handle Zvbb instructions
101+
switch (PseudoInfo->BaseInstr) {
102+
default:
103+
return false;
104+
105+
// 11.6. Vector Single-Width Shift Instructions
106+
case RISCV::VSLL_VX:
107+
case RISCV::VSRL_VX:
108+
case RISCV::VSRA_VX:
109+
// 12.4. Vector Single-Width Scaling Shift Instructions
110+
case RISCV::VSSRL_VX:
111+
case RISCV::VSSRA_VX:
112+
// Only the low lg2(SEW) bits of the shift-amount value are used.
113+
if (Bits < Log2SEW)
114+
return false;
115+
break;
116+
117+
// 11.7 Vector Narrowing Integer Right Shift Instructions
118+
case RISCV::VNSRL_WX:
119+
case RISCV::VNSRA_WX:
120+
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
121+
case RISCV::VNCLIPU_WX:
122+
case RISCV::VNCLIP_WX:
123+
// Only the low lg2(2*SEW) bits of the shift-amount value are used.
124+
if (Bits < Log2SEW + 1)
125+
return false;
126+
break;
127+
128+
// 11.1. Vector Single-Width Integer Add and Subtract
129+
case RISCV::VADD_VX:
130+
case RISCV::VSUB_VX:
131+
case RISCV::VRSUB_VX:
132+
// 11.2. Vector Widening Integer Add/Subtract
133+
case RISCV::VWADDU_VX:
134+
case RISCV::VWSUBU_VX:
135+
case RISCV::VWADD_VX:
136+
case RISCV::VWSUB_VX:
137+
case RISCV::VWADDU_WX:
138+
case RISCV::VWSUBU_WX:
139+
case RISCV::VWADD_WX:
140+
case RISCV::VWSUB_WX:
141+
// 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
142+
case RISCV::VADC_VXM:
143+
case RISCV::VADC_VIM:
144+
case RISCV::VMADC_VXM:
145+
case RISCV::VMADC_VIM:
146+
case RISCV::VMADC_VX:
147+
case RISCV::VSBC_VXM:
148+
case RISCV::VMSBC_VXM:
149+
case RISCV::VMSBC_VX:
150+
// 11.5 Vector Bitwise Logical Instructions
151+
case RISCV::VAND_VX:
152+
case RISCV::VOR_VX:
153+
case RISCV::VXOR_VX:
154+
// 11.8. Vector Integer Compare Instructions
155+
case RISCV::VMSEQ_VX:
156+
case RISCV::VMSNE_VX:
157+
case RISCV::VMSLTU_VX:
158+
case RISCV::VMSLT_VX:
159+
case RISCV::VMSLEU_VX:
160+
case RISCV::VMSLE_VX:
161+
case RISCV::VMSGTU_VX:
162+
case RISCV::VMSGT_VX:
163+
// 11.9. Vector Integer Min/Max Instructions
164+
case RISCV::VMINU_VX:
165+
case RISCV::VMIN_VX:
166+
case RISCV::VMAXU_VX:
167+
case RISCV::VMAX_VX:
168+
// 11.10. Vector Single-Width Integer Multiply Instructions
169+
case RISCV::VMUL_VX:
170+
case RISCV::VMULH_VX:
171+
case RISCV::VMULHU_VX:
172+
case RISCV::VMULHSU_VX:
173+
// 11.11. Vector Integer Divide Instructions
174+
case RISCV::VDIVU_VX:
175+
case RISCV::VDIV_VX:
176+
case RISCV::VREMU_VX:
177+
case RISCV::VREM_VX:
178+
// 11.12. Vector Widening Integer Multiply Instructions
179+
case RISCV::VWMUL_VX:
180+
case RISCV::VWMULU_VX:
181+
case RISCV::VWMULSU_VX:
182+
// 11.13. Vector Single-Width Integer Multiply-Add Instructions
183+
case RISCV::VMACC_VX:
184+
case RISCV::VNMSAC_VX:
185+
case RISCV::VMADD_VX:
186+
case RISCV::VNMSUB_VX:
187+
// 11.14. Vector Widening Integer Multiply-Add Instructions
188+
case RISCV::VWMACCU_VX:
189+
case RISCV::VWMACC_VX:
190+
case RISCV::VWMACCSU_VX:
191+
case RISCV::VWMACCUS_VX:
192+
// 11.15. Vector Integer Merge Instructions
193+
case RISCV::VMERGE_VXM:
194+
// 11.16. Vector Integer Move Instructions
195+
case RISCV::VMV_V_X:
196+
// 12.1. Vector Single-Width Saturating Add and Subtract
197+
case RISCV::VSADDU_VX:
198+
case RISCV::VSADD_VX:
199+
case RISCV::VSSUBU_VX:
200+
case RISCV::VSSUB_VX:
201+
// 12.2. Vector Single-Width Averaging Add and Subtract
202+
case RISCV::VAADDU_VX:
203+
case RISCV::VAADD_VX:
204+
case RISCV::VASUBU_VX:
205+
case RISCV::VASUB_VX:
206+
// 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
207+
case RISCV::VSMUL_VX:
208+
// 16.1. Integer Scalar Move Instructions
209+
case RISCV::VMV_S_X:
210+
if (Bits < (1 << Log2SEW))
211+
return false;
212+
}
213+
return true;
214+
}
215+
81216
// Checks if all users only demand the lower \p OrigBits of the original
82217
// instruction's result.
83218
// TODO: handle multiple interdependent transformations
@@ -108,6 +243,8 @@ static bool hasAllNBitUsers(const MachineInstr &OrigMI,
108243

109244
switch (UserMI->getOpcode()) {
110245
default:
246+
if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
247+
break;
111248
return false;
112249

113250
case RISCV::ADDIW:

llvm/test/CodeGen/RISCV/rvv/constant-folding.ll

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2-
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s \
3-
; RUN: | FileCheck %s --check-prefixes=CHECK,RV32
4-
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s \
5-
; RUN: | FileCheck %s --check-prefixes=CHECK,RV64
2+
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
3+
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
64

75
; These tests check that the scalable-vector version of this series of
86
; instructions does not get into an infinite DAGCombine loop. This was
@@ -14,26 +12,15 @@
1412
; a constant SPLAT_VECTOR didn't follow suit.
1513

1614
define <2 x i16> @fixedlen(<2 x i32> %x) {
17-
; RV32-LABEL: fixedlen:
18-
; RV32: # %bb.0:
19-
; RV32-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
20-
; RV32-NEXT: vsrl.vi v8, v8, 16
21-
; RV32-NEXT: lui a0, 1048568
22-
; RV32-NEXT: vand.vx v8, v8, a0
23-
; RV32-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
24-
; RV32-NEXT: vnsrl.wi v8, v8, 0
25-
; RV32-NEXT: ret
26-
;
27-
; RV64-LABEL: fixedlen:
28-
; RV64: # %bb.0:
29-
; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
30-
; RV64-NEXT: vsrl.vi v8, v8, 16
31-
; RV64-NEXT: lui a0, 131071
32-
; RV64-NEXT: slli a0, a0, 3
33-
; RV64-NEXT: vand.vx v8, v8, a0
34-
; RV64-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
35-
; RV64-NEXT: vnsrl.wi v8, v8, 0
36-
; RV64-NEXT: ret
15+
; CHECK-LABEL: fixedlen:
16+
; CHECK: # %bb.0:
17+
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
18+
; CHECK-NEXT: vsrl.vi v8, v8, 16
19+
; CHECK-NEXT: lui a0, 1048568
20+
; CHECK-NEXT: vand.vx v8, v8, a0
21+
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
22+
; CHECK-NEXT: vnsrl.wi v8, v8, 0
23+
; CHECK-NEXT: ret
3724
%v41 = insertelement <2 x i32> poison, i32 16, i32 0
3825
%v42 = shufflevector <2 x i32> %v41, <2 x i32> poison, <2 x i32> zeroinitializer
3926
%v43 = lshr <2 x i32> %x, %v42

0 commit comments

Comments
 (0)