Skip to content

[AArch64] Lower partial add reduction to udot or svdot #101010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Sep 2, 2024

Conversation

SamTebbs33
Copy link
Collaborator

This patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. This also involves adding a shouldExpandPartialReductionIntrinsic target hook, which AArch64 will return false for so that it can be lowered rather than be expanded.

@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Sam Tebbs (SamTebbs33)

Changes

This patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. This also involves adding a shouldExpandPartialReductionIntrinsic target hook, which AArch64 will return false for so that it can be lowered rather than be expanded.


Full diff: https://github.com/llvm/llvm-project/pull/101010.diff

7 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+6)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+6)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+77)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+30)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+6)
  • (added) llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll (+109)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9d9886f4920a29..07d99aec47122a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -453,6 +453,12 @@ class TargetLoweringBase {
     return true;
   }
 
+  /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+  /// should be expanded using generic code in SelectionDAGBuilder.
+  virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+    return true;
+  }
+
   /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
   /// using generic code in SelectionDAGBuilder.
   virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1791f1b503379e..c70ab253c1aabc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7985,6 +7985,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
+
+    if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
+    }
+
     SDValue OpNode = getValue(I.getOperand(1));
     EVT ReducedTy = EVT::getEVT(I.getType());
     EVT FullTy = OpNode.getValueType();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d86e52d49000ae..d1ee58668ecbd7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1971,6 +1971,57 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
   return false;
 }
 
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+    const CallInst *CI) const {
+  const bool TargetLowers = false;
+  const bool GenericLowers = true;
+
+  auto *I = dyn_cast<IntrinsicInst>(CI);
+  if (!I)
+    return GenericLowers;
+
+  ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
+
+  if (!RetTy)
+    return GenericLowers;
+
+  ScalableVectorType *InputTy = nullptr;
+
+  auto RetScalarTy = RetTy->getScalarType();
+  if (RetScalarTy->isIntegerTy(64)) {
+    InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+  } else if (RetScalarTy->isIntegerTy(32)) {
+    InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+  }
+
+  if (!InputTy)
+    return GenericLowers;
+
+  Value *InputA;
+  Value *InputB;
+
+  auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+      m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+                                m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
+
+  if (!match(I, Pattern))
+    return GenericLowers;
+
+  auto Mul = cast<Instruction>(I->getOperand(1));
+
+  auto getOpcodeOfOperand = [&](unsigned Idx) {
+    return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
+  };
+
+  if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
+    return GenericLowers;
+
+  if (InputA->getType() != InputTy || InputB->getType() != InputTy)
+    return GenericLowers;
+
+  return TargetLowers;
+}
+
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return true;
@@ -21237,6 +21288,32 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
+  case Intrinsic::experimental_vector_partial_reduce_add: {
+    SDLoc DL(N);
+
+    auto NarrowOp = N->getOperand(1);
+    auto MulOp = N->getOperand(2);
+
+    auto ExtA = MulOp->getOperand(0);
+    auto ExtB = MulOp->getOperand(1);
+
+    unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+    if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+      DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+    else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+      DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+    assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+           "Unexpected dot product case encountered.");
+
+    auto A = ExtA->getOperand(0);
+    auto B = ExtB->getOperand(0);
+
+    auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+                       {IntrinsicId, NarrowOp, A, B});
+  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81e15185f985d5..fc79d9766719bc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -991,6 +991,8 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
 
+  bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+
   bool shouldExpandCttzElements(EVT VT) const override;
 
   /// If a change in streaming mode is required on entry to/return from a
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 45148449dfb821..792bd546019192 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3533,6 +3533,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
   return Cost;
 }
 
