diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 36da2e7b40efa..5d5a9188d1f47 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -693,8 +693,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand); // We have some custom DAG combine patterns for these nodes - setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL, - ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT, + setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD, + ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT}); // setcc for f16x2 and bf16x2 needs special handling to prevent @@ -5471,6 +5471,45 @@ static SDValue PerformVSELECTCombine(SDNode *N, return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E); } +static SDValue PerformLOADCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + SelectionDAG &DAG = DCI.DAG; + LoadSDNode *LD = cast(N); + + // Lower a v16i8 load into a LoadV4 operation with i32 results instead of + // letting ReplaceLoadVector split it into smaller loads during legalization. + // This is done at dag-combine1 time, so that vector operations with i8 + // elements can be optimised away instead of being needlessly split during + // legalization, which involves storing to the stack and loading it back. + EVT VT = N->getValueType(0); + if (VT != MVT::v16i8) + return SDValue(); + + SDLoc DL(N); + + // Create a v4i32 vector load operation, effectively <4 x v4i8>. + unsigned Opc = NVPTXISD::LoadV4; + EVT NewVT = MVT::v4i32; + EVT EltVT = NewVT.getVectorElementType(); + unsigned NumElts = NewVT.getVectorNumElements(); + EVT RetVTs[] = {EltVT, EltVT, EltVT, EltVT, MVT::Other}; + SDVTList RetVTList = DAG.getVTList(RetVTs); + SmallVector Ops(N->ops()); + Ops.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL)); + SDValue NewLoad = DAG.getMemIntrinsicNode(Opc, DL, RetVTList, Ops, NewVT, + LD->getMemOperand()); + SDValue NewChain = NewLoad.getValue(NumElts); + + // Create a vector of the same type returned by the original load. + SmallVector Elts; + for (unsigned i = 0; i < NumElts; i++) + Elts.push_back(NewLoad.getValue(i)); + return DCI.DAG.getMergeValues( + {DCI.DAG.getBitcast(VT, DCI.DAG.getBuildVector(NewVT, DL, Elts)), + NewChain}, + DL); +} + SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel(); @@ -5490,6 +5529,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return PerformREMCombine(N, DCI, OptLevel); case ISD::SETCC: return PerformSETCCCombine(N, DCI); + case ISD::LOAD: + return PerformLOADCombine(N, DCI); case NVPTXISD::StoreRetval: case NVPTXISD::StoreRetvalV2: case NVPTXISD::StoreRetvalV4: diff --git a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll index 4f13b6d9d1a8a..868a06e2a850c 100644 --- a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll +++ b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll @@ -52,3 +52,126 @@ define float @ff(ptr %p) { %sum = fadd float %sum3, %v4 ret float %sum } + +define void @combine_v16i8(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) { + ; ENABLED-LABEL: combine_v16i8 + ; ENABLED: ld.v4.u32 + %val0 = load i8, ptr %ptr1, align 16 + %ptr1.1 = getelementptr inbounds i8, ptr %ptr1, i64 1 + %val1 = load i8, ptr %ptr1.1, align 1 + %ptr1.2 = getelementptr inbounds i8, ptr %ptr1, i64 2 + %val2 = load i8, ptr %ptr1.2, align 2 + %ptr1.3 = getelementptr inbounds i8, ptr %ptr1, i64 3 + %val3 = load i8, ptr %ptr1.3, align 1 + %ptr1.4 = getelementptr inbounds i8, ptr %ptr1, i64 4 + %val4 = load i8, ptr %ptr1.4, align 4 + %ptr1.5 = getelementptr inbounds i8, ptr %ptr1, i64 5 + %val5 = load i8, ptr %ptr1.5, align 1 + %ptr1.6 = getelementptr inbounds i8, ptr %ptr1, i64 6 + %val6 = load i8, ptr %ptr1.6, align 2 + %ptr1.7 = getelementptr inbounds i8, ptr %ptr1, i64 7 + %val7 = load i8, ptr %ptr1.7, align 1 + %ptr1.8 = getelementptr inbounds i8, ptr %ptr1, i64 8 + %val8 = load i8, ptr %ptr1.8, align 8 + %ptr1.9 = getelementptr inbounds i8, ptr %ptr1, i64 9 + %val9 = load i8, ptr %ptr1.9, align 1 + %ptr1.10 = getelementptr inbounds i8, ptr %ptr1, i64 10 + %val10 = load i8, ptr %ptr1.10, align 2 + %ptr1.11 = getelementptr inbounds i8, ptr %ptr1, i64 11 + %val11 = load i8, ptr %ptr1.11, align 1 + %ptr1.12 = getelementptr inbounds i8, ptr %ptr1, i64 12 + %val12 = load i8, ptr %ptr1.12, align 4 + %ptr1.13 = getelementptr inbounds i8, ptr %ptr1, i64 13 + %val13 = load i8, ptr %ptr1.13, align 1 + %ptr1.14 = getelementptr inbounds i8, ptr %ptr1, i64 14 + %val14 = load i8, ptr %ptr1.14, align 2 + %ptr1.15 = getelementptr inbounds i8, ptr %ptr1, i64 15 + %val15 = load i8, ptr %ptr1.15, align 1 + %lane0 = zext i8 %val0 to i32 + %lane1 = zext i8 %val1 to i32 + %lane2 = zext i8 %val2 to i32 + %lane3 = zext i8 %val3 to i32 + %lane4 = zext i8 %val4 to i32 + %lane5 = zext i8 %val5 to i32 + %lane6 = zext i8 %val6 to i32 + %lane7 = zext i8 %val7 to i32 + %lane8 = zext i8 %val8 to i32 + %lane9 = zext i8 %val9 to i32 + %lane10 = zext i8 %val10 to i32 + %lane11 = zext i8 %val11 to i32 + %lane12 = zext i8 %val12 to i32 + %lane13 = zext i8 %val13 to i32 + %lane14 = zext i8 %val14 to i32 + %lane15 = zext i8 %val15 to i32 + %red.1 = add i32 %lane0, %lane1 + %red.2 = add i32 %red.1, %lane2 + %red.3 = add i32 %red.2, %lane3 + %red.4 = add i32 %red.3, %lane4 + %red.5 = add i32 %red.4, %lane5 + %red.6 = add i32 %red.5, %lane6 + %red.7 = add i32 %red.6, %lane7 + %red.8 = add i32 %red.7, %lane8 + %red.9 = add i32 %red.8, %lane9 + %red.10 = add i32 %red.9, %lane10 + %red.11 = add i32 %red.10, %lane11 + %red.12 = add i32 %red.11, %lane12 + %red.13 = add i32 %red.12, %lane13 + %red.14 = add i32 %red.13, %lane14 + %red = add i32 %red.14, %lane15 + store i32 %red, ptr %ptr2, align 4 + ret void +} + +define void @combine_v8i16(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) { + ; ENABLED-LABEL: combine_v8i16 + ; ENABLED: ld.v4.b32 + %val0 = load i16, ptr %ptr1, align 16 + %ptr1.1 = getelementptr inbounds i16, ptr %ptr1, i64 1 + %val1 = load i16, ptr %ptr1.1, align 2 + %ptr1.2 = getelementptr inbounds i16, ptr %ptr1, i64 2 + %val2 = load i16, ptr %ptr1.2, align 4 + %ptr1.3 = getelementptr inbounds i16, ptr %ptr1, i64 3 + %val3 = load i16, ptr %ptr1.3, align 2 + %ptr1.4 = getelementptr inbounds i16, ptr %ptr1, i64 4 + %val4 = load i16, ptr %ptr1.4, align 4 + %ptr1.5 = getelementptr inbounds i16, ptr %ptr1, i64 5 + %val5 = load i16, ptr %ptr1.5, align 2 + %ptr1.6 = getelementptr inbounds i16, ptr %ptr1, i64 6 + %val6 = load i16, ptr %ptr1.6, align 4 + %ptr1.7 = getelementptr inbounds i16, ptr %ptr1, i64 7 + %val7 = load i16, ptr %ptr1.7, align 2 + %lane0 = zext i16 %val0 to i32 + %lane1 = zext i16 %val1 to i32 + %lane2 = zext i16 %val2 to i32 + %lane3 = zext i16 %val3 to i32 + %lane4 = zext i16 %val4 to i32 + %lane5 = zext i16 %val5 to i32 + %lane6 = zext i16 %val6 to i32 + %lane7 = zext i16 %val7 to i32 + %red.1 = add i32 %lane0, %lane1 + %red.2 = add i32 %red.1, %lane2 + %red.3 = add i32 %red.2, %lane3 + %red.4 = add i32 %red.3, %lane4 + %red.5 = add i32 %red.4, %lane5 + %red.6 = add i32 %red.5, %lane6 + %red = add i32 %red.6, %lane7 + store i32 %red, ptr %ptr2, align 4 + ret void +} + +define void @combine_v4i32(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) { + ; ENABLED-LABEL: combine_v4i32 + ; ENABLED: ld.v4.u32 + %val0 = load i32, ptr %ptr1, align 16 + %ptr1.1 = getelementptr inbounds i32, ptr %ptr1, i64 1 + %val1 = load i32, ptr %ptr1.1, align 4 + %ptr1.2 = getelementptr inbounds i32, ptr %ptr1, i64 2 + %val2 = load i32, ptr %ptr1.2, align 8 + %ptr1.3 = getelementptr inbounds i32, ptr %ptr1, i64 3 + %val3 = load i32, ptr %ptr1.3, align 4 + %red.1 = add i32 %val0, %val1 + %red.2 = add i32 %red.1, %val2 + %red = add i32 %red.2, %val3 + store i32 %red, ptr %ptr2, align 4 + ret void +}