From 561706ff1ed18f9b1924df417dbe1c2a4ff65432 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 29 Jul 2024 10:46:16 +0100 Subject: [PATCH 01/33] [AArch64] Lower add partial reduction to udot This patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. --- llvm/include/llvm/CodeGen/TargetLowering.h | 6 + .../SelectionDAG/SelectionDAGBuilder.cpp | 6 + .../Target/AArch64/AArch64ISelLowering.cpp | 77 +++++++++++++ llvm/lib/Target/AArch64/AArch64ISelLowering.h | 2 + .../AArch64/AArch64TargetTransformInfo.cpp | 30 +++++ .../AArch64/AArch64TargetTransformInfo.h | 6 + .../AArch64/partial-reduce-dot-product.ll | 109 ++++++++++++++++++ 7 files changed, 236 insertions(+) create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index eda38cd8a564d..883a2252f7ffe 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -453,6 +453,12 @@ class TargetLoweringBase { return true; } + /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic + /// should be expanded using generic code in SelectionDAGBuilder. + virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const { + return true; + } + /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded /// using generic code in SelectionDAGBuilder. virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 60dcb11854278..5ddbf9f414d21 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -8005,6 +8005,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, return; } case Intrinsic::experimental_vector_partial_reduce_add: { + + if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) { + visitTargetIntrinsic(I, Intrinsic); + return; + } + SDValue OpNode = getValue(I.getOperand(1)); EVT ReducedTy = EVT::getEVT(I.getType()); EVT FullTy = OpNode.getValueType(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 215f30128e703..987a7290274e7 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1988,6 +1988,57 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT, return false; } +bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( + const CallInst *CI) const { + const bool TargetLowers = false; + const bool GenericLowers = true; + + auto *I = dyn_cast(CI); + if (!I) + return GenericLowers; + + ScalableVectorType *RetTy = dyn_cast(I->getType()); + + if (!RetTy) + return GenericLowers; + + ScalableVectorType *InputTy = nullptr; + + auto RetScalarTy = RetTy->getScalarType(); + if (RetScalarTy->isIntegerTy(64)) { + InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8); + } else if (RetScalarTy->isIntegerTy(32)) { + InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16); + } + + if (!InputTy) + return GenericLowers; + + Value *InputA; + Value *InputB; + + auto Pattern = m_Intrinsic( + m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))), + m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))); + + if (!match(I, Pattern)) + return GenericLowers; + + auto Mul = cast(I->getOperand(1)); + + auto getOpcodeOfOperand = [&](unsigned Idx) { + return cast(Mul->getOperand(Idx))->getOpcode(); + }; + + if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1)) + return GenericLowers; + + if (InputA->getType() != InputTy || InputB->getType() != InputTy) + return GenericLowers; + + return TargetLowers; +} + bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const { if (!Subtarget->isSVEorStreamingSVEAvailable()) return true; @@ -21765,6 +21816,32 @@ static SDValue performIntrinsicCombine(SDNode *N, switch (IID) { default: break; + case Intrinsic::experimental_vector_partial_reduce_add: { + SDLoc DL(N); + + auto NarrowOp = N->getOperand(1); + auto MulOp = N->getOperand(2); + + auto ExtA = MulOp->getOperand(0); + auto ExtB = MulOp->getOperand(1); + + unsigned DotIntrinsicId = Intrinsic::not_intrinsic; + + if (ExtA->getOpcode() == ISD::SIGN_EXTEND) + DotIntrinsicId = Intrinsic::aarch64_sve_sdot; + else if (ExtA->getOpcode() == ISD::ZERO_EXTEND) + DotIntrinsicId = Intrinsic::aarch64_sve_udot; + + assert(DotIntrinsicId != Intrinsic::not_intrinsic && + "Unexpected dot product case encountered."); + + auto A = ExtA->getOperand(0); + auto B = ExtB->getOperand(0); + + auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(), + {IntrinsicId, NarrowOp, A, B}); + } case Intrinsic::aarch64_neon_vcvtfxs2fp: case Intrinsic::aarch64_neon_vcvtfxu2fp: return tryCombineFixedPointConvert(N, DCI, DAG); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 39d5df0de0eec..9fe95ddaca32c 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -998,6 +998,8 @@ class AArch64TargetLowering : public TargetLowering { bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override; + bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override; + bool shouldExpandCttzElements(EVT VT) const override; /// If a change in streaming mode is required on entry to/return from a diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index dc748290f2e21..5871134e60985 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -3670,6 +3670,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef Tys) { return Cost; } +bool AArch64TTIImpl::isPartialReductionSupported( + const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor, + bool IsInputASignExtended, bool IsInputBSignExtended, + const Instruction *BinOp) const { + if (ReductionInstr->getOpcode() != Instruction::Add) + return false; + + // Check that both extends are of the same type + if (IsInputASignExtended != IsInputBSignExtended) + return false; + + if (!BinOp || BinOp->getOpcode() != Instruction::Mul) + return false; + + // Dot product only supports a scale factor of 4 + if (ScaleFactor != 4) + return false; + + Type *ReductionType = ReductionInstr->getType(); + if (ReductionType->isIntegerTy(32)) { + if (!InputType->isIntegerTy(8)) + return false; + } else if (ReductionType->isIntegerTy(64)) { + if (!InputType->isIntegerTy(16)) + return false; + } + + return true; +} + unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) { return ST->getMaxInterleaveFactor(); } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 4a6457d7a7dbf..af7e8e8e497dd 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -155,6 +155,12 @@ class AArch64TTIImpl : public BasicTTIImplBase { return VF.getKnownMinValue() * ST->getVScaleForTuning(); } + bool isPartialReductionSupported(const Instruction *ReductionInstr, + Type *InputType, unsigned ScaleFactor, + bool IsInputASignExtended, + bool IsInputBSignExtended, + const Instruction *BinOp = nullptr) const; + unsigned getMaxInterleaveFactor(ElementCount VF); bool prefersVectorizedAddressing() const; diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll new file mode 100644 index 0000000000000..23b39387fb7a0 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll @@ -0,0 +1,109 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-none-unknown-elf" + +define @dotp( %a, %b) #0 { +; CHECK-LABEL: dotp: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z2.s, #0 // =0x0 +; CHECK-NEXT: udot z2.s, z0.b, z1.b +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( zeroinitializer, %mult) + ret %partial.reduce +} + +define @dotp_wide( %a, %b) #0 { +; CHECK-LABEL: dotp_wide: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z2.d, #0 // =0x0 +; CHECK-NEXT: udot z2.d, z0.h, z1.h +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( zeroinitializer, %mult) + ret %partial.reduce +} + +define @dotp_sext( %a, %b) #0 { +; CHECK-LABEL: dotp_sext: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z2.s, #0 // =0x0 +; CHECK-NEXT: sdot z2.s, z0.b, z1.b +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( zeroinitializer, %mult) + ret %partial.reduce +} + +define @dotp_wide_sext( %a, %b) #0 { +; CHECK-LABEL: dotp_wide_sext: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z2.d, #0 // =0x0 +; CHECK-NEXT: sdot z2.d, z0.h, z1.h +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( zeroinitializer, %mult) + ret %partial.reduce +} + +define @not_dotp( %a, %b) #0 { +; CHECK-LABEL: not_dotp: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: and z0.h, z0.h, #0xff +; CHECK-NEXT: and z1.h, z1.h, #0xff +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: uunpkhi z2.s, z0.h +; CHECK-NEXT: uunpkhi z3.s, z1.h +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: uunpklo z1.s, z1.h +; CHECK-NEXT: mul z2.s, z2.s, z3.s +; CHECK-NEXT: mad z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( zeroinitializer, %mult) + ret %partial.reduce +} + +define @not_dotp_wide( %a, %b) #0 { +; CHECK-LABEL: not_dotp_wide: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: and z0.s, z0.s, #0xffff +; CHECK-NEXT: and z1.s, z1.s, #0xffff +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: uunpkhi z2.d, z0.s +; CHECK-NEXT: uunpkhi z3.d, z1.s +; CHECK-NEXT: uunpklo z0.d, z0.s +; CHECK-NEXT: uunpklo z1.d, z1.s +; CHECK-NEXT: mul z2.d, z2.d, z3.d +; CHECK-NEXT: mad z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64( zeroinitializer, %mult) + ret %partial.reduce +} + +attributes #0 = { "target-features"="+sve2" } From 563d025161d0b0e00253daf8943926b8e6ebfc75 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 29 Jul 2024 18:08:52 +0100 Subject: [PATCH 02/33] Remove TargetLowers and GenericLowers --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 987a7290274e7..f24e24bfe2d45 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1990,17 +1990,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT, bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( const CallInst *CI) const { - const bool TargetLowers = false; - const bool GenericLowers = true; auto *I = dyn_cast(CI); if (!I) - return GenericLowers; + return true; ScalableVectorType *RetTy = dyn_cast(I->getType()); if (!RetTy) - return GenericLowers; + return true; ScalableVectorType *InputTy = nullptr; @@ -2012,7 +2010,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( } if (!InputTy) - return GenericLowers; + return true; Value *InputA; Value *InputB; @@ -2022,7 +2020,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))); if (!match(I, Pattern)) - return GenericLowers; + return true; auto Mul = cast(I->getOperand(1)); @@ -2031,12 +2029,12 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( }; if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1)) - return GenericLowers; + return true; if (InputA->getType() != InputTy || InputB->getType() != InputTy) - return GenericLowers; + return true; - return TargetLowers; + return false; } bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const { From e604e4570412adee32b2ce363f8ceea9c7438349 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 29 Jul 2024 18:16:16 +0100 Subject: [PATCH 03/33] Assert that shouldExpandPartialReductionIntrinsic sees an intrinsic --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index f24e24bfe2d45..8811f2ef94fb0 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1992,8 +1992,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( const CallInst *CI) const { auto *I = dyn_cast(CI); - if (!I) - return true; + assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc"); ScalableVectorType *RetTy = dyn_cast(I->getType()); From 9b23c96f5cd7a5ce10229682e61b1d8c5464b01e Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 29 Jul 2024 18:31:24 +0100 Subject: [PATCH 04/33] Allow non-scalable vector types --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 8811f2ef94fb0..ad38410ba3f27 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1994,18 +1994,17 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( auto *I = dyn_cast(CI); assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc"); - ScalableVectorType *RetTy = dyn_cast(I->getType()); - + VectorType *RetTy = dyn_cast(I->getType()); if (!RetTy) return true; - ScalableVectorType *InputTy = nullptr; + VectorType *InputTy = nullptr; auto RetScalarTy = RetTy->getScalarType(); if (RetScalarTy->isIntegerTy(64)) { - InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8); + InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy()); } else if (RetScalarTy->isIntegerTy(32)) { - InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16); + InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy()); } if (!InputTy) From 45692dff2a6498d3ebc53a5862ef68239a772ab9 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 29 Jul 2024 19:10:08 +0100 Subject: [PATCH 05/33] Clean up type checking --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index ad38410ba3f27..f96be2a9f55e6 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2001,13 +2001,11 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( VectorType *InputTy = nullptr; auto RetScalarTy = RetTy->getScalarType(); - if (RetScalarTy->isIntegerTy(64)) { + if (RetScalarTy->isIntegerTy(64)) InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy()); - } else if (RetScalarTy->isIntegerTy(32)) { + else if (RetScalarTy->isIntegerTy(32)) InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy()); - } - - if (!InputTy) + else return true; Value *InputA; @@ -2021,7 +2019,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( return true; auto Mul = cast(I->getOperand(1)); - auto getOpcodeOfOperand = [&](unsigned Idx) { return cast(Mul->getOperand(Idx))->getOpcode(); }; From d305452d3cf171b320cdd25720247a4d61fb1a31 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Thu, 1 Aug 2024 11:04:37 +0100 Subject: [PATCH 06/33] Restrict to scalable vector types and clean up type checking --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +- .../lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index f96be2a9f55e6..31c3e208356bf 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1995,7 +1995,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc"); VectorType *RetTy = dyn_cast(I->getType()); - if (!RetTy) + if (!RetTy || !RetTy->isScalableTy()) return true; VectorType *InputTy = nullptr; diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 5871134e60985..8cd2ba17b7d79 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -3689,15 +3689,10 @@ bool AArch64TTIImpl::isPartialReductionSupported( return false; Type *ReductionType = ReductionInstr->getType(); - if (ReductionType->isIntegerTy(32)) { - if (!InputType->isIntegerTy(8)) - return false; - } else if (ReductionType->isIntegerTy(64)) { - if (!InputType->isIntegerTy(16)) - return false; - } - return true; + return ((ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) || + (ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16))) && + ReductionType->isScalableTy(); } unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) { From 4738a204c93ba2386783afb375b95c806e111522 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Thu, 1 Aug 2024 11:36:50 +0100 Subject: [PATCH 07/33] Simplify instruction matching in shouldExpandPartialReduction --- .../Target/AArch64/AArch64ISelLowering.cpp | 56 +++++++++---------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 31c3e208356bf..c4d00abfc190f 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1998,38 +1998,36 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( if (!RetTy || !RetTy->isScalableTy()) return true; - VectorType *InputTy = nullptr; - - auto RetScalarTy = RetTy->getScalarType(); - if (RetScalarTy->isIntegerTy(64)) - InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy()); - else if (RetScalarTy->isIntegerTy(32)) - InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy()); - else - return true; - Value *InputA; Value *InputB; + if (match(I, m_Intrinsic( + m_Value(), + m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))), + m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) { + VectorType *InputAType = dyn_cast(InputA->getType()); + VectorType *InputBType = dyn_cast(InputB->getType()); + if (!InputAType || !InputBType) + return true; + ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy()); + ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy()); + if ((RetTy->getScalarType()->isIntegerTy(64) && + InputAType->getElementType()->isIntegerTy(16) && + InputAType->getElementCount() == ExpectedCount8 && + InputAType == InputBType) || + + (RetTy->getScalarType()->isIntegerTy(32) && + InputAType->getElementType()->isIntegerTy(8) && + InputAType->getElementCount() == ExpectedCount16 && + InputAType == InputBType)) { + auto *Mul = cast(I->getOperand(1)); + auto *Mul0 = cast(Mul->getOperand(0)); + auto *Mul1 = cast(Mul->getOperand(1)); + if (Mul0->getOpcode() == Mul1->getOpcode()) + return false; + } + } - auto Pattern = m_Intrinsic( - m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))), - m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))); - - if (!match(I, Pattern)) - return true; - - auto Mul = cast(I->getOperand(1)); - auto getOpcodeOfOperand = [&](unsigned Idx) { - return cast(Mul->getOperand(Idx))->getOpcode(); - }; - - if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1)) - return true; - - if (InputA->getType() != InputTy || InputB->getType() != InputTy) - return true; - - return false; + return true; } bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const { From 4dbf99e959674b91c743340560816d926480e2a3 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Fri, 9 Aug 2024 16:38:22 +0100 Subject: [PATCH 08/33] Add fallback in case the nodes aren't as we expect at lowering time --- .../Target/AArch64/AArch64ISelLowering.cpp | 67 ++++++++++++++++--- 1 file changed, 59 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index c4d00abfc190f..a5ea612fb3899 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21810,28 +21810,79 @@ static SDValue performIntrinsicCombine(SDNode *N, case Intrinsic::experimental_vector_partial_reduce_add: { SDLoc DL(N); + bool IsValidDotProduct = false; + auto NarrowOp = N->getOperand(1); auto MulOp = N->getOperand(2); + if (MulOp->getOpcode() == ISD::MUL) + IsValidDotProduct = true; auto ExtA = MulOp->getOperand(0); auto ExtB = MulOp->getOperand(1); + bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND; + bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; + if (ExtA->getOpcode() == ExtB->getOpcode() && (IsSExt || IsZExt)) + IsValidDotProduct = true; unsigned DotIntrinsicId = Intrinsic::not_intrinsic; - if (ExtA->getOpcode() == ISD::SIGN_EXTEND) + if (IsSExt && IsValidDotProduct) DotIntrinsicId = Intrinsic::aarch64_sve_sdot; - else if (ExtA->getOpcode() == ISD::ZERO_EXTEND) + else if (IsZExt && IsValidDotProduct) DotIntrinsicId = Intrinsic::aarch64_sve_udot; - assert(DotIntrinsicId != Intrinsic::not_intrinsic && + assert((!IsValidDotProduct || DotIntrinsicId != Intrinsic::not_intrinsic) && "Unexpected dot product case encountered."); - auto A = ExtA->getOperand(0); - auto B = ExtB->getOperand(0); + if (IsValidDotProduct) { + auto A = ExtA->getOperand(0); + auto B = ExtB->getOperand(0); + + auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(), + {IntrinsicId, NarrowOp, A, B}); + } else { + // If the node doesn't match a dot product, lower to a series of ADDs + // instead. + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + EVT Type0 = Op0->getValueType(0); + EVT Type1 = Op1->getValueType(0); + + // Canonicalise so that Op1 has the larger type + if (Type1.getVectorNumElements() > Type0.getVectorNumElements()) { + std::swap(Op0, Op1); + std::swap(Type0, Type1); + } - auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64); - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(), - {IntrinsicId, NarrowOp, A, B}); + auto Type0Elements = Type0.getVectorNumElements(); + auto Type1Elements = Type1.getVectorNumElements(); + auto Type0ElementSize = + Type0.getVectorElementType().getScalarSizeInBits(); + auto Type1ElementSize = + Type1.getVectorElementType().getScalarSizeInBits(); + + // If the types are equal then a single ADD is fine + if (Type0 == Type1) + return DAG.getNode(ISD::ADD, DL, Type0, {Op0, Op1}); + + // Otherwise, we need to add each subvector together so that the output is + // the intrinsic's return type. For example, <4 x i32> + // partial.reduction(<4 x i32> a, <16 x i32> b) becomes a + b[0..3] + + // b[4..7] + b[8..11] + b[12..15] + SDValue Add = Op0; + for (unsigned i = 0; i < Type1Elements / Type0Elements; i++) { + SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1, + DAG.getConstant(i, DL, MVT::i64)); + + if (Type1ElementSize < Type0ElementSize) + Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec); + else if (Type1ElementSize > Type0ElementSize) + Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec); + Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec}); + } + return Add; + } } case Intrinsic::aarch64_neon_vcvtfxs2fp: case Intrinsic::aarch64_neon_vcvtfxu2fp: From c068775b5ddb75a9cd31b6443551c5a0e9cab496 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 12 Aug 2024 11:02:28 +0100 Subject: [PATCH 09/33] Fix logic error with fallback case --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a5ea612fb3899..56327cc074a47 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21810,19 +21810,19 @@ static SDValue performIntrinsicCombine(SDNode *N, case Intrinsic::experimental_vector_partial_reduce_add: { SDLoc DL(N); - bool IsValidDotProduct = false; + bool IsValidDotProduct = true; auto NarrowOp = N->getOperand(1); auto MulOp = N->getOperand(2); - if (MulOp->getOpcode() == ISD::MUL) - IsValidDotProduct = true; + if (MulOp->getOpcode() != ISD::MUL) + IsValidDotProduct = false; auto ExtA = MulOp->getOperand(0); auto ExtB = MulOp->getOperand(1); bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND; bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; - if (ExtA->getOpcode() == ExtB->getOpcode() && (IsSExt || IsZExt)) - IsValidDotProduct = true; + if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt)) + IsValidDotProduct = false; unsigned DotIntrinsicId = Intrinsic::not_intrinsic; @@ -21844,8 +21844,8 @@ static SDValue performIntrinsicCombine(SDNode *N, } else { // If the node doesn't match a dot product, lower to a series of ADDs // instead. - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); + SDValue Op0 = N->getOperand(1); + SDValue Op1 = N->getOperand(2); EVT Type0 = Op0->getValueType(0); EVT Type1 = Op1->getValueType(0); From 636652d0ad53a8476270ffa01cee73483081ca4a Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Tue, 13 Aug 2024 14:18:53 +0100 Subject: [PATCH 10/33] Pass IntrinsicInst to shouldExpandPartialReductionIntrinsic --- llvm/include/llvm/CodeGen/TargetLowering.h | 3 ++- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 2 +- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 5 +---- llvm/lib/Target/AArch64/AArch64ISelLowering.h | 3 ++- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 883a2252f7ffe..e17d68d2690c8 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -455,7 +455,8 @@ class TargetLoweringBase { /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic /// should be expanded using generic code in SelectionDAGBuilder. - virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const { + virtual bool + shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const { return true; } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 5ddbf9f414d21..05cbe384cb5ed 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -8006,7 +8006,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, } case Intrinsic::experimental_vector_partial_reduce_add: { - if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) { + if (!TLI.shouldExpandPartialReductionIntrinsic(cast(&I))) { visitTargetIntrinsic(I, Intrinsic); return; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 56327cc074a47..916eccd52e939 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1989,10 +1989,7 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT, } bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( - const CallInst *CI) const { - - auto *I = dyn_cast(CI); - assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc"); + const IntrinsicInst *I) const { VectorType *RetTy = dyn_cast(I->getType()); if (!RetTy || !RetTy->isScalableTy()) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 9fe95ddaca32c..f9d45b02d30e3 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -998,7 +998,8 @@ class AArch64TargetLowering : public TargetLowering { bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override; - bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override; + bool + shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override; bool shouldExpandCttzElements(EVT VT) const override; From 83015b7e08e5c0ccfa55cf20423292a9e29d2a2a Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Tue, 13 Aug 2024 14:31:22 +0100 Subject: [PATCH 11/33] Remove one-use restriction --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 916eccd52e939..cd70d2e3cdfa7 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1997,10 +1997,10 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( Value *InputA; Value *InputB; - if (match(I, m_Intrinsic( - m_Value(), - m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))), - m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) { + if (match(I, + m_Intrinsic( + m_Value(), m_OneUse(m_Mul(m_ZExtOrSExt(m_Value(InputA)), + m_ZExtOrSExt(m_Value(InputB))))))) { VectorType *InputAType = dyn_cast(InputA->getType()); VectorType *InputBType = dyn_cast(InputB->getType()); if (!InputAType || !InputBType) From ed6efd6e7e705cd8b932b66f9e818a0e6c204884 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Tue, 13 Aug 2024 14:32:10 +0100 Subject: [PATCH 12/33] Remove new line --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index cd70d2e3cdfa7..856d9b0eeadd1 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2011,7 +2011,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( InputAType->getElementType()->isIntegerTy(16) && InputAType->getElementCount() == ExpectedCount8 && InputAType == InputBType) || - (RetTy->getScalarType()->isIntegerTy(32) && InputAType->getElementType()->isIntegerTy(8) && InputAType->getElementCount() == ExpectedCount16 && From 63648378be615a28ea75fd53d27eabc0a3658c06 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Tue, 13 Aug 2024 20:21:43 +0100 Subject: [PATCH 13/33] Remove extending/truncating for fallback case --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 856d9b0eeadd1..c150db8d8947d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21853,10 +21853,6 @@ static SDValue performIntrinsicCombine(SDNode *N, auto Type0Elements = Type0.getVectorNumElements(); auto Type1Elements = Type1.getVectorNumElements(); - auto Type0ElementSize = - Type0.getVectorElementType().getScalarSizeInBits(); - auto Type1ElementSize = - Type1.getVectorElementType().getScalarSizeInBits(); // If the types are equal then a single ADD is fine if (Type0 == Type1) @@ -21871,10 +21867,6 @@ static SDValue performIntrinsicCombine(SDNode *N, SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1, DAG.getConstant(i, DL, MVT::i64)); - if (Type1ElementSize < Type0ElementSize) - Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec); - else if (Type1ElementSize > Type0ElementSize) - Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec); Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec}); } return Add; From 9da416b52c88340bbd9611ad016532c58f2e8417 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Tue, 13 Aug 2024 20:27:04 +0100 Subject: [PATCH 14/33] Clean up test target --- llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll index 23b39387fb7a0..0facb2049135f 100644 --- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll @@ -1,8 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s - -target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" -target triple = "aarch64-none-unknown-elf" +; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s define @dotp( %a, %b) #0 { ; CHECK-LABEL: dotp: From 0d231096e4f971f379a56c21e961caae7db3080e Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Wed, 14 Aug 2024 09:42:32 +0100 Subject: [PATCH 15/33] Remove #0 attribute from test --- .../CodeGen/AArch64/partial-reduce-dot-product.ll | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll index 0facb2049135f..16ef219a93c9b 100644 --- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll @@ -1,7 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s -define @dotp( %a, %b) #0 { +define @dotp( %a, %b) { ; CHECK-LABEL: dotp: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: mov z2.s, #0 // =0x0 @@ -16,7 +16,7 @@ entry: ret %partial.reduce } -define @dotp_wide( %a, %b) #0 { +define @dotp_wide( %a, %b) { ; CHECK-LABEL: dotp_wide: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: mov z2.d, #0 // =0x0 @@ -31,7 +31,7 @@ entry: ret %partial.reduce } -define @dotp_sext( %a, %b) #0 { +define @dotp_sext( %a, %b) { ; CHECK-LABEL: dotp_sext: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: mov z2.s, #0 // =0x0 @@ -46,7 +46,7 @@ entry: ret %partial.reduce } -define @dotp_wide_sext( %a, %b) #0 { +define @dotp_wide_sext( %a, %b) { ; CHECK-LABEL: dotp_wide_sext: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: mov z2.d, #0 // =0x0 @@ -61,7 +61,7 @@ entry: ret %partial.reduce } -define @not_dotp( %a, %b) #0 { +define @not_dotp( %a, %b) { ; CHECK-LABEL: not_dotp: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: and z0.h, z0.h, #0xff @@ -82,7 +82,7 @@ entry: ret %partial.reduce } -define @not_dotp_wide( %a, %b) #0 { +define @not_dotp_wide( %a, %b) { ; CHECK-LABEL: not_dotp_wide: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: and z0.s, z0.s, #0xffff @@ -102,5 +102,3 @@ entry: %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64( zeroinitializer, %mult) ret %partial.reduce } - -attributes #0 = { "target-features"="+sve2" } From bc86de6d93b9166c6ac4ca10cddae0199607fd15 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Wed, 14 Aug 2024 10:55:12 +0100 Subject: [PATCH 16/33] Allow i8 to i64 dot products --- .../Target/AArch64/AArch64ISelLowering.cpp | 34 ++++++++- .../AArch64/partial-reduce-dot-product.ll | 72 +++++++++++++++++++ 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index c150db8d8947d..13e664f5bf27f 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2007,11 +2007,15 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( return true; ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy()); ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy()); + // Check that the input type is 4 times smaller than the output type. If the + // output type is 64 bit then we can accept 8 bit inputs if we do a 32 bit + // dot product and add a zext/sext. if ((RetTy->getScalarType()->isIntegerTy(64) && InputAType->getElementType()->isIntegerTy(16) && InputAType->getElementCount() == ExpectedCount8 && InputAType == InputBType) || - (RetTy->getScalarType()->isIntegerTy(32) && + ((RetTy->getScalarType()->isIntegerTy(32) || + RetTy->getScalarType()->isIntegerTy(64)) && InputAType->getElementType()->isIntegerTy(8) && InputAType->getElementCount() == ExpectedCount16 && InputAType == InputBType)) { @@ -21833,10 +21837,34 @@ static SDValue performIntrinsicCombine(SDNode *N, if (IsValidDotProduct) { auto A = ExtA->getOperand(0); auto B = ExtB->getOperand(0); + EVT Type = NarrowOp.getValueType(); + + // 8 bit input to 64 bit output can be done by doing a 32 bit dot product + // and extending the output + bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 && + Type.getScalarSizeInBits() == 64; + SDValue Accumulator = NarrowOp; + if (Extend) { + Type = Type.changeVectorElementType( + EVT::getIntegerVT(*DAG.getContext(), 32)); + // The accumulator is of the wider type so we insert a 0 accumulator and + // add the proper one after extending + Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32, + DAG.getConstant(0, DL, MVT::i32)); + } auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64); - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(), - {IntrinsicId, NarrowOp, A, B}); + auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type, + {IntrinsicId, Accumulator, A, B}); + if (Extend) { + auto Extended = + DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL, + NarrowOp.getValueType(), {DotProduct}); + auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), + {NarrowOp, Extended}); + DotProduct = AccAdd; + } + return DotProduct; } else { // If the node doesn't match a dot product, lower to a series of ADDs // instead. diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll index 16ef219a93c9b..c1cf9026d693c 100644 --- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll @@ -61,6 +61,78 @@ entry: ret %partial.reduce } +define @dotp_8to64( %a, %b) { +; CHECK-LABEL: dotp_8to64: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z2.s, #0 // =0x0 +; CHECK-NEXT: udot z2.s, z0.b, z1.b +; CHECK-NEXT: uunpklo z0.d, z2.s +; CHECK-NEXT: uunpkhi z1.d, z2.s +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( + zeroinitializer, %mult) + ret %partial.reduce +} + +define @dotp_sext_8to64( %a, %b) { +; CHECK-LABEL: dotp_sext_8to64: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z2.s, #0 // =0x0 +; CHECK-NEXT: sdot z2.s, z0.b, z1.b +; CHECK-NEXT: sunpklo z0.d, z2.s +; CHECK-NEXT: sunpkhi z1.d, z2.s +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( + zeroinitializer, %mult) + ret %partial.reduce +} + +define @dotp_8to64_accumulator( %a, %b, %acc) { +; CHECK-LABEL: dotp_8to64_accumulator: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z4.s, #0 // =0x0 +; CHECK-NEXT: udot z4.s, z0.b, z1.b +; CHECK-NEXT: uunpklo z0.d, z4.s +; CHECK-NEXT: uunpkhi z1.d, z4.s +; CHECK-NEXT: add z0.d, z2.d, z0.d +; CHECK-NEXT: add z1.d, z3.d, z1.d +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( + %acc, %mult) + ret %partial.reduce +} + +define @dotp_sext_8to64_accumulator( %a, %b, %acc) { +; CHECK-LABEL: dotp_sext_8to64_accumulator: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z4.s, #0 // =0x0 +; CHECK-NEXT: sdot z4.s, z0.b, z1.b +; CHECK-NEXT: sunpklo z0.d, z4.s +; CHECK-NEXT: sunpkhi z1.d, z4.s +; CHECK-NEXT: add z0.d, z2.d, z0.d +; CHECK-NEXT: add z1.d, z3.d, z1.d +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( + %acc, %mult) + ret %partial.reduce +} + define @not_dotp( %a, %b) { ; CHECK-LABEL: not_dotp: ; CHECK: // %bb.0: // %entry From aa7957faeb3adb8beda7443c4a1b06bddb9b01d4 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Tue, 20 Aug 2024 13:53:11 +0100 Subject: [PATCH 17/33] Remove isPartialReductionSupported --- .../AArch64/AArch64TargetTransformInfo.cpp | 25 ------------------- .../AArch64/AArch64TargetTransformInfo.h | 6 ----- 2 files changed, 31 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 8cd2ba17b7d79..dc748290f2e21 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -3670,31 +3670,6 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef Tys) { return Cost; } -bool AArch64TTIImpl::isPartialReductionSupported( - const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor, - bool IsInputASignExtended, bool IsInputBSignExtended, - const Instruction *BinOp) const { - if (ReductionInstr->getOpcode() != Instruction::Add) - return false; - - // Check that both extends are of the same type - if (IsInputASignExtended != IsInputBSignExtended) - return false; - - if (!BinOp || BinOp->getOpcode() != Instruction::Mul) - return false; - - // Dot product only supports a scale factor of 4 - if (ScaleFactor != 4) - return false; - - Type *ReductionType = ReductionInstr->getType(); - - return ((ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) || - (ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16))) && - ReductionType->isScalableTy(); -} - unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) { return ST->getMaxInterleaveFactor(); } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index af7e8e8e497dd..4a6457d7a7dbf 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -155,12 +155,6 @@ class AArch64TTIImpl : public BasicTTIImplBase { return VF.getKnownMinValue() * ST->getVScaleForTuning(); } - bool isPartialReductionSupported(const Instruction *ReductionInstr, - Type *InputType, unsigned ScaleFactor, - bool IsInputASignExtended, - bool IsInputBSignExtended, - const Instruction *BinOp = nullptr) const; - unsigned getMaxInterleaveFactor(ElementCount VF); bool prefersVectorizedAddressing() const; From a58ac297afd726a536b6687748d9a196473a4618 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Wed, 21 Aug 2024 15:02:17 +0100 Subject: [PATCH 18/33] Share expansion code in SelectionDAG --- llvm/include/llvm/CodeGen/SelectionDAG.h | 4 + .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 30 +++ .../SelectionDAG/SelectionDAGBuilder.cpp | 29 +-- .../Target/AArch64/AArch64ISelLowering.cpp | 217 ++++++++---------- 4 files changed, 130 insertions(+), 150 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 1514d92b36b3c..2235db5d93b5a 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1594,6 +1594,10 @@ class SelectionDAG { /// the target's desired shift amount type. SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op); + /// Expand a partial reduction intrinsic call. + /// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type. + SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL); + /// Expand the specified \c ISD::VAARG node as the Legalize pass would. SDValue expandVAArg(SDNode *Node); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 27675dce70c26..2510c1828c909 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -74,6 +74,7 @@ #include #include #include +#include #include #include #include @@ -2426,6 +2427,35 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) { return getZExtOrTrunc(Op, SDLoc(Op), ShTy); } +SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL) { + EVT FullTy = Op2.getValueType(); + + unsigned Stride = ReducedTy.getVectorMinNumElements(); + unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride; + + // Collect all of the subvectors + std::deque Subvectors = {Op1}; + for (unsigned I = 0; I < ScaleFactor; I++) { + auto SourceIndex = getVectorIdxConstant(I * Stride, DL); + Subvectors.push_back(getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, + {Op2, SourceIndex})); + } + + // Flatten the subvector tree + while (Subvectors.size() > 1) { + Subvectors.push_back(getNode(ISD::ADD, DL, ReducedTy, + {Subvectors[0], Subvectors[1]})); + Subvectors.pop_front(); + Subvectors.pop_front(); + } + + assert(Subvectors.size() == 1 && + "There should only be one subvector after tree flattening"); + + return Subvectors[0]; + +} + SDValue SelectionDAG::expandVAArg(SDNode *Node) { SDLoc dl(Node); const TargetLowering &TLI = getTargetLoweringInfo(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 05cbe384cb5ed..33de8747fb7e5 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -8011,34 +8011,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, return; } - SDValue OpNode = getValue(I.getOperand(1)); - EVT ReducedTy = EVT::getEVT(I.getType()); - EVT FullTy = OpNode.getValueType(); - - unsigned Stride = ReducedTy.getVectorMinNumElements(); - unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride; - - // Collect all of the subvectors - std::deque Subvectors; - Subvectors.push_back(getValue(I.getOperand(0))); - for (unsigned i = 0; i < ScaleFactor; i++) { - auto SourceIndex = DAG.getVectorIdxConstant(i * Stride, sdl); - Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ReducedTy, - {OpNode, SourceIndex})); - } - - // Flatten the subvector tree - while (Subvectors.size() > 1) { - Subvectors.push_back(DAG.getNode(ISD::ADD, sdl, ReducedTy, - {Subvectors[0], Subvectors[1]})); - Subvectors.pop_front(); - Subvectors.pop_front(); - } - - assert(Subvectors.size() == 1 && - "There should only be one subvector after tree flattening"); - - setValue(&I, Subvectors[0]); + setValue(&I, DAG.expandPartialReductionIntrinsic(EVT::getEVT(I.getType()), getValue(I.getOperand(0)), getValue(I.getOperand(1)), sdl)); return; } case Intrinsic::experimental_cttz_elts: { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 13e664f5bf27f..df89806ca057e 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1995,37 +1995,12 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( if (!RetTy || !RetTy->isScalableTy()) return true; - Value *InputA; - Value *InputB; - if (match(I, - m_Intrinsic( - m_Value(), m_OneUse(m_Mul(m_ZExtOrSExt(m_Value(InputA)), - m_ZExtOrSExt(m_Value(InputB))))))) { - VectorType *InputAType = dyn_cast(InputA->getType()); - VectorType *InputBType = dyn_cast(InputB->getType()); - if (!InputAType || !InputBType) - return true; - ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy()); - ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy()); - // Check that the input type is 4 times smaller than the output type. If the - // output type is 64 bit then we can accept 8 bit inputs if we do a 32 bit - // dot product and add a zext/sext. - if ((RetTy->getScalarType()->isIntegerTy(64) && - InputAType->getElementType()->isIntegerTy(16) && - InputAType->getElementCount() == ExpectedCount8 && - InputAType == InputBType) || - ((RetTy->getScalarType()->isIntegerTy(32) || - RetTy->getScalarType()->isIntegerTy(64)) && - InputAType->getElementType()->isIntegerTy(8) && - InputAType->getElementCount() == ExpectedCount16 && - InputAType == InputBType)) { - auto *Mul = cast(I->getOperand(1)); - auto *Mul0 = cast(Mul->getOperand(0)); - auto *Mul1 = cast(Mul->getOperand(1)); - if (Mul0->getOpcode() == Mul1->getOpcode()) - return false; - } - } + if (RetTy->getScalarType()->isIntegerTy(32) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy())) + return false; + if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy())) + return false; + if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy())) + return false; return true; } @@ -21799,6 +21774,92 @@ static SDValue tryCombineWhileLo(SDNode *N, return SDValue(N, 0); } +SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarget, SelectionDAG &DAG) { + SDLoc DL(N); + + // The narrower of the two operands. Used as the accumulator + auto NarrowOp = N->getOperand(1); + auto MulOp = N->getOperand(2); + if (MulOp->getOpcode() != ISD::MUL) + return SDValue(); + + auto ExtA = MulOp->getOperand(0); + auto ExtB = MulOp->getOperand(1); + bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND; + bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; + if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt)) + return SDValue(); + + auto A = ExtA->getOperand(0); + auto B = ExtB->getOperand(0); + if (A.getValueType() != B.getValueType()) + return SDValue(); + + // The fully-reduced type. Should be a vector of i32 or i64 + EVT FullType = N->getValueType(0); + // The type that is extended to the wide type. Should be an i8 or i16 + EVT ExtendedType = A.getValueType(); + // The wide type with four times as many elements as the reduced type. Should be a vector of i32 or i64, the same as the fully-reduced type + EVT WideType = MulOp.getValueType(); + if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits()) + return SDValue(); + // Dot products operate on chunks of four elements so there must be four times as many elements in the wide type + if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() != 4) + return SDValue(); + switch (FullType.getScalarSizeInBits()) { + case 32: + if (ExtendedType.getScalarSizeInBits() != 8) + return SDValue(); + break; + case 64: + // i8 to i64 can be done with an extended i32 dot product + if (ExtendedType.getScalarSizeInBits() != 8 && ExtendedType.getScalarSizeInBits() != 16) + return SDValue(); + break; + default: + return SDValue(); + } + + unsigned DotIntrinsicId = Intrinsic::not_intrinsic; + + if (IsSExt) + DotIntrinsicId = Intrinsic::aarch64_sve_sdot; + else if (IsZExt) + DotIntrinsicId = Intrinsic::aarch64_sve_udot; + + assert(DotIntrinsicId != Intrinsic::not_intrinsic && + "Unexpected dot product case encountered."); + + EVT Type = NarrowOp.getValueType(); + + // 8 bit input to 64 bit output can be done by doing a 32 bit dot product + // and extending the output + bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 && + Type.getScalarSizeInBits() == 64; + SDValue Accumulator = NarrowOp; + if (Extend) { + Type = Type.changeVectorElementType( + EVT::getIntegerVT(*DAG.getContext(), 32)); + // The accumulator is of the wider type so we insert a 0 accumulator and + // add the proper one after extending + Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32, + DAG.getConstant(0, DL, MVT::i32)); + } + + auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64); + auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type, + {IntrinsicId, Accumulator, A, B}); + if (Extend) { + auto Extended = + DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL, + NarrowOp.getValueType(), {DotProduct}); + auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), + {NarrowOp, Extended}); + DotProduct = AccAdd; + } + return DotProduct; +} + static SDValue performIntrinsicCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -21808,97 +21869,9 @@ static SDValue performIntrinsicCombine(SDNode *N, default: break; case Intrinsic::experimental_vector_partial_reduce_add: { - SDLoc DL(N); - - bool IsValidDotProduct = true; - - auto NarrowOp = N->getOperand(1); - auto MulOp = N->getOperand(2); - if (MulOp->getOpcode() != ISD::MUL) - IsValidDotProduct = false; - - auto ExtA = MulOp->getOperand(0); - auto ExtB = MulOp->getOperand(1); - bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND; - bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; - if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt)) - IsValidDotProduct = false; - - unsigned DotIntrinsicId = Intrinsic::not_intrinsic; - - if (IsSExt && IsValidDotProduct) - DotIntrinsicId = Intrinsic::aarch64_sve_sdot; - else if (IsZExt && IsValidDotProduct) - DotIntrinsicId = Intrinsic::aarch64_sve_udot; - - assert((!IsValidDotProduct || DotIntrinsicId != Intrinsic::not_intrinsic) && - "Unexpected dot product case encountered."); - - if (IsValidDotProduct) { - auto A = ExtA->getOperand(0); - auto B = ExtB->getOperand(0); - EVT Type = NarrowOp.getValueType(); - - // 8 bit input to 64 bit output can be done by doing a 32 bit dot product - // and extending the output - bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 && - Type.getScalarSizeInBits() == 64; - SDValue Accumulator = NarrowOp; - if (Extend) { - Type = Type.changeVectorElementType( - EVT::getIntegerVT(*DAG.getContext(), 32)); - // The accumulator is of the wider type so we insert a 0 accumulator and - // add the proper one after extending - Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32, - DAG.getConstant(0, DL, MVT::i32)); - } - - auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64); - auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type, - {IntrinsicId, Accumulator, A, B}); - if (Extend) { - auto Extended = - DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL, - NarrowOp.getValueType(), {DotProduct}); - auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), - {NarrowOp, Extended}); - DotProduct = AccAdd; - } - return DotProduct; - } else { - // If the node doesn't match a dot product, lower to a series of ADDs - // instead. - SDValue Op0 = N->getOperand(1); - SDValue Op1 = N->getOperand(2); - EVT Type0 = Op0->getValueType(0); - EVT Type1 = Op1->getValueType(0); - - // Canonicalise so that Op1 has the larger type - if (Type1.getVectorNumElements() > Type0.getVectorNumElements()) { - std::swap(Op0, Op1); - std::swap(Type0, Type1); - } - - auto Type0Elements = Type0.getVectorNumElements(); - auto Type1Elements = Type1.getVectorNumElements(); - - // If the types are equal then a single ADD is fine - if (Type0 == Type1) - return DAG.getNode(ISD::ADD, DL, Type0, {Op0, Op1}); - - // Otherwise, we need to add each subvector together so that the output is - // the intrinsic's return type. For example, <4 x i32> - // partial.reduction(<4 x i32> a, <16 x i32> b) becomes a + b[0..3] + - // b[4..7] + b[8..11] + b[12..15] - SDValue Add = Op0; - for (unsigned i = 0; i < Type1Elements / Type0Elements; i++) { - SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1, - DAG.getConstant(i, DL, MVT::i64)); - - Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec}); - } - return Add; - } + if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG)) + return Dot; + return DAG.expandPartialReductionIntrinsic(N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N)); } case Intrinsic::aarch64_neon_vcvtfxs2fp: case Intrinsic::aarch64_neon_vcvtfxu2fp: From 5f31079756eb54679f0dd64da0d425084069a929 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Wed, 21 Aug 2024 16:04:52 +0100 Subject: [PATCH 19/33] Check for NEON or SVE --- llvm/include/llvm/CodeGen/SelectionDAG.h | 3 +- .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 45 +++++++------- .../SelectionDAG/SelectionDAGBuilder.cpp | 4 +- .../Target/AArch64/AArch64ISelLowering.cpp | 61 +++++++++++-------- 4 files changed, 65 insertions(+), 48 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 2235db5d93b5a..2c1d4bf259699 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1596,7 +1596,8 @@ class SelectionDAG { /// Expand a partial reduction intrinsic call. /// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type. - SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL); + SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, + SDValue Op2, SDLoc DL); /// Expand the specified \c ISD::VAARG node as the Legalize pass would. SDValue expandVAArg(SDNode *Node); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 2510c1828c909..cd09760b1f24e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2427,33 +2427,34 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) { return getZExtOrTrunc(Op, SDLoc(Op), ShTy); } -SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL) { - EVT FullTy = Op2.getValueType(); +SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy, + SDValue Op1, SDValue Op2, + SDLoc DL) { + EVT FullTy = Op2.getValueType(); - unsigned Stride = ReducedTy.getVectorMinNumElements(); - unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride; + unsigned Stride = ReducedTy.getVectorMinNumElements(); + unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride; - // Collect all of the subvectors - std::deque Subvectors = {Op1}; - for (unsigned I = 0; I < ScaleFactor; I++) { - auto SourceIndex = getVectorIdxConstant(I * Stride, DL); - Subvectors.push_back(getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, - {Op2, SourceIndex})); - } - - // Flatten the subvector tree - while (Subvectors.size() > 1) { - Subvectors.push_back(getNode(ISD::ADD, DL, ReducedTy, - {Subvectors[0], Subvectors[1]})); - Subvectors.pop_front(); - Subvectors.pop_front(); - } + // Collect all of the subvectors + std::deque Subvectors = {Op1}; + for (unsigned I = 0; I < ScaleFactor; I++) { + auto SourceIndex = getVectorIdxConstant(I * Stride, DL); + Subvectors.push_back( + getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex})); + } - assert(Subvectors.size() == 1 && - "There should only be one subvector after tree flattening"); + // Flatten the subvector tree + while (Subvectors.size() > 1) { + Subvectors.push_back( + getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]})); + Subvectors.pop_front(); + Subvectors.pop_front(); + } - return Subvectors[0]; + assert(Subvectors.size() == 1 && + "There should only be one subvector after tree flattening"); + return Subvectors[0]; } SDValue SelectionDAG::expandVAArg(SDNode *Node) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 33de8747fb7e5..ce5ef78eba15d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -8011,7 +8011,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, return; } - setValue(&I, DAG.expandPartialReductionIntrinsic(EVT::getEVT(I.getType()), getValue(I.getOperand(0)), getValue(I.getOperand(1)), sdl)); + setValue(&I, DAG.expandPartialReductionIntrinsic( + EVT::getEVT(I.getType()), getValue(I.getOperand(0)), + getValue(I.getOperand(1)), sdl)); return; } case Intrinsic::experimental_cttz_elts: { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index df89806ca057e..b849ddb2a86d6 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1995,11 +1995,14 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( if (!RetTy || !RetTy->isScalableTy()) return true; - if (RetTy->getScalarType()->isIntegerTy(32) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy())) + if (RetTy->getScalarType()->isIntegerTy(32) && + RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy())) return false; - if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy())) + if (RetTy->getScalarType()->isIntegerTy(64) && + RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy())) return false; - if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy())) + if (RetTy->getScalarType()->isIntegerTy(64) && + RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy())) return false; return true; @@ -21774,7 +21777,13 @@ static SDValue tryCombineWhileLo(SDNode *N, return SDValue(N, 0); } -SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarget, SelectionDAG &DAG) { +SDValue tryLowerPartialReductionToDot(SDNode *N, + const AArch64Subtarget *Subtarget, + SelectionDAG &DAG) { + + if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable()) + return SDValue(); + SDLoc DL(N); // The narrower of the two operands. Used as the accumulator @@ -21799,25 +21808,29 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarg EVT FullType = N->getValueType(0); // The type that is extended to the wide type. Should be an i8 or i16 EVT ExtendedType = A.getValueType(); - // The wide type with four times as many elements as the reduced type. Should be a vector of i32 or i64, the same as the fully-reduced type + // The wide type with four times as many elements as the reduced type. Should + // be a vector of i32 or i64, the same as the fully-reduced type EVT WideType = MulOp.getValueType(); if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits()) return SDValue(); - // Dot products operate on chunks of four elements so there must be four times as many elements in the wide type - if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() != 4) + // Dot products operate on chunks of four elements so there must be four times + // as many elements in the wide type + if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() != + 4) return SDValue(); switch (FullType.getScalarSizeInBits()) { - case 32: - if (ExtendedType.getScalarSizeInBits() != 8) - return SDValue(); - break; - case 64: - // i8 to i64 can be done with an extended i32 dot product - if (ExtendedType.getScalarSizeInBits() != 8 && ExtendedType.getScalarSizeInBits() != 16) - return SDValue(); - break; - default: + case 32: + if (ExtendedType.getScalarSizeInBits() != 8) + return SDValue(); + break; + case 64: + // i8 to i64 can be done with an extended i32 dot product + if (ExtendedType.getScalarSizeInBits() != 8 && + ExtendedType.getScalarSizeInBits() != 16) return SDValue(); + break; + default: + return SDValue(); } unsigned DotIntrinsicId = Intrinsic::not_intrinsic; @@ -21838,8 +21851,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarg Type.getScalarSizeInBits() == 64; SDValue Accumulator = NarrowOp; if (Extend) { - Type = Type.changeVectorElementType( - EVT::getIntegerVT(*DAG.getContext(), 32)); + Type = + Type.changeVectorElementType(EVT::getIntegerVT(*DAG.getContext(), 32)); // The accumulator is of the wider type so we insert a 0 accumulator and // add the proper one after extending Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32, @@ -21850,9 +21863,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarg auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type, {IntrinsicId, Accumulator, A, B}); if (Extend) { - auto Extended = - DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL, - NarrowOp.getValueType(), {DotProduct}); + auto Extended = DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, + DL, NarrowOp.getValueType(), {DotProduct}); auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), {NarrowOp, Extended}); DotProduct = AccAdd; @@ -21870,8 +21882,9 @@ static SDValue performIntrinsicCombine(SDNode *N, break; case Intrinsic::experimental_vector_partial_reduce_add: { if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG)) - return Dot; - return DAG.expandPartialReductionIntrinsic(N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N)); + return Dot; + return DAG.expandPartialReductionIntrinsic( + N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N)); } case Intrinsic::aarch64_neon_vcvtfxs2fp: case Intrinsic::aarch64_neon_vcvtfxu2fp: From 2f3a0dc8d581efd0687d0eedf3c346ffc6a60716 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Wed, 28 Aug 2024 09:27:15 +0100 Subject: [PATCH 20/33] Rename expansion function --- llvm/include/llvm/CodeGen/SelectionDAG.h | 4 ++-- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 5 ++--- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 6 +++--- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 ++-- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 2c1d4bf259699..227616c37e004 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1596,8 +1596,8 @@ class SelectionDAG { /// Expand a partial reduction intrinsic call. /// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type. - SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, - SDValue Op2, SDLoc DL); + SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1, + SDValue Op2); /// Expand the specified \c ISD::VAARG node as the Legalize pass would. SDValue expandVAArg(SDNode *Node); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index cd09760b1f24e..d5e61183d0e25 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2427,9 +2427,8 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) { return getZExtOrTrunc(Op, SDLoc(Op), ShTy); } -SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy, - SDValue Op1, SDValue Op2, - SDLoc DL) { +SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1, + SDValue Op2) { EVT FullTy = Op2.getValueType(); unsigned Stride = ReducedTy.getVectorMinNumElements(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index ce5ef78eba15d..98c2e703c39ed 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -8011,9 +8011,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, return; } - setValue(&I, DAG.expandPartialReductionIntrinsic( - EVT::getEVT(I.getType()), getValue(I.getOperand(0)), - getValue(I.getOperand(1)), sdl)); + setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()), + getValue(I.getOperand(0)), + getValue(I.getOperand(1)))); return; } case Intrinsic::experimental_cttz_elts: { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index b849ddb2a86d6..a25c09ade370e 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21883,8 +21883,8 @@ static SDValue performIntrinsicCombine(SDNode *N, case Intrinsic::experimental_vector_partial_reduce_add: { if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG)) return Dot; - return DAG.expandPartialReductionIntrinsic( - N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N)); + return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2)); } case Intrinsic::aarch64_neon_vcvtfxs2fp: case Intrinsic::aarch64_neon_vcvtfxu2fp: From 00a1be219912ca53f24dda3ee410229fa6286736 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Wed, 28 Aug 2024 10:22:45 +0100 Subject: [PATCH 21/33] Simplify shouldExpandPartialReductionIntrinsic --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a25c09ade370e..9f5d5a61397d1 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1991,21 +1991,12 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT, bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( const IntrinsicInst *I) const { - VectorType *RetTy = dyn_cast(I->getType()); - if (!RetTy || !RetTy->isScalableTy()) + if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add) return true; - if (RetTy->getScalarType()->isIntegerTy(32) && - RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy())) - return false; - if (RetTy->getScalarType()->isIntegerTy(64) && - RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy())) - return false; - if (RetTy->getScalarType()->isIntegerTy(64) && - RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy())) - return false; + EVT VT = EVT::getEVT(I->getType()); - return true; + return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::nxv4i64; } bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const { From f6c58393ddc03ae7328f7a1fe8236bce55924999 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Thu, 29 Aug 2024 09:36:24 +0100 Subject: [PATCH 22/33] Remove nxv4i64 case --- .../Target/AArch64/AArch64ISelLowering.cpp | 77 ++++++------------- .../AArch64/partial-reduce-dot-product.ll | 72 ----------------- 2 files changed, 23 insertions(+), 126 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 9f5d5a61397d1..7193303c7084a 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21795,35 +21795,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, if (A.getValueType() != B.getValueType()) return SDValue(); - // The fully-reduced type. Should be a vector of i32 or i64 - EVT FullType = N->getValueType(0); - // The type that is extended to the wide type. Should be an i8 or i16 - EVT ExtendedType = A.getValueType(); - // The wide type with four times as many elements as the reduced type. Should - // be a vector of i32 or i64, the same as the fully-reduced type - EVT WideType = MulOp.getValueType(); - if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits()) - return SDValue(); - // Dot products operate on chunks of four elements so there must be four times - // as many elements in the wide type - if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() != - 4) - return SDValue(); - switch (FullType.getScalarSizeInBits()) { - case 32: - if (ExtendedType.getScalarSizeInBits() != 8) - return SDValue(); - break; - case 64: - // i8 to i64 can be done with an extended i32 dot product - if (ExtendedType.getScalarSizeInBits() != 8 && - ExtendedType.getScalarSizeInBits() != 16) - return SDValue(); - break; - default: - return SDValue(); - } - unsigned DotIntrinsicId = Intrinsic::not_intrinsic; if (IsSExt) @@ -21834,33 +21805,31 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, assert(DotIntrinsicId != Intrinsic::not_intrinsic && "Unexpected dot product case encountered."); - EVT Type = NarrowOp.getValueType(); + auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64); - // 8 bit input to 64 bit output can be done by doing a 32 bit dot product - // and extending the output - bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 && - Type.getScalarSizeInBits() == 64; - SDValue Accumulator = NarrowOp; - if (Extend) { - Type = - Type.changeVectorElementType(EVT::getIntegerVT(*DAG.getContext(), 32)); - // The accumulator is of the wider type so we insert a 0 accumulator and - // add the proper one after extending - Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32, - DAG.getConstant(0, DL, MVT::i32)); - } + // The fully-reduced type. Should be a vector of i32 or i64 + EVT ReducedType = N->getValueType(0); + // The type that is extended to the wide type. Should be an i8 or i16 + EVT ExtendedType = A.getValueType(); + // The wide type with four times as many elements as the reduced type. Should + // be a vector of i32 or i64, the same as the fully-reduced type + EVT WideType = MulOp.getValueType(); + if (WideType.getScalarSizeInBits() != ReducedType.getScalarSizeInBits()) + return SDValue(); - auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64); - auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type, - {IntrinsicId, Accumulator, A, B}); - if (Extend) { - auto Extended = DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, - DL, NarrowOp.getValueType(), {DotProduct}); - auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), - {NarrowOp, Extended}); - DotProduct = AccAdd; - } - return DotProduct; + // Dot products operate on chunks of four elements so there must be four times + // as many elements in the wide type + if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 && + ExtendedType == MVT::nxv16i8) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32, + {IntrinsicId, NarrowOp, A, B}); + + if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 && + ExtendedType == MVT::nxv8i16) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i64, + {IntrinsicId, NarrowOp, A, B}); + + return SDValue(); } static SDValue performIntrinsicCombine(SDNode *N, diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll index c1cf9026d693c..16ef219a93c9b 100644 --- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll @@ -61,78 +61,6 @@ entry: ret %partial.reduce } -define @dotp_8to64( %a, %b) { -; CHECK-LABEL: dotp_8to64: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: mov z2.s, #0 // =0x0 -; CHECK-NEXT: udot z2.s, z0.b, z1.b -; CHECK-NEXT: uunpklo z0.d, z2.s -; CHECK-NEXT: uunpkhi z1.d, z2.s -; CHECK-NEXT: ret -entry: - %a.wide = zext %a to - %b.wide = zext %b to - %mult = mul nuw nsw %a.wide, %b.wide - %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( - zeroinitializer, %mult) - ret %partial.reduce -} - -define @dotp_sext_8to64( %a, %b) { -; CHECK-LABEL: dotp_sext_8to64: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: mov z2.s, #0 // =0x0 -; CHECK-NEXT: sdot z2.s, z0.b, z1.b -; CHECK-NEXT: sunpklo z0.d, z2.s -; CHECK-NEXT: sunpkhi z1.d, z2.s -; CHECK-NEXT: ret -entry: - %a.wide = sext %a to - %b.wide = sext %b to - %mult = mul nuw nsw %a.wide, %b.wide - %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( - zeroinitializer, %mult) - ret %partial.reduce -} - -define @dotp_8to64_accumulator( %a, %b, %acc) { -; CHECK-LABEL: dotp_8to64_accumulator: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: mov z4.s, #0 // =0x0 -; CHECK-NEXT: udot z4.s, z0.b, z1.b -; CHECK-NEXT: uunpklo z0.d, z4.s -; CHECK-NEXT: uunpkhi z1.d, z4.s -; CHECK-NEXT: add z0.d, z2.d, z0.d -; CHECK-NEXT: add z1.d, z3.d, z1.d -; CHECK-NEXT: ret -entry: - %a.wide = zext %a to - %b.wide = zext %b to - %mult = mul nuw nsw %a.wide, %b.wide - %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( - %acc, %mult) - ret %partial.reduce -} - -define @dotp_sext_8to64_accumulator( %a, %b, %acc) { -; CHECK-LABEL: dotp_sext_8to64_accumulator: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: mov z4.s, #0 // =0x0 -; CHECK-NEXT: sdot z4.s, z0.b, z1.b -; CHECK-NEXT: sunpklo z0.d, z4.s -; CHECK-NEXT: sunpkhi z1.d, z4.s -; CHECK-NEXT: add z0.d, z2.d, z0.d -; CHECK-NEXT: add z1.d, z3.d, z1.d -; CHECK-NEXT: ret -entry: - %a.wide = sext %a to - %b.wide = sext %b to - %mult = mul nuw nsw %a.wide, %b.wide - %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( - %acc, %mult) - ret %partial.reduce -} - define @not_dotp( %a, %b) { ; CHECK-LABEL: not_dotp: ; CHECK: // %bb.0: // %entry From da20b2a03ef35f0f1172f0b4826ac3d792671da5 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Thu, 29 Aug 2024 09:47:18 +0100 Subject: [PATCH 23/33] Add assertion --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 7193303c7084a..a100d033f50d0 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21772,6 +21772,11 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarget, SelectionDAG &DAG) { + assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && + getIntrinsicID(N) == + Intrinsic::experimental_vector_partial_reduce_add && + "Expected a partial reduction node"); + if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable()) return SDValue(); From 4697fc13d67421bd364ca907659a4896f629d0e3 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Thu, 29 Aug 2024 09:48:31 +0100 Subject: [PATCH 24/33] Fix subtarget check --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a100d033f50d0..e303f6b693da1 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21777,7 +21777,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, Intrinsic::experimental_vector_partial_reduce_add && "Expected a partial reduction node"); - if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable()) + if (!Subtarget->isSVEorStreamingSVEAvailable()) return SDValue(); SDLoc DL(N); @@ -21819,8 +21819,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, // The wide type with four times as many elements as the reduced type. Should // be a vector of i32 or i64, the same as the fully-reduced type EVT WideType = MulOp.getValueType(); - if (WideType.getScalarSizeInBits() != ReducedType.getScalarSizeInBits()) - return SDValue(); // Dot products operate on chunks of four elements so there must be four times // as many elements in the wide type From 31b75674f414b740c3770377cb8b498aaa607225 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Thu, 29 Aug 2024 09:50:33 +0100 Subject: [PATCH 25/33] Emit a node instead of an intrinsic --- .../Target/AArch64/AArch64ISelLowering.cpp | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index e303f6b693da1..ded99e5ec8c44 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21800,17 +21800,14 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, if (A.getValueType() != B.getValueType()) return SDValue(); - unsigned DotIntrinsicId = Intrinsic::not_intrinsic; + unsigned Opcode = 0; if (IsSExt) - DotIntrinsicId = Intrinsic::aarch64_sve_sdot; + Opcode = AArch64ISD::SDOT; else if (IsZExt) - DotIntrinsicId = Intrinsic::aarch64_sve_udot; - - assert(DotIntrinsicId != Intrinsic::not_intrinsic && - "Unexpected dot product case encountered."); + Opcode = AArch64ISD::UDOT; - auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64); + assert(Opcode != 0 && "Unexpected dot product case encountered."); // The fully-reduced type. Should be a vector of i32 or i64 EVT ReducedType = N->getValueType(0); @@ -21824,13 +21821,13 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, // as many elements in the wide type if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 && ExtendedType == MVT::nxv16i8) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32, - {IntrinsicId, NarrowOp, A, B}); + return DAG.getNode(Opcode, DL, MVT::nxv4i32, + NarrowOp, A, B); if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 && ExtendedType == MVT::nxv8i16) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i64, - {IntrinsicId, NarrowOp, A, B}); + return DAG.getNode(Opcode, DL, MVT::nxv2i64, + NarrowOp, A, B); return SDValue(); } From 76296792b2cd94c78daaa03e8416e5bc86346ccb Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Thu, 29 Aug 2024 10:11:18 +0100 Subject: [PATCH 26/33] Pass accumulator from function in tests --- .../AArch64/partial-reduce-dot-product.ll | 68 ++++++++----------- 1 file changed, 30 insertions(+), 38 deletions(-) diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll index 16ef219a93c9b..b1354ab210f72 100644 --- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll @@ -1,104 +1,96 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s -define @dotp( %a, %b) { +define @dotp( %acc, %a, %b) { ; CHECK-LABEL: dotp: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: mov z2.s, #0 // =0x0 -; CHECK-NEXT: udot z2.s, z0.b, z1.b -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: udot z0.s, z1.b, z2.b ; CHECK-NEXT: ret entry: %a.wide = zext %a to %b.wide = zext %b to %mult = mul nuw nsw %a.wide, %b.wide - %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( zeroinitializer, %mult) + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %mult) ret %partial.reduce } -define @dotp_wide( %a, %b) { +define @dotp_wide( %acc, %a, %b) { ; CHECK-LABEL: dotp_wide: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: mov z2.d, #0 // =0x0 -; CHECK-NEXT: udot z2.d, z0.h, z1.h -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: udot z0.d, z1.h, z2.h ; CHECK-NEXT: ret entry: %a.wide = zext %a to %b.wide = zext %b to %mult = mul nuw nsw %a.wide, %b.wide - %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( zeroinitializer, %mult) + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) ret %partial.reduce } -define @dotp_sext( %a, %b) { +define @dotp_sext( %accc, %a, %b) { ; CHECK-LABEL: dotp_sext: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: mov z2.s, #0 // =0x0 -; CHECK-NEXT: sdot z2.s, z0.b, z1.b -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: sdot z0.s, z1.b, z2.b ; CHECK-NEXT: ret entry: %a.wide = sext %a to %b.wide = sext %b to %mult = mul nuw nsw %a.wide, %b.wide - %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( zeroinitializer, %mult) + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %accc, %mult) ret %partial.reduce } -define @dotp_wide_sext( %a, %b) { +define @dotp_wide_sext( %acc, %a, %b) { ; CHECK-LABEL: dotp_wide_sext: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: mov z2.d, #0 // =0x0 -; CHECK-NEXT: sdot z2.d, z0.h, z1.h -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: sdot z0.d, z1.h, z2.h ; CHECK-NEXT: ret entry: %a.wide = sext %a to %b.wide = sext %b to %mult = mul nuw nsw %a.wide, %b.wide - %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( zeroinitializer, %mult) + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) ret %partial.reduce } -define @not_dotp( %a, %b) { +define @not_dotp( %acc, %a, %b) { ; CHECK-LABEL: not_dotp: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: and z0.h, z0.h, #0xff ; CHECK-NEXT: and z1.h, z1.h, #0xff +; CHECK-NEXT: and z2.h, z2.h, #0xff ; CHECK-NEXT: ptrue p0.s -; CHECK-NEXT: uunpkhi z2.s, z0.h -; CHECK-NEXT: uunpkhi z3.s, z1.h -; CHECK-NEXT: uunpklo z0.s, z0.h -; CHECK-NEXT: uunpklo z1.s, z1.h -; CHECK-NEXT: mul z2.s, z2.s, z3.s -; CHECK-NEXT: mad z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: uunpklo z3.s, z1.h +; CHECK-NEXT: uunpklo z4.s, z2.h +; CHECK-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEXT: uunpkhi z2.s, z2.h +; CHECK-NEXT: mla z0.s, p0/m, z3.s, z4.s +; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s ; CHECK-NEXT: ret entry: %a.wide = zext %a to %b.wide = zext %b to %mult = mul nuw nsw %a.wide, %b.wide - %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( zeroinitializer, %mult) + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %mult) ret %partial.reduce } -define @not_dotp_wide( %a, %b) { +define @not_dotp_wide( %acc, %a, %b) { ; CHECK-LABEL: not_dotp_wide: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: and z0.s, z0.s, #0xffff ; CHECK-NEXT: and z1.s, z1.s, #0xffff +; CHECK-NEXT: and z2.s, z2.s, #0xffff ; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: uunpkhi z2.d, z0.s -; CHECK-NEXT: uunpkhi z3.d, z1.s -; CHECK-NEXT: uunpklo z0.d, z0.s -; CHECK-NEXT: uunpklo z1.d, z1.s -; CHECK-NEXT: mul z2.d, z2.d, z3.d -; CHECK-NEXT: mad z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: uunpklo z3.d, z1.s +; CHECK-NEXT: uunpklo z4.d, z2.s +; CHECK-NEXT: uunpkhi z1.d, z1.s +; CHECK-NEXT: uunpkhi z2.d, z2.s +; CHECK-NEXT: mla z0.d, p0/m, z3.d, z4.d +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d ; CHECK-NEXT: ret entry: %a.wide = zext %a to %b.wide = zext %b to %mult = mul nuw nsw %a.wide, %b.wide - %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64( zeroinitializer, %mult) + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64( %acc, %mult) ret %partial.reduce } From 830df7624894cca7c00db2263129433cf0e1a9d7 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Fri, 30 Aug 2024 10:00:37 +0100 Subject: [PATCH 27/33] Remove nxv4i64 case from shouldExpandPartialReductionIntrinsic --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index ded99e5ec8c44..9ec25a4074a0a 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1996,7 +1996,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( EVT VT = EVT::getEVT(I->getType()); - return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::nxv4i64; + return VT != MVT::nxv4i32 && VT != MVT::nxv2i64; } bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const { From 97f3f7601061b2d0a8ebc2e59e87ba3f3d778e1f Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 2 Sep 2024 09:59:06 +0100 Subject: [PATCH 28/33] Reword getPartialReduceAdd comment --- llvm/include/llvm/CodeGen/SelectionDAG.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 227616c37e004..7ee8ca18c2c1d 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1594,8 +1594,8 @@ class SelectionDAG { /// the target's desired shift amount type. SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op); - /// Expand a partial reduction intrinsic call. - /// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type. + /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are + /// its operands and ReducedTY is the intrinsic's return type. SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1, SDValue Op2); From bccc3e0833d68e4b2eb8f9bcab22146f2b467abb Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 2 Sep 2024 09:59:40 +0100 Subject: [PATCH 29/33] Remove blank lines from shouldExpandPartialReductionIntrinsic --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 9ec25a4074a0a..d8b77e3fa9cd2 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1990,12 +1990,10 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT, bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( const IntrinsicInst *I) const { - if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add) return true; EVT VT = EVT::getEVT(I->getType()); - return VT != MVT::nxv4i32 && VT != MVT::nxv2i64; } From 08db4c3e3278caebe58d8fe7de9662f2d38fcc83 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 2 Sep 2024 10:03:42 +0100 Subject: [PATCH 30/33] Rename ExtendedType --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index d8b77e3fa9cd2..7512f18c4b1e2 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21810,7 +21810,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, // The fully-reduced type. Should be a vector of i32 or i64 EVT ReducedType = N->getValueType(0); // The type that is extended to the wide type. Should be an i8 or i16 - EVT ExtendedType = A.getValueType(); + EVT MulSrcType = A.getValueType(); // The wide type with four times as many elements as the reduced type. Should // be a vector of i32 or i64, the same as the fully-reduced type EVT WideType = MulOp.getValueType(); @@ -21818,12 +21818,12 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, // Dot products operate on chunks of four elements so there must be four times // as many elements in the wide type if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 && - ExtendedType == MVT::nxv16i8) + MulSrcType == MVT::nxv16i8) return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B); if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 && - ExtendedType == MVT::nxv8i16) + MulSrcType == MVT::nxv8i16) return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B); From acf964b99e05a54765e611171f9b4f70a7517820 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 2 Sep 2024 10:04:27 +0100 Subject: [PATCH 31/33] Remove var declaration comments --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 7512f18c4b1e2..7e9a860e3d1ba 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21807,12 +21807,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, assert(Opcode != 0 && "Unexpected dot product case encountered."); - // The fully-reduced type. Should be a vector of i32 or i64 EVT ReducedType = N->getValueType(0); - // The type that is extended to the wide type. Should be an i8 or i16 EVT MulSrcType = A.getValueType(); - // The wide type with four times as many elements as the reduced type. Should - // be a vector of i32 or i64, the same as the fully-reduced type EVT WideType = MulOp.getValueType(); // Dot products operate on chunks of four elements so there must be four times From 50147e0eb2532dfdf6e0dc8e81f6bf4e20cb4ed6 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 2 Sep 2024 10:10:35 +0100 Subject: [PATCH 32/33] Put DAG.getNode all on one line --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 7e9a860e3d1ba..3cdbe90295ccd 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21815,13 +21815,11 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, // as many elements in the wide type if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) - return DAG.getNode(Opcode, DL, MVT::nxv4i32, - NarrowOp, A, B); + return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B); if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) - return DAG.getNode(Opcode, DL, MVT::nxv2i64, - NarrowOp, A, B); + return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B); return SDValue(); } From 6ed303f2203622d548c7febeeab6a348fbfe4b1b Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 2 Sep 2024 10:30:02 +0100 Subject: [PATCH 33/33] Remove WideType --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 3cdbe90295ccd..bc771c44d0d9d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21809,16 +21809,13 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, EVT ReducedType = N->getValueType(0); EVT MulSrcType = A.getValueType(); - EVT WideType = MulOp.getValueType(); // Dot products operate on chunks of four elements so there must be four times // as many elements in the wide type - if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 && - MulSrcType == MVT::nxv16i8) + if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B); - if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 && - MulSrcType == MVT::nxv8i16) + if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B); return SDValue();