Skip to content

[NVPTX] Preserve v16i8 vector loads when legalizing #67322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);

// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT,
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
ISD::VSELECT});

// setcc for f16x2 and bf16x2 needs special handling to prevent
Expand Down Expand Up @@ -5471,6 +5471,45 @@ static SDValue PerformVSELECTCombine(SDNode *N,
return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E);
}

static SDValue PerformLOADCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
SelectionDAG &DAG = DCI.DAG;
LoadSDNode *LD = cast<LoadSDNode>(N);

// Lower a v16i8 load into a LoadV4 operation with i32 results instead of
// letting ReplaceLoadVector split it into smaller loads during legalization.
// This is done at dag-combine1 time, so that vector operations with i8
// elements can be optimised away instead of being needlessly split during
// legalization, which involves storing to the stack and loading it back.
EVT VT = N->getValueType(0);
if (VT != MVT::v16i8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to generalize it to v8i8, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if that would work as well as v16i8, since I don't think there is a ld.v2.b32 instruction we could use. It would mean having to create two NVPTXISD::LoadV* nodes here and duplicating some code from ReplaceLoadVector.

By the way, I have also tried to do this change in ReplaceLoadVector instead of adding a DAG combine for LOAD nodes. I backtracked as this was creating stack operations. I didn't check again after your recent commit was merged, but maybe that works better now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a ld.v2.b32 instruction we could use

V2 ld/st variants do exist:

def _v2_avar : NVPTXInst<

The code is easily parametrizable by NumElts.

return SDValue();

SDLoc DL(N);

// Create a v4i32 vector load operation, effectively <4 x v4i8>.
unsigned Opc = NVPTXISD::LoadV4;
EVT NewVT = MVT::v4i32;
EVT EltVT = NewVT.getVectorElementType();
unsigned NumElts = NewVT.getVectorNumElements();
EVT RetVTs[] = {EltVT, EltVT, EltVT, EltVT, MVT::Other};
SDVTList RetVTList = DAG.getVTList(RetVTs);
SmallVector<SDValue, 8> Ops(N->ops());
Ops.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
SDValue NewLoad = DAG.getMemIntrinsicNode(Opc, DL, RetVTList, Ops, NewVT,
LD->getMemOperand());
SDValue NewChain = NewLoad.getValue(NumElts);

// Create a vector of the same type returned by the original load.
SmallVector<SDValue, 4> Elts;
for (unsigned i = 0; i < NumElts; i++)
Elts.push_back(NewLoad.getValue(i));
return DCI.DAG.getMergeValues(
{DCI.DAG.getBitcast(VT, DCI.DAG.getBuildVector(NewVT, DL, Elts)),
NewChain},
DL);
}

SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
Expand All @@ -5490,6 +5529,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformREMCombine(N, DCI, OptLevel);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI);
case ISD::LOAD:
return PerformLOADCombine(N, DCI);
case NVPTXISD::StoreRetval:
case NVPTXISD::StoreRetvalV2:
case NVPTXISD::StoreRetvalV4:
Expand Down
123 changes: 123 additions & 0 deletions llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,126 @@ define float @ff(ptr %p) {
%sum = fadd float %sum3, %v4
ret float %sum
}