+bool AArch64TTIImpl::isPartialReductionSupported(
+    const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+    bool IsInputASignExtended, bool IsInputBSignExtended,
+    const Instruction *BinOp) const {
+  if (ReductionInstr->getOpcode() != Instruction::Add)
+    return false;
+
+  // Check that both extends are of the same type
+  if (IsInputASignExtended != IsInputBSignExtended)
+    return false;
+
+  if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
+    return false;
+
+  // Dot product only supports a scale factor of 4
+  if (ScaleFactor != 4)
+    return false;
+
+  Type *ReductionType = ReductionInstr->getType();
+  if (ReductionType->isIntegerTy(32)) {
+    if (!InputType->isIntegerTy(8))
+      return false;
+  } else if (ReductionType->isIntegerTy(64)) {
+    if (!InputType->isIntegerTy(16))
+      return false;
+  }
+
+  return true;
+}
+
 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a9189fd53f40bb..592b452134e778 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,6 +155,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return VF.getKnownMinValue() * ST->getVScaleForTuning();
   }
 
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp = nullptr) const;
+
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
   bool prefersVectorizedAddressing() const;
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..23b39387fb7a0c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -0,0 +1,109 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    udot z2.s, z0.b, z1.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.d, #0 // =0x0
+; CHECK-NEXT:    udot z2.d, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    sdot z2.s, z0.b, z1.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.d, #0 // =0x0
+; CHECK-NEXT:    sdot z2.d, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+; CHECK-LABEL: not_dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z0.h, z0.h, #0xff
+; CHECK-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    uunpkhi z2.s, z0.h
+; CHECK-NEXT:    uunpkhi z3.s, z1.h
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    uunpklo z1.s, z1.h
+; CHECK-NEXT:    mul z2.s, z2.s, z3.s
+; CHECK-NEXT:    mad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
+  %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
+  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 8 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+; CHECK-LABEL: not_dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z0.s, z0.s, #0xffff
+; CHECK-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpkhi z2.d, z0.s
+; CHECK-NEXT:    uunpkhi z3.d, z1.s
+; CHECK-NEXT:    uunpklo z0.d, z0.s
+; CHECK-NEXT:    uunpklo z1.d, z1.s
+; CHECK-NEXT:    mul z2.d, z2.d, z3.d
+; CHECK-NEXT:    mad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
+  %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
+  %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+attributes #0 = { "target-features"="+sve2" }

@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Sam Tebbs (SamTebbs33)

Changes

This patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. This also involves adding a shouldExpandPartialReductionIntrinsic target hook, which AArch64 will return false for so that it can be lowered rather than be expanded.


Full diff: https://github.com/llvm/llvm-project/pull/101010.diff

7 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+6)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+6)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+77)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+30)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+6)
  • (added) llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll (+109)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9d9886f4920a29..07d99aec47122a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -453,6 +453,12 @@ class TargetLoweringBase {
     return true;
   }
 
+  /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+  /// should be expanded using generic code in SelectionDAGBuilder.
+  virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+    return true;
+  }
+
   /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
   /// using generic code in SelectionDAGBuilder.
   virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1791f1b503379e..c70ab253c1aabc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7985,6 +7985,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
+
+    if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
+    }
+
     SDValue OpNode = getValue(I.getOperand(1));
     EVT ReducedTy = EVT::getEVT(I.getType());
     EVT FullTy = OpNode.getValueType();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d86e52d49000ae..d1ee58668ecbd7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1971,6 +1971,57 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
   return false;
 }
 
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+    const CallInst *CI) const {
+  const bool TargetLowers = false;
+  const bool GenericLowers = true;
+
+  auto *I = dyn_cast<IntrinsicInst>(CI);
+  if (!I)
+    return GenericLowers;
+
+  ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
+
+  if (!RetTy)
+    return GenericLowers;
+
+  ScalableVectorType *InputTy = nullptr;
+
+  auto RetScalarTy = RetTy->getScalarType();
+  if (RetScalarTy->isIntegerTy(64)) {
+    InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+  } else if (RetScalarTy->isIntegerTy(32)) {
+    InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+  }
+
+  if (!InputTy)
+    return GenericLowers;
+
+  Value *InputA;
+  Value *InputB;
+
+  auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+      m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+                                m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
+
+  if (!match(I, Pattern))
+    return GenericLowers;
+
+  auto Mul = cast<Instruction>(I->getOperand(1));
+
+  auto getOpcodeOfOperand = [&](unsigned Idx) {
+    return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
+  };
+
+  if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
+    return GenericLowers;
+
+  if (InputA->getType() != InputTy || InputB->getType() != InputTy)
+    return GenericLowers;
+
+  return TargetLowers;
+}
+
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return true;
@@ -21237,6 +21288,32 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
+  case Intrinsic::experimental_vector_partial_reduce_add: {
+    SDLoc DL(N);
+
+    auto NarrowOp = N->getOperand(1);
+    auto MulOp = N->getOperand(2);
+
+    auto ExtA = MulOp->getOperand(0);
+    auto ExtB = MulOp->getOperand(1);
+
+    unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+    if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+      DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+    else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+      DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+    assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+           "Unexpected dot product case encountered.");
+
+    auto A = ExtA->getOperand(0);
+    auto B = ExtB->getOperand(0);
+
+    auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+                       {IntrinsicId, NarrowOp, A, B});
+  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81e15185f985d5..fc79d9766719bc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -991,6 +991,8 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
 
