diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 288fd3639e5eb..fdd266c396412 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -27924,6 +27924,21 @@ bool AArch64TargetLowering::isIntDivCheap(EVT VT, AttributeList Attr) const { return OptSize && !VT.isVector(); } +bool AArch64TargetLowering::canMergeStoresTo(unsigned AddressSpace, EVT MemVT, + const MachineFunction &MF) const { + // Avoid merging stores into fixed-length vectors when Neon is unavailable. + // In future, we could allow this when SVE is available, but currently, + // the SVE lowerings for BUILD_VECTOR are limited to a few specific cases (and + // the general lowering may introduce stack spills/reloads). + if (MemVT.isFixedLengthVector() && !Subtarget->isNeonAvailable()) + return false; + + // Do not merge to float value size (128 bytes) if no implicit float attribute + // is set. + bool NoFloat = MF.getFunction().hasFnAttribute(Attribute::NoImplicitFloat); + return !NoFloat || MemVT.getSizeInBits() <= 64; +} + bool AArch64TargetLowering::preferIncOfAddToSubOfNot(EVT VT) const { // We want inc-of-add for scalars and sub-of-not for vectors. return VT.isScalarInteger(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 1bae7562f459a..be8ab5ee76a05 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -849,16 +849,7 @@ class AArch64TargetLowering : public TargetLowering { bool isIntDivCheap(EVT VT, AttributeList Attr) const override; bool canMergeStoresTo(unsigned AddressSpace, EVT MemVT, - const MachineFunction &MF) const override { - // Do not merge to float value size (128 bytes) if no implicit - // float attribute is set. - - bool NoFloat = MF.getFunction().hasFnAttribute(Attribute::NoImplicitFloat); - - if (NoFloat) - return (MemVT.getSizeInBits() <= 64); - return true; - } + const MachineFunction &MF) const override; bool isCheapToSpeculateCttz(Type *) const override { return true; diff --git a/llvm/test/CodeGen/AArch64/consecutive-stores-of-faddv.ll b/llvm/test/CodeGen/AArch64/consecutive-stores-of-faddv.ll new file mode 100644 index 0000000000000..64482e15aed81 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/consecutive-stores-of-faddv.ll @@ -0,0 +1,92 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve,+sme -O3 < %s -o - | FileCheck %s --check-prefixes=CHECK + +; Tests consecutive stores of @llvm.aarch64.sve.faddv. Within SDAG faddv is +; lowered as a FADDV + EXTRACT_VECTOR_ELT (of lane 0). Stores of extracts can +; be matched by DAGCombiner::mergeConsecutiveStores(), which we want to avoid in +; some cases as it can lead to worse codegen. + +; TODO: A single `stp s0, s1, [x0]` may be preferred here. +define void @consecutive_stores_pair(ptr %dest0, %vec0, %vec1) { +; CHECK-LABEL: consecutive_stores_pair: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: faddv s0, p0, z0.s +; CHECK-NEXT: faddv s1, p0, z1.s +; CHECK-NEXT: mov v0.s[1], v1.s[0] +; CHECK-NEXT: str d0, [x0] +; CHECK-NEXT: ret + %dest1 = getelementptr inbounds i8, ptr %dest0, i64 4 + %reduce0 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec0) + %reduce1 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec1) + store float %reduce0, ptr %dest0, align 4 + store float %reduce1, ptr %dest1, align 4 + ret void +} + +define void @consecutive_stores_quadruple(ptr %dest0, %vec0, %vec1, %vec2, %vec3) { +; CHECK-LABEL: consecutive_stores_quadruple: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: faddv s0, p0, z0.s +; CHECK-NEXT: faddv s1, p0, z1.s +; CHECK-NEXT: faddv s2, p0, z2.s +; CHECK-NEXT: mov v0.s[1], v1.s[0] +; CHECK-NEXT: faddv s3, p0, z3.s +; CHECK-NEXT: mov v2.s[1], v3.s[0] +; CHECK-NEXT: stp d0, d2, [x0] +; CHECK-NEXT: ret + %dest1 = getelementptr inbounds i8, ptr %dest0, i64 4 + %dest2 = getelementptr inbounds i8, ptr %dest1, i64 4 + %dest3 = getelementptr inbounds i8, ptr %dest2, i64 4 + %reduce0 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec0) + %reduce1 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec1) + %reduce2 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec2) + %reduce3 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec3) + store float %reduce0, ptr %dest0, align 4 + store float %reduce1, ptr %dest1, align 4 + store float %reduce2, ptr %dest2, align 4 + store float %reduce3, ptr %dest3, align 4 + ret void +} + +define void @consecutive_stores_pair_streaming_function(ptr %dest0, %vec0, %vec1) "aarch64_pstate_sm_enabled" { +; CHECK-LABEL: consecutive_stores_pair_streaming_function: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: faddv s0, p0, z0.s +; CHECK-NEXT: faddv s1, p0, z1.s +; CHECK-NEXT: stp s0, s1, [x0] +; CHECK-NEXT: ret + %dest1 = getelementptr inbounds i8, ptr %dest0, i64 4 + %reduce0 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec0) + %reduce1 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec1) + store float %reduce0, ptr %dest0, align 4 + store float %reduce1, ptr %dest1, align 4 + ret void +} + +define void @consecutive_stores_quadruple_streaming_function(ptr %dest0, %vec0, %vec1, %vec2, %vec3) "aarch64_pstate_sm_enabled" { +; CHECK-LABEL: consecutive_stores_quadruple_streaming_function: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: faddv s0, p0, z0.s +; CHECK-NEXT: faddv s1, p0, z1.s +; CHECK-NEXT: faddv s2, p0, z2.s +; CHECK-NEXT: stp s0, s1, [x0] +; CHECK-NEXT: faddv s3, p0, z3.s +; CHECK-NEXT: stp s2, s3, [x0, #8] +; CHECK-NEXT: ret + %dest1 = getelementptr inbounds i8, ptr %dest0, i64 4 + %dest2 = getelementptr inbounds i8, ptr %dest1, i64 4 + %dest3 = getelementptr inbounds i8, ptr %dest2, i64 4 + %reduce0 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec0) + %reduce1 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec1) + %reduce2 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec2) + %reduce3 = call float @llvm.aarch64.sve.faddv.nxv4f32( splat(i1 true), %vec3) + store float %reduce0, ptr %dest0, align 4 + store float %reduce1, ptr %dest1, align 4 + store float %reduce2, ptr %dest2, align 4 + store float %reduce3, ptr %dest3, align 4 + ret void +}