define void @combine_v16i8(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
; ENABLED-LABEL: combine_v16i8
; ENABLED: ld.v4.u32
%val0 = load i8, ptr %ptr1, align 16
%ptr1.1 = getelementptr inbounds i8, ptr %ptr1, i64 1
%val1 = load i8, ptr %ptr1.1, align 1
%ptr1.2 = getelementptr inbounds i8, ptr %ptr1, i64 2
%val2 = load i8, ptr %ptr1.2, align 2
%ptr1.3 = getelementptr inbounds i8, ptr %ptr1, i64 3
%val3 = load i8, ptr %ptr1.3, align 1
%ptr1.4 = getelementptr inbounds i8, ptr %ptr1, i64 4
%val4 = load i8, ptr %ptr1.4, align 4
%ptr1.5 = getelementptr inbounds i8, ptr %ptr1, i64 5
%val5 = load i8, ptr %ptr1.5, align 1
%ptr1.6 = getelementptr inbounds i8, ptr %ptr1, i64 6
%val6 = load i8, ptr %ptr1.6, align 2
%ptr1.7 = getelementptr inbounds i8, ptr %ptr1, i64 7
%val7 = load i8, ptr %ptr1.7, align 1
%ptr1.8 = getelementptr inbounds i8, ptr %ptr1, i64 8
%val8 = load i8, ptr %ptr1.8, align 8
%ptr1.9 = getelementptr inbounds i8, ptr %ptr1, i64 9
%val9 = load i8, ptr %ptr1.9, align 1
%ptr1.10 = getelementptr inbounds i8, ptr %ptr1, i64 10
%val10 = load i8, ptr %ptr1.10, align 2
%ptr1.11 = getelementptr inbounds i8, ptr %ptr1, i64 11
%val11 = load i8, ptr %ptr1.11, align 1
%ptr1.12 = getelementptr inbounds i8, ptr %ptr1, i64 12
%val12 = load i8, ptr %ptr1.12, align 4
%ptr1.13 = getelementptr inbounds i8, ptr %ptr1, i64 13
%val13 = load i8, ptr %ptr1.13, align 1
%ptr1.14 = getelementptr inbounds i8, ptr %ptr1, i64 14
%val14 = load i8, ptr %ptr1.14, align 2
%ptr1.15 = getelementptr inbounds i8, ptr %ptr1, i64 15
%val15 = load i8, ptr %ptr1.15, align 1
%lane0 = zext i8 %val0 to i32
%lane1 = zext i8 %val1 to i32
%lane2 = zext i8 %val2 to i32
%lane3 = zext i8 %val3 to i32
%lane4 = zext i8 %val4 to i32
%lane5 = zext i8 %val5 to i32
%lane6 = zext i8 %val6 to i32
%lane7 = zext i8 %val7 to i32
%lane8 = zext i8 %val8 to i32
%lane9 = zext i8 %val9 to i32
%lane10 = zext i8 %val10 to i32
%lane11 = zext i8 %val11 to i32
%lane12 = zext i8 %val12 to i32
%lane13 = zext i8 %val13 to i32
%lane14 = zext i8 %val14 to i32
%lane15 = zext i8 %val15 to i32
%red.1 = add i32 %lane0, %lane1
%red.2 = add i32 %red.1, %lane2
%red.3 = add i32 %red.2, %lane3
%red.4 = add i32 %red.3, %lane4
%red.5 = add i32 %red.4, %lane5
%red.6 = add i32 %red.5, %lane6
%red.7 = add i32 %red.6, %lane7
%red.8 = add i32 %red.7, %lane8
%red.9 = add i32 %red.8, %lane9
%red.10 = add i32 %red.9, %lane10
%red.11 = add i32 %red.10, %lane11
%red.12 = add i32 %red.11, %lane12
%red.13 = add i32 %red.12, %lane13
%red.14 = add i32 %red.13, %lane14
%red = add i32 %red.14, %lane15
store i32 %red, ptr %ptr2, align 4
ret void
}

define void @combine_v8i16(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
; ENABLED-LABEL: combine_v8i16
; ENABLED: ld.v4.b32
%val0 = load i16, ptr %ptr1, align 16
%ptr1.1 = getelementptr inbounds i16, ptr %ptr1, i64 1
%val1 = load i16, ptr %ptr1.1, align 2
%ptr1.2 = getelementptr inbounds i16, ptr %ptr1, i64 2
%val2 = load i16, ptr %ptr1.2, align 4
%ptr1.3 = getelementptr inbounds i16, ptr %ptr1, i64 3
%val3 = load i16, ptr %ptr1.3, align 2
%ptr1.4 = getelementptr inbounds i16, ptr %ptr1, i64 4
%val4 = load i16, ptr %ptr1.4, align 4
%ptr1.5 = getelementptr inbounds i16, ptr %ptr1, i64 5
%val5 = load i16, ptr %ptr1.5, align 2
%ptr1.6 = getelementptr inbounds i16, ptr %ptr1, i64 6
%val6 = load i16, ptr %ptr1.6, align 4
%ptr1.7 = getelementptr inbounds i16, ptr %ptr1, i64 7
%val7 = load i16, ptr %ptr1.7, align 2
%lane0 = zext i16 %val0 to i32
%lane1 = zext i16 %val1 to i32
%lane2 = zext i16 %val2 to i32
%lane3 = zext i16 %val3 to i32
%lane4 = zext i16 %val4 to i32
%lane5 = zext i16 %val5 to i32
%lane6 = zext i16 %val6 to i32
%lane7 = zext i16 %val7 to i32
%red.1 = add i32 %lane0, %lane1
%red.2 = add i32 %red.1, %lane2
%red.3 = add i32 %red.2, %lane3
%red.4 = add i32 %red.3, %lane4
%red.5 = add i32 %red.4, %lane5
%red.6 = add i32 %red.5, %lane6
%red = add i32 %red.6, %lane7
store i32 %red, ptr %ptr2, align 4
ret void
}

define void @combine_v4i32(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
; ENABLED-LABEL: combine_v4i32
; ENABLED: ld.v4.u32
%val0 = load i32, ptr %ptr1, align 16
%ptr1.1 = getelementptr inbounds i32, ptr %ptr1, i64 1
%val1 = load i32, ptr %ptr1.1, align 4
%ptr1.2 = getelementptr inbounds i32, ptr %ptr1, i64 2
%val2 = load i32, ptr %ptr1.2, align 8
%ptr1.3 = getelementptr inbounds i32, ptr %ptr1, i64 3
%val3 = load i32, ptr %ptr1.3, align 4
%red.1 = add i32 %val0, %val1
%red.2 = add i32 %red.1, %val2
%red = add i32 %red.2, %val3
store i32 %red, ptr %ptr2, align 4
ret void
}