+  bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+
   bool shouldExpandCttzElements(EVT VT) const override;
 
   /// If a change in streaming mode is required on entry to/return from a
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 45148449dfb821..792bd546019192 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3533,6 +3533,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
   return Cost;
 }
 
+bool AArch64TTIImpl::isPartialReductionSupported(
+    const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+    bool IsInputASignExtended, bool IsInputBSignExtended,
+    const Instruction *BinOp) const {
+  if (ReductionInstr->getOpcode() != Instruction::Add)
+    return false;
+
+  // Check that both extends are of the same type
+  if (IsInputASignExtended != IsInputBSignExtended)
+    return false;
+
+  if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
+    return false;
+
+  // Dot product only supports a scale factor of 4
+  if (ScaleFactor != 4)
+    return false;
+
+  Type *ReductionType = ReductionInstr->getType();
+  if (ReductionType->isIntegerTy(32)) {
+    if (!InputType->isIntegerTy(8))
+      return false;
+  } else if (ReductionType->isIntegerTy(64)) {
+    if (!InputType->isIntegerTy(16))
+      return false;
+  }
+
+  return true;
+}
+
 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a9189fd53f40bb..592b452134e778 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,6 +155,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return VF.getKnownMinValue() * ST->getVScaleForTuning();
   }
 
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp = nullptr) const;
+
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
   bool prefersVectorizedAddressing() const;
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..23b39387fb7a0c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -0,0 +1,109 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    udot z2.s, z0.b, z1.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.d, #0 // =0x0
+; CHECK-NEXT:    udot z2.d, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    sdot z2.s, z0.b, z1.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.d, #0 // =0x0
+; CHECK-NEXT:    sdot z2.d, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+; CHECK-LABEL: not_dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z0.h, z0.h, #0xff
+; CHECK-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    uunpkhi z2.s, z0.h
+; CHECK-NEXT:    uunpkhi z3.s, z1.h
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    uunpklo z1.s, z1.h
+; CHECK-NEXT:    mul z2.s, z2.s, z3.s
+; CHECK-NEXT:    mad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
+  %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
+  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 8 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+; CHECK-LABEL: not_dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z0.s, z0.s, #0xffff
+; CHECK-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpkhi z2.d, z0.s
+; CHECK-NEXT:    uunpkhi z3.d, z1.s
+; CHECK-NEXT:    uunpklo z0.d, z0.s
+; CHECK-NEXT:    uunpklo z1.d, z1.s
+; CHECK-NEXT:    mul z2.d, z2.d, z3.d
+; CHECK-NEXT:    mad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
+  %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
+  %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+attributes #0 = { "target-features"="+sve2" }

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SamTebbs33 I hadn't really followed the other patches before looking at this one, so perhaps some of my questions have already been answered in the other PRs.

Comment on lines 1976 to 1977
const bool TargetLowers = false;
const bool GenericLowers = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally find this more confusing than helpful, maybe just return true or false direclty instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 1979 to 1981
auto *I = dyn_cast<IntrinsicInst>(CI);
if (!I)
return GenericLowers;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an assert? I would expect this only to be called on intrinsic calls to experimental_vector_partial_reduce_add (or similar)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be. It's only called from somewhere where it is an intrinsic so I'll change this to an assertion.

Comment on lines 1995 to 1998
}

