Skip to content

Commit a52cf9a

Browse files
committed
[NVPTX] Preserve v16i8 vector loads when legalizing
This is done by lowering v16i8 loads into LoadV4 operations with i32 results instead of letting ReplaceLoadVector split it into smaller loads during legalization. This is done at dag-combine1 time, so that vector operations with i8 elements can be optimised away instead of being needlessly split during legalization, which involves storing to the stack and loading it back.
1 parent b15b846 commit a52cf9a

File tree

2 files changed

+166
-2
lines changed

2 files changed

+166
-2
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,8 +693,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
693693
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
694694

695695
// We have some custom DAG combine patterns for these nodes
696-
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
697-
ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT,
696+
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
697+
ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
698698
ISD::VSELECT});
699699

700700
// setcc for f16x2 and bf16x2 needs special handling to prevent
@@ -5471,6 +5471,45 @@ static SDValue PerformVSELECTCombine(SDNode *N,
54715471
return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E);
54725472
}
54735473

5474+
static SDValue PerformLOADCombine(SDNode *N,
5475+
TargetLowering::DAGCombinerInfo &DCI) {
5476+
SelectionDAG &DAG = DCI.DAG;
5477+
LoadSDNode *LD = cast<LoadSDNode>(N);
5478+
5479+
// Lower a v16i8 load into a LoadV4 operation with i32 results instead of
5480+
// letting ReplaceLoadVector split it into smaller loads during legalization.
5481+
// This is done at dag-combine1 time, so that vector operations with i8
5482+
// elements can be optimised away instead of being needlessly split during
5483+
// legalization, which involves storing to the stack and loading it back.
5484+
EVT VT = N->getValueType(0);
5485+
if (VT != MVT::v16i8)
5486+
return SDValue();
5487+
5488+
SDLoc DL(N);
5489+
5490+
// Create a v4i32 vector load operation, effectively <4 x v4i8>.
5491+
unsigned Opc = NVPTXISD::LoadV4;
5492+
EVT NewVT = MVT::v4i32;
5493+
EVT EltVT = NewVT.getVectorElementType();
5494+
unsigned NumElts = NewVT.getVectorNumElements();
5495+
EVT RetVTs[] = {EltVT, EltVT, EltVT, EltVT, MVT::Other};
5496+
SDVTList RetVTList = DAG.getVTList(RetVTs);
5497+
SmallVector<SDValue, 8> Ops(N->ops());
5498+
Ops.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
5499+
SDValue NewLoad = DAG.getMemIntrinsicNode(Opc, DL, RetVTList, Ops, NewVT,
5500+
LD->getMemOperand());
5501+
SDValue NewChain = NewLoad.getValue(NumElts);
5502+
5503+
// Create a vector of the same type returned by the original load.
5504+
SmallVector<SDValue, 4> Elts;
5505+
for (unsigned i = 0; i < NumElts; i++)
5506+
Elts.push_back(NewLoad.getValue(i));
5507+
return DCI.DAG.getMergeValues(
5508+
{DCI.DAG.getBitcast(VT, DCI.DAG.getBuildVector(NewVT, DL, Elts)),
5509+
NewChain},
5510+
DL);
5511+
}
5512+
54745513
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
54755514
DAGCombinerInfo &DCI) const {
54765515
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5490,6 +5529,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
54905529
return PerformREMCombine(N, DCI, OptLevel);
54915530
case ISD::SETCC:
54925531
return PerformSETCCCombine(N, DCI);
5532+
case ISD::LOAD:
5533+
return PerformLOADCombine(N, DCI);
54935534
case NVPTXISD::StoreRetval:
54945535
case NVPTXISD::StoreRetvalV2:
54955536
case NVPTXISD::StoreRetvalV4:

llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,126 @@ define float @ff(ptr %p) {
5252
%sum = fadd float %sum3, %v4
5353
ret float %sum
5454
}
55+
56+
define void @combine_v16i8(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
57+
; ENABLED-LABEL: combine_v16i8
58+
; ENABLED: ld.v4.u32
59+
%val0 = load i8, ptr %ptr1, align 16
60+
%ptr1.1 = getelementptr inbounds i8, ptr %ptr1, i64 1
61+
%val1 = load i8, ptr %ptr1.1, align 1
62+
%ptr1.2 = getelementptr inbounds i8, ptr %ptr1, i64 2
63+
%val2 = load i8, ptr %ptr1.2, align 2
64+
%ptr1.3 = getelementptr inbounds i8, ptr %ptr1, i64 3
65+
%val3 = load i8, ptr %ptr1.3, align 1
66+
%ptr1.4 = getelementptr inbounds i8, ptr %ptr1, i64 4
67+
%val4 = load i8, ptr %ptr1.4, align 4
68+
%ptr1.5 = getelementptr inbounds i8, ptr %ptr1, i64 5
69+
%val5 = load i8, ptr %ptr1.5, align 1
70+
%ptr1.6 = getelementptr inbounds i8, ptr %ptr1, i64 6
71+
%val6 = load i8, ptr %ptr1.6, align 2
72+
%ptr1.7 = getelementptr inbounds i8, ptr %ptr1, i64 7
73+
%val7 = load i8, ptr %ptr1.7, align 1
74+
%ptr1.8 = getelementptr inbounds i8, ptr %ptr1, i64 8
75+
%val8 = load i8, ptr %ptr1.8, align 8
76+
%ptr1.9 = getelementptr inbounds i8, ptr %ptr1, i64 9
77+
%val9 = load i8, ptr %ptr1.9, align 1
78+
%ptr1.10 = getelementptr inbounds i8, ptr %ptr1, i64 10
79+
%val10 = load i8, ptr %ptr1.10, align 2
80+
%ptr1.11 = getelementptr inbounds i8, ptr %ptr1, i64 11
81+
%val11 = load i8, ptr %ptr1.11, align 1
82+
%ptr1.12 = getelementptr inbounds i8, ptr %ptr1, i64 12
83+
%val12 = load i8, ptr %ptr1.12, align 4
84+
%ptr1.13 = getelementptr inbounds i8, ptr %ptr1, i64 13
85+
%val13 = load i8, ptr %ptr1.13, align 1
86+
%ptr1.14 = getelementptr inbounds i8, ptr %ptr1, i64 14
87+
%val14 = load i8, ptr %ptr1.14, align 2
88+
%ptr1.15 = getelementptr inbounds i8, ptr %ptr1, i64 15
89+
%val15 = load i8, ptr %ptr1.15, align 1
90+
%lane0 = zext i8 %val0 to i32
91+
%lane1 = zext i8 %val1 to i32
92+
%lane2 = zext i8 %val2 to i32
93+
%lane3 = zext i8 %val3 to i32
94+
%lane4 = zext i8 %val4 to i32
95+
%lane5 = zext i8 %val5 to i32
96+
%lane6 = zext i8 %val6 to i32
97+
%lane7 = zext i8 %val7 to i32
98+
%lane8 = zext i8 %val8 to i32
99+
%lane9 = zext i8 %val9 to i32
100+
%lane10 = zext i8 %val10 to i32
101+
%lane11 = zext i8 %val11 to i32
102+
%lane12 = zext i8 %val12 to i32
103+
%lane13 = zext i8 %val13 to i32
104+
%lane14 = zext i8 %val14 to i32
105+
%lane15 = zext i8 %val15 to i32
106+
%red.1 = add i32 %lane0, %lane1
107+
%red.2 = add i32 %red.1, %lane2
108+
%red.3 = add i32 %red.2, %lane3
109+
%red.4 = add i32 %red.3, %lane4
110+
%red.5 = add i32 %red.4, %lane5
111+
%red.6 = add i32 %red.5, %lane6
112+
%red.7 = add i32 %red.6, %lane7
113+
%red.8 = add i32 %red.7, %lane8
114+
%red.9 = add i32 %red.8, %lane9
115+
%red.10 = add i32 %red.9, %lane10
116+
%red.11 = add i32 %red.10, %lane11
117+
%red.12 = add i32 %red.11, %lane12
118+
%red.13 = add i32 %red.12, %lane13
119+
%red.14 = add i32 %red.13, %lane14
120+
%red = add i32 %red.14, %lane15
121+
store i32 %red, ptr %ptr2, align 4
122+
ret void
123+
}
124+
125+
define void @combine_v8i16(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
126+
; ENABLED-LABEL: combine_v8i16
127+
; ENABLED: ld.v4.b32
128+
%val0 = load i16, ptr %ptr1, align 16
129+
%ptr1.1 = getelementptr inbounds i16, ptr %ptr1, i64 1
130+
%val1 = load i16, ptr %ptr1.1, align 2
131+
%ptr1.2 = getelementptr inbounds i16, ptr %ptr1, i64 2
132+
%val2 = load i16, ptr %ptr1.2, align 4
133+
%ptr1.3 = getelementptr inbounds i16, ptr %ptr1, i64 3
134+
%val3 = load i16, ptr %ptr1.3, align 2
135+
%ptr1.4 = getelementptr inbounds i16, ptr %ptr1, i64 4
136+
%val4 = load i16, ptr %ptr1.4, align 4
137+
%ptr1.5 = getelementptr inbounds i16, ptr %ptr1, i64 5
138+
%val5 = load i16, ptr %ptr1.5, align 2
139+
%ptr1.6 = getelementptr inbounds i16, ptr %ptr1, i64 6
140+
%val6 = load i16, ptr %ptr1.6, align 4
141+
%ptr1.7 = getelementptr inbounds i16, ptr %ptr1, i64 7
142+
%val7 = load i16, ptr %ptr1.7, align 2
143+
%lane0 = zext i16 %val0 to i32
144+
%lane1 = zext i16 %val1 to i32
145+
%lane2 = zext i16 %val2 to i32
146+
%lane3 = zext i16 %val3 to i32
147+
%lane4 = zext i16 %val4 to i32
148+
%lane5 = zext i16 %val5 to i32
149+
%lane6 = zext i16 %val6 to i32
150+
%lane7 = zext i16 %val7 to i32
151+
%red.1 = add i32 %lane0, %lane1
152+
%red.2 = add i32 %red.1, %lane2
153+
%red.3 = add i32 %red.2, %lane3
154+
%red.4 = add i32 %red.3, %lane4
155+
%red.5 = add i32 %red.4, %lane5
156+
%red.6 = add i32 %red.5, %lane6
157+
%red = add i32 %red.6, %lane7
158+
store i32 %red, ptr %ptr2, align 4
159+
ret void
160+
}
161+
162+
define void @combine_v4i32(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
163+
; ENABLED-LABEL: combine_v4i32
164+
; ENABLED: ld.v4.u32
165+
%val0 = load i32, ptr %ptr1, align 16
166+
%ptr1.1 = getelementptr inbounds i32, ptr %ptr1, i64 1
167+
%val1 = load i32, ptr %ptr1.1, align 4
168+
%ptr1.2 = getelementptr inbounds i32, ptr %ptr1, i64 2
169+
%val2 = load i32, ptr %ptr1.2, align 8
170+
%ptr1.3 = getelementptr inbounds i32, ptr %ptr1, i64 3
171+
%val3 = load i32, ptr %ptr1.3, align 4
172+
%red.1 = add i32 %val0, %val1
173+
%red.2 = add i32 %red.1, %val2
174+
%red = add i32 %red.2, %val3
175+
store i32 %red, ptr %ptr2, align 4
176+
ret void
177+
}

0 commit comments

Comments
 (0)