diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index d3cd9b1671e1b..45114b85e25d8 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -8668,6 +8668,7 @@ using SDByteProvider = ByteProvider; static std::optional calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, std::optional VectorIndex, + SmallPtrSetImpl &ExtractElements, unsigned StartingIndex = 0) { // Typical i64 by i8 pattern requires recursion up to 8 calls depth @@ -8694,12 +8695,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, switch (Op.getOpcode()) { case ISD::OR: { - auto LHS = - calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex); + auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1, + VectorIndex, ExtractElements); if (!LHS) return std::nullopt; - auto RHS = - calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex); + auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1, + VectorIndex, ExtractElements); if (!RHS) return std::nullopt; @@ -8726,7 +8727,8 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, return Index < ByteShift ? SDByteProvider::getConstantZero() : calculateByteProvider(Op->getOperand(0), Index - ByteShift, - Depth + 1, VectorIndex, Index); + Depth + 1, VectorIndex, ExtractElements, + Index); } case ISD::ANY_EXTEND: case ISD::SIGN_EXTEND: @@ -8743,11 +8745,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, SDByteProvider::getConstantZero()) : std::nullopt; return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex, - StartingIndex); + ExtractElements, StartingIndex); } case ISD::BSWAP: return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1, - Depth + 1, VectorIndex, StartingIndex); + Depth + 1, VectorIndex, ExtractElements, + StartingIndex); case ISD::EXTRACT_VECTOR_ELT: { auto OffsetOp = dyn_cast(Op->getOperand(1)); if (!OffsetOp) @@ -8772,8 +8775,9 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex) return std::nullopt; + ExtractElements.insert(Op.getNode()); return calculateByteProvider(Op->getOperand(0), Index, Depth + 1, - VectorIndex, StartingIndex); + VectorIndex, ExtractElements, StartingIndex); } case ISD::LOAD: { auto L = cast(Op.getNode()); @@ -9110,6 +9114,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { SDValue Chain; SmallPtrSet Loads; + SmallPtrSet ExtractElements; std::optional FirstByteProvider; int64_t FirstOffset = INT64_MAX; @@ -9119,7 +9124,9 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { unsigned ZeroExtendedBytes = 0; for (int i = ByteWidth - 1; i >= 0; --i) { auto P = - calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt, + calculateByteProvider(SDValue(N, 0), i, 0, + /*VectorIndex*/ std::nullopt, ExtractElements, + /*StartingIndex*/ i); if (!P) return SDValue(); @@ -9245,6 +9252,14 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { if (!Allowed || !Fast) return SDValue(); + // calculatebyteProvider() allows multi-use for vector loads. Ensure that + // all uses are in vector element extracts that are part of the pattern. + for (LoadSDNode *L : Loads) + if (L->getMemoryVT().isVector()) + for (auto It = L->use_begin(); It != L->use_end(); ++It) + if (It.getUse().getResNo() == 0 && !ExtractElements.contains(*It)) + return SDValue(); + SDValue NewLoad = DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT, Chain, FirstLoad->getBasePtr(), diff --git a/llvm/test/CodeGen/AArch64/load-combine.ll b/llvm/test/CodeGen/AArch64/load-combine.ll index 57f61e5303ecf..b30ee45aa4d1a 100644 --- a/llvm/test/CodeGen/AArch64/load-combine.ll +++ b/llvm/test/CodeGen/AArch64/load-combine.ll @@ -606,10 +606,12 @@ define void @short_vector_to_i32_unused_high_i8(ptr %in, ptr %out, ptr %p) { ; CHECK-LABEL: short_vector_to_i32_unused_high_i8: ; CHECK: // %bb.0: ; CHECK-NEXT: ldr s0, [x0] -; CHECK-NEXT: ldrh w9, [x0] ; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: umov w8, v0.h[2] -; CHECK-NEXT: orr w8, w9, w8, lsl #16 +; CHECK-NEXT: umov w8, v0.h[1] +; CHECK-NEXT: umov w9, v0.h[0] +; CHECK-NEXT: umov w10, v0.h[2] +; CHECK-NEXT: bfi w9, w8, #8, #8 +; CHECK-NEXT: orr w8, w9, w10, lsl #16 ; CHECK-NEXT: str w8, [x1] ; CHECK-NEXT: ret %ld = load <4 x i8>, ptr %in, align 4 diff --git a/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll b/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll index c27e44609c527..9692108280182 100644 --- a/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll +++ b/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll @@ -205,12 +205,14 @@ define i64 @load_3xi16_combine(ptr addrspace(1) %p) #0 { ; GCN-LABEL: load_3xi16_combine: ; GCN: ; %bb.0: ; GCN-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -; GCN-NEXT: global_load_dword v2, v[0:1], off -; GCN-NEXT: global_load_ushort v3, v[0:1], off offset:4 +; GCN-NEXT: global_load_dword v3, v[0:1], off +; GCN-NEXT: global_load_ushort v2, v[0:1], off offset:4 +; GCN-NEXT: s_mov_b32 s4, 0xffff ; GCN-NEXT: s_waitcnt vmcnt(1) -; GCN-NEXT: v_mov_b32_e32 v0, v2 +; GCN-NEXT: v_and_b32_e32 v0, 0xffff0000, v3 +; GCN-NEXT: v_and_or_b32 v0, v3, s4, v0 ; GCN-NEXT: s_waitcnt vmcnt(0) -; GCN-NEXT: v_mov_b32_e32 v1, v3 +; GCN-NEXT: v_mov_b32_e32 v1, v2 ; GCN-NEXT: s_setpc_b64 s[30:31] %gep.p = getelementptr i16, ptr addrspace(1) %p, i32 1 %gep.2p = getelementptr i16, ptr addrspace(1) %p, i32 2 diff --git a/llvm/test/CodeGen/X86/load-combine.ll b/llvm/test/CodeGen/X86/load-combine.ll index 7e4e11fcc75c2..530e17a0b0f09 100644 --- a/llvm/test/CodeGen/X86/load-combine.ll +++ b/llvm/test/CodeGen/X86/load-combine.ll @@ -1283,26 +1283,33 @@ define i32 @zext_load_i32_by_i8_bswap_shl_16(ptr %arg) { ret i32 %tmp8 } -; FIXME: This is a miscompile. define i32 @pr80911_vector_load_multiuse(ptr %ptr, ptr %clobber) nounwind { ; CHECK-LABEL: pr80911_vector_load_multiuse: ; CHECK: # %bb.0: +; CHECK-NEXT: pushl %edi ; CHECK-NEXT: pushl %esi -; CHECK-NEXT: movl {{[0-9]+}}(%esp), %ecx ; CHECK-NEXT: movl {{[0-9]+}}(%esp), %edx -; CHECK-NEXT: movl (%edx), %esi -; CHECK-NEXT: movzwl (%edx), %eax -; CHECK-NEXT: movl $0, (%ecx) -; CHECK-NEXT: movl %esi, (%edx) +; CHECK-NEXT: movl {{[0-9]+}}(%esp), %esi +; CHECK-NEXT: movzbl (%esi), %ecx +; CHECK-NEXT: movzbl 1(%esi), %eax +; CHECK-NEXT: movzwl 2(%esi), %edi +; CHECK-NEXT: movl $0, (%edx) +; CHECK-NEXT: movw %di, 2(%esi) +; CHECK-NEXT: movb %al, 1(%esi) +; CHECK-NEXT: movb %cl, (%esi) +; CHECK-NEXT: shll $8, %eax +; CHECK-NEXT: orl %ecx, %eax ; CHECK-NEXT: popl %esi +; CHECK-NEXT: popl %edi ; CHECK-NEXT: retl ; ; CHECK64-LABEL: pr80911_vector_load_multiuse: ; CHECK64: # %bb.0: -; CHECK64-NEXT: movzwl (%rdi), %eax +; CHECK64-NEXT: movaps (%rdi), %xmm0 ; CHECK64-NEXT: movl $0, (%rsi) -; CHECK64-NEXT: movl (%rdi), %ecx -; CHECK64-NEXT: movl %ecx, (%rdi) +; CHECK64-NEXT: movss %xmm0, (%rdi) +; CHECK64-NEXT: movaps %xmm0, -{{[0-9]+}}(%rsp) +; CHECK64-NEXT: movzwl -{{[0-9]+}}(%rsp), %eax ; CHECK64-NEXT: retq %load = load <4 x i8>, ptr %ptr, align 16 store i32 0, ptr %clobber