if (!InputTy)
return GenericLowers;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
if (!InputTy)
return GenericLowers;
else
return GenericLowers;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ScalableVectorType *InputTy = nullptr;

auto RetScalarTy = RetTy->getScalarType();
if (RetScalarTy->isIntegerTy(64)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: curly braces are unnecessary here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 1983 to 1986
ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());

if (!RetTy)
return GenericLowers;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
if (!RetTy)
return GenericLowers;
if (!isa<ScalableVectorType>(I->getType()))
return true;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is superseded by what I've done to address your other comment about the scalable vector type.

Comment on lines 2003 to 2005
auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth making this check earlier (before checking the types) and then do:

  if (match(I, m_Intrinsic(...)) {
    if ((I->getType()->isIntegerType(64) && InputA->getType()->isIntegerType(16)) ||
        (I->getType()->isIntegerType(32) && InputA->getType()->isIntegerType(8))) {
      auto *Mul = cast<Instruction>(I->getOperand(1);
      if (Mul->getOperand(0)->getOpcode() == Mul->getOperand(1)->getOpcode())
        return false;
    }
  }

  return true;

That way you don't need to construct any explicit Types, to then match later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's much cleaner, thanks. Some type construction was required but that's just because of how the ElementCount class works. Done.

if (!I)
return GenericLowers;

ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right that this is just an artificial limitation at the moment? (i.e. we can make this work for fixed-length vectors too)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're right. I've removed that limitation now.

@@ -3533,6 +3533,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
return Cost;
}

bool AArch64TTIImpl::isPartialReductionSupported(
const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is ScaleFactor ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the scaling difference between the element size of the inputs and the size of the resulting scalar from the entire reduction. In the tests you'll see that the input types are extended and the resulting scalar is 4 times the size, which mirrors how the dot product instructions work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'm missing something but isPartialReductionSupported isn't used and thus doesn't belong in this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct, thanks for spotting that. It's not used in this PR but is used by the loop vectorizer part of this work, so I'll move it to that PR.

Comment on lines 3555 to 3563
if (ReductionType->isIntegerTy(32)) {
if (!InputType->isIntegerTy(8))
return false;
} else if (ReductionType->isIntegerTy(64)) {
if (!InputType->isIntegerTy(16))
return false;
}

return true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (ReductionType->isIntegerTy(32)) {
if (!InputType->isIntegerTy(8))
return false;
} else if (ReductionType->isIntegerTy(64)) {
if (!InputType->isIntegerTy(16))
return false;
}
return true;
return (ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) ||
(ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16)));

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (ScaleFactor != 4)
return false;

Type *ReductionType = ReductionInstr->getType();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not testing that the type must be scalable. Is that on purpose?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally we were limiting the creation of the intrinsic to when scalable types were used so a check here wasn't necessary. But I will add a limitation to scalable types here as an initial version of this work and we can get it to work for NEON later on.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link

github-actions bot commented Jul 29, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this remove the shouldExpandPartialReductionIntrinsic and always emit a node, that gets expanded as needed like the rest of DAG?

@paulwalker-arm
Copy link
Collaborator

I request that we defer creating a dedicated ISD node for now, with the precedent already set (see get_active_lane_mask and cttz_elts). In this instance I don't think we'll be waiting long for a dedicated ISD node but my concern is that such a node it unlikely to want to follow the exact semantics of the intrinsic and so I'd rather take the time to see how the functionality plays out before committing ourselves rather than hastily implementing all the functionality that typically comes from adding a new ISD node.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worried I was asking for quite a lot of work when I added the comment, but I don't think we can just rely on the IR matching the DAG like this. It will need to at least check that the DAG nodes match as expected, and probably add a fallback in case. It would be good if we handled larger other vector sizes too, but that might have to wait until we do add nodes to handle legalization properly.

case Intrinsic::experimental_vector_partial_reduce_add: {
SDLoc DL(N);

auto NarrowOp = N->getOperand(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't just assume that the node is still how we expect (although that will likely almost always be the case). Can you add checks that the nodes are Mul/zext/sext and add some fallback in case they are not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Dave, I've now added a fallback that creates a chain of ADDs if the nodes aren't as expected.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realised I got a bit of the logic with IsValidDotProduct wrong. I'll fix that up in the next commit.

Comment on lines 1977 to 1978
auto *I = dyn_cast<IntrinsicInst>(CI);
assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could just be a cast without the assert, but it might be better to pass a IntrinsicInst directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds good to me. Done.

InputAType->getElementType()->isIntegerTy(16) &&
InputAType->getElementCount() == ExpectedCount8 &&
InputAType == InputBType) ||

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra line (I would have expected the formatter to remove it).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

(RetTy->getScalarType()->isIntegerTy(32) &&
InputAType->getElementType()->isIntegerTy(8) &&
InputAType->getElementCount() == ExpectedCount16 &&
InputAType == InputBType)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A i8->i64 reduction could use a i8->i32 udot + zext.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea, added that now. The accumulator operand is of the wider type so I made it insert a zeroinitializer and add the accumulator after extending.

auto ExtB = MulOp->getOperand(1);
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the types too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the types be altered between shouldExpandPartialReductionIntrinsic being called and lowering happening? If not then I'd expect them to be fine here even if the nodes themselves have changed.

@@ -0,0 +1,109 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can usually remove -O3

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unknwon->unknown, or I often just use -mtriple=aarch64 nowadays.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks!

ret <vscale x 2 x i64> %partial.reduce
}

attributes #0 = { "target-features"="+sve2" }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed as it is in the run line.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 4 to 5
target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
target triple = "aarch64-none-unknown-elf"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't match the run line, and datalayout is often unnecessary in llc tests.

Comment on lines 1988 to 1989
m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to produce a udot even if there are multiple uses of the inner zext/sext. In the long run we will likely need them to be generated from known-bits too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 21350 to 21353
if (Type1ElementSize < Type0ElementSize)
Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec);
else if (Type1ElementSize > Type0ElementSize)
Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be extending/truncating the type. From what I understand they should be the same size.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're right, removed now.

@@ -3533,6 +3533,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
return Cost;
}

bool AArch64TTIImpl::isPartialReductionSupported(
const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'm missing something but isPartialReductionSupported isn't used and thus doesn't belong in this PR?

Comment on lines 1984 to 1990
m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
m_Value(), m_OneUse(m_Mul(m_ZExtOrSExt(m_Value(InputA)),
m_ZExtOrSExt(m_Value(InputB))))))) {
VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
if (!InputAType || !InputBType)
return true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function needs to be this complex. Is it not possible to base the decision purely on the result type (e.g. legal scalable vectors).

I ask because you're putting significant effort into matching DOT instructions, which is not that unreasonable given the PR's title but the true intent of this patch is to enable better code generation for the partial reduction intrinsic, of which DOT is just one possible destination.

For me the complexity can stay in the target specific DAG combine, which will evolve over time, with the only cost being perhaps a duplication of the common lowering code. This duplication is easily solved by moving it into a dedicated SelectionDAG function that can be called by both the builder and the target specific DAG combine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Paul, I like the approach of sharing the expansion code in SelectionDAG. Let me know what you think of the new implementation.

@@ -1590,6 +1590,11 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);

/// Expand a partial reduction intrinsic call.
/// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type.
SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name lacks sufficient context (i.e. the "add" part is very important) but at the same time this nicely abstracts the intrinsic side of things so perhaps "getPartialReduceAdd()" is more representative?

On similar lines what do you think to mirroring the operand order of similar getNode() calls? (i.e. DL, VT, operands....)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getPartialReduceAdd() definitely fits the naming scheme of the rest of the functions in this file. Sounds good to me.

Comment on lines 1977 to 1978
VectorType *RetTy = dyn_cast<VectorType>(I->getType());
if (!RetTy || !RetTy->isScalableTy())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The langref says both operands must be vectors so this can be just cast<VectorType> with the !RetTy dropped.

That said, I think you'd be better off calling EVT::getEVT() because then the following code can just be:
return VT == MVT::nv4i32 || VT == MVT::nv2i64 .....

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a lot cleaner, thanks. Done.

Comment on lines 1974 to 1975
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
const IntrinsicInst *I) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the intent is for this function to support other partial reduction then you should be rejecting intrinsics whose ID is not Intrinsic::experimental_vector_partial_reduce_add.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

SDValue tryLowerPartialReductionToDot(SDNode *N,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {

Copy link
Collaborator

@paulwalker-arm paulwalker-arm Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add an assert to verify N is what you expect it to be.

In general you should assume knowledge from shouldExpandPartialReductionIntrinsic and should reverify the input types are supported by the following code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 131 to 132
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
<vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not looked how prevalent this is but the intrinsic name is not consistent with the operand types.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that would be a mistake I made when copying the 8 to 64 tests. I've removed them as part of the commit that removes the nxv4i64 handling.

Comment on lines 8 to 14
; CHECK-NEXT: udot z2.s, z0.b, z1.b
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the tests pass in zero as the first operand, but the DAG combine doesn't have any special handling for zero and so seems irrelevant to the PR. If this is true then all such tests should be removed with that operand just being passed into the test function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, done.

Comment on lines 21310 to 21314
if (IsSExt)
DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
else if (IsZExt)
DotIntrinsicId = Intrinsic::aarch64_sve_udot;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than emit an intrinsic, which is cumbersome to create, can you use AArch64ISD::SDOT/UDOT directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 21338 to 21339
auto Extended = DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND,
DL, NarrowOp.getValueType(), {DotProduct});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need the {} wrapping the operands because there's enough getNode functions for the number of operands you have here. The same likely goes for the other places you call getNode.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 21330 to 21331
Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
DAG.getConstant(0, DL, MVT::i32));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to use DAG.getConstant(0, DL, MVT::nxv4i32); here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I didn't know that. It's not needed now that the nxv4i64 case has been removed.

@@ -21229,6 +21249,101 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}

SDValue tryLowerPartialReductionToDot(SDNode *N,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're trying to do too much in this one PR, which is making it hard to see if the complexity is required. Please can we start by just handling the cases that have a direct mapping to DOT instructions.

Specifically, please remove the nxv4i64 result type handling because this seems to be the primary source of complexity. With that gone I think you'll be able to talk purely about EVTs and do less element count based maths. For example, I think you be able to implement the function more akin to:

validate input and get pre-extend operands

if ((result_type == nxv4i32 && input_type == nxv16i8)
  return getNode(...);

if (result_type == nxv2i64 && input_type == nxv8i16)
  return getNode(...);

return SDValue();

This will give us a foundation for future PRs to build on (i.e. to support non-legal types, 2-way dot products, usdot and figure out what ISD node semantics we need to best handle the intrinsic).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@SamTebbs33 SamTebbs33 force-pushed the partial-reduction-lowering branch from 651c200 to 7629679 Compare August 29, 2024 10:40
Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly superficial comments with the removal of WideType (assuming you agree that it is redundant) being the only one I really care about. Otherwise the patch looks good to me.

Comment on lines 21822 to 21823
if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 &&
ExtendedType == MVT::nxv16i8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is WideType required here? Given by definition it must have the same element count as ExtendedType it should be enough to say if (ReducedType == MVT::nxv4i32 && ExtendedType == MVT::nxv16i8).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remove this check, what guarantees that the mul has the proper type?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DAG combines can and should assume the DAG is well formed.

In this instance you can be certain the element type of the mul will match the element type of the partial.add's result type (i.e. ReducedType ) and you can be certain the mul will have the same number of elements as its operands, which by extension means the same number of elements as its pre-extended operands (i.e ExtendedType).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thank you. Done.

Comment on lines 21827 to 21828
if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 &&
ExtendedType == MVT::nxv8i16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, WideType might be redundant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@SamTebbs33 SamTebbs33 merged commit 44cfbef into llvm:main Sep 2, 2024
5 of 8 checks passed
fhahn pushed a commit to fhahn/llvm-project that referenced this pull request Jan 21, 2025
This patch introduces lowering of the partial add reduction intrinsic to
a udot or svdot for AArch64. This also involves adding a
`shouldExpandPartialReductionIntrinsic` target hook, which AArch64 will
return false from in the cases that it can be lowered.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants