-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[AArch64] Lower partial add reduction to udot or svdot #101010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-selectiondag Author: Sam Tebbs (SamTebbs33) ChangesThis patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. This also involves adding a Full diff: https://github.com/llvm/llvm-project/pull/101010.diff 7 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9d9886f4920a29..07d99aec47122a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -453,6 +453,12 @@ class TargetLoweringBase {
return true;
}
+ /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+ /// should be expanded using generic code in SelectionDAGBuilder.
+ virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+ return true;
+ }
+
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
/// using generic code in SelectionDAGBuilder.
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1791f1b503379e..c70ab253c1aabc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7985,6 +7985,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
+
+ if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+ visitTargetIntrinsic(I, Intrinsic);
+ return;
+ }
+
SDValue OpNode = getValue(I.getOperand(1));
EVT ReducedTy = EVT::getEVT(I.getType());
EVT FullTy = OpNode.getValueType();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d86e52d49000ae..d1ee58668ecbd7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1971,6 +1971,57 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
return false;
}
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+ const CallInst *CI) const {
+ const bool TargetLowers = false;
+ const bool GenericLowers = true;
+
+ auto *I = dyn_cast<IntrinsicInst>(CI);
+ if (!I)
+ return GenericLowers;
+
+ ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
+
+ if (!RetTy)
+ return GenericLowers;
+
+ ScalableVectorType *InputTy = nullptr;
+
+ auto RetScalarTy = RetTy->getScalarType();
+ if (RetScalarTy->isIntegerTy(64)) {
+ InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+ } else if (RetScalarTy->isIntegerTy(32)) {
+ InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+ }
+
+ if (!InputTy)
+ return GenericLowers;
+
+ Value *InputA;
+ Value *InputB;
+
+ auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+ m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+ m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
+
+ if (!match(I, Pattern))
+ return GenericLowers;
+
+ auto Mul = cast<Instruction>(I->getOperand(1));
+
+ auto getOpcodeOfOperand = [&](unsigned Idx) {
+ return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
+ };
+
+ if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
+ return GenericLowers;
+
+ if (InputA->getType() != InputTy || InputB->getType() != InputTy)
+ return GenericLowers;
+
+ return TargetLowers;
+}
+
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
if (!Subtarget->isSVEorStreamingSVEAvailable())
return true;
@@ -21237,6 +21288,32 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
+ case Intrinsic::experimental_vector_partial_reduce_add: {
+ SDLoc DL(N);
+
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+
+ unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+ if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+ DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+ else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+ DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+ assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+ "Unexpected dot product case encountered.");
+
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+
+ auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+ {IntrinsicId, NarrowOp, A, B});
+ }
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81e15185f985d5..fc79d9766719bc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -991,6 +991,8 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
+ bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+
bool shouldExpandCttzElements(EVT VT) const override;
/// If a change in streaming mode is required on entry to/return from a
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 45148449dfb821..792bd546019192 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3533,6 +3533,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
return Cost;
}
+bool AArch64TTIImpl::isPartialReductionSupported(
+ const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+ bool IsInputASignExtended, bool IsInputBSignExtended,
+ const Instruction *BinOp) const {
+ if (ReductionInstr->getOpcode() != Instruction::Add)
+ return false;
+
+ // Check that both extends are of the same type
+ if (IsInputASignExtended != IsInputBSignExtended)
+ return false;
+
+ if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
+ return false;
+
+ // Dot product only supports a scale factor of 4
+ if (ScaleFactor != 4)
+ return false;
+
+ Type *ReductionType = ReductionInstr->getType();
+ if (ReductionType->isIntegerTy(32)) {
+ if (!InputType->isIntegerTy(8))
+ return false;
+ } else if (ReductionType->isIntegerTy(64)) {
+ if (!InputType->isIntegerTy(16))
+ return false;
+ }
+
+ return true;
+}
+
unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
return ST->getMaxInterleaveFactor();
}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a9189fd53f40bb..592b452134e778 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,6 +155,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
return VF.getKnownMinValue() * ST->getVScaleForTuning();
}
+ bool isPartialReductionSupported(const Instruction *ReductionInstr,
+ Type *InputType, unsigned ScaleFactor,
+ bool IsInputASignExtended,
+ bool IsInputBSignExtended,
+ const Instruction *BinOp = nullptr) const;
+
unsigned getMaxInterleaveFactor(ElementCount VF);
bool prefersVectorizedAddressing() const;
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..23b39387fb7a0c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -0,0 +1,109 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: udot z2.s, z0.b, z1.b
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.d, #0 // =0x0
+; CHECK-NEXT: udot z2.d, z0.h, z1.h
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp_sext:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: sdot z2.s, z0.b, z1.b
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.d, #0 // =0x0
+; CHECK-NEXT: sdot z2.d, z0.h, z1.h
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+; CHECK-LABEL: not_dotp:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z0.h, z0.h, #0xff
+; CHECK-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: uunpkhi z2.s, z0.h
+; CHECK-NEXT: uunpkhi z3.s, z1.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: uunpklo z1.s, z1.h
+; CHECK-NEXT: mul z2.s, z2.s, z3.s
+; CHECK-NEXT: mad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
+ %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
+ %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 8 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+; CHECK-LABEL: not_dotp_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z0.s, z0.s, #0xffff
+; CHECK-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: uunpkhi z2.d, z0.s
+; CHECK-NEXT: uunpkhi z3.d, z1.s
+; CHECK-NEXT: uunpklo z0.d, z0.s
+; CHECK-NEXT: uunpklo z1.d, z1.s
+; CHECK-NEXT: mul z2.d, z2.d, z3.d
+; CHECK-NEXT: mad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
+ %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
+ %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+attributes #0 = { "target-features"="+sve2" }
|
@llvm/pr-subscribers-backend-aarch64 Author: Sam Tebbs (SamTebbs33) ChangesThis patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. This also involves adding a Full diff: https://github.com/llvm/llvm-project/pull/101010.diff 7 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9d9886f4920a29..07d99aec47122a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -453,6 +453,12 @@ class TargetLoweringBase {
return true;
}
+ /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+ /// should be expanded using generic code in SelectionDAGBuilder.
+ virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+ return true;
+ }
+
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
/// using generic code in SelectionDAGBuilder.
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1791f1b503379e..c70ab253c1aabc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7985,6 +7985,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
+
+ if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+ visitTargetIntrinsic(I, Intrinsic);
+ return;
+ }
+
SDValue OpNode = getValue(I.getOperand(1));
EVT ReducedTy = EVT::getEVT(I.getType());
EVT FullTy = OpNode.getValueType();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d86e52d49000ae..d1ee58668ecbd7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1971,6 +1971,57 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
return false;
}
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+ const CallInst *CI) const {
+ const bool TargetLowers = false;
+ const bool GenericLowers = true;
+
+ auto *I = dyn_cast<IntrinsicInst>(CI);
+ if (!I)
+ return GenericLowers;
+
+ ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
+
+ if (!RetTy)
+ return GenericLowers;
+
+ ScalableVectorType *InputTy = nullptr;
+
+ auto RetScalarTy = RetTy->getScalarType();
+ if (RetScalarTy->isIntegerTy(64)) {
+ InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+ } else if (RetScalarTy->isIntegerTy(32)) {
+ InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+ }
+
+ if (!InputTy)
+ return GenericLowers;
+
+ Value *InputA;
+ Value *InputB;
+
+ auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+ m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+ m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
+
+ if (!match(I, Pattern))
+ return GenericLowers;
+
+ auto Mul = cast<Instruction>(I->getOperand(1));
+
+ auto getOpcodeOfOperand = [&](unsigned Idx) {
+ return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
+ };
+
+ if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
+ return GenericLowers;
+
+ if (InputA->getType() != InputTy || InputB->getType() != InputTy)
+ return GenericLowers;
+
+ return TargetLowers;
+}
+
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
if (!Subtarget->isSVEorStreamingSVEAvailable())
return true;
@@ -21237,6 +21288,32 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
+ case Intrinsic::experimental_vector_partial_reduce_add: {
+ SDLoc DL(N);
+
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+
+ unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+ if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+ DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+ else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+ DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+ assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+ "Unexpected dot product case encountered.");
+
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+
+ auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+ {IntrinsicId, NarrowOp, A, B});
+ }
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81e15185f985d5..fc79d9766719bc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -991,6 +991,8 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
+ bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+
bool shouldExpandCttzElements(EVT VT) const override;
/// If a change in streaming mode is required on entry to/return from a
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 45148449dfb821..792bd546019192 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3533,6 +3533,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
return Cost;
}
+bool AArch64TTIImpl::isPartialReductionSupported(
+ const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+ bool IsInputASignExtended, bool IsInputBSignExtended,
+ const Instruction *BinOp) const {
+ if (ReductionInstr->getOpcode() != Instruction::Add)
+ return false;
+
+ // Check that both extends are of the same type
+ if (IsInputASignExtended != IsInputBSignExtended)
+ return false;
+
+ if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
+ return false;
+
+ // Dot product only supports a scale factor of 4
+ if (ScaleFactor != 4)
+ return false;
+
+ Type *ReductionType = ReductionInstr->getType();
+ if (ReductionType->isIntegerTy(32)) {
+ if (!InputType->isIntegerTy(8))
+ return false;
+ } else if (ReductionType->isIntegerTy(64)) {
+ if (!InputType->isIntegerTy(16))
+ return false;
+ }
+
+ return true;
+}
+
unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
return ST->getMaxInterleaveFactor();
}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a9189fd53f40bb..592b452134e778 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,6 +155,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
return VF.getKnownMinValue() * ST->getVScaleForTuning();
}
+ bool isPartialReductionSupported(const Instruction *ReductionInstr,
+ Type *InputType, unsigned ScaleFactor,
+ bool IsInputASignExtended,
+ bool IsInputBSignExtended,
+ const Instruction *BinOp = nullptr) const;
+
unsigned getMaxInterleaveFactor(ElementCount VF);
bool prefersVectorizedAddressing() const;
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..23b39387fb7a0c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -0,0 +1,109 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: udot z2.s, z0.b, z1.b
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.d, #0 // =0x0
+; CHECK-NEXT: udot z2.d, z0.h, z1.h
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp_sext:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: sdot z2.s, z0.b, z1.b
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.d, #0 // =0x0
+; CHECK-NEXT: sdot z2.d, z0.h, z1.h
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+; CHECK-LABEL: not_dotp:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z0.h, z0.h, #0xff
+; CHECK-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: uunpkhi z2.s, z0.h
+; CHECK-NEXT: uunpkhi z3.s, z1.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: uunpklo z1.s, z1.h
+; CHECK-NEXT: mul z2.s, z2.s, z3.s
+; CHECK-NEXT: mad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
+ %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
+ %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 8 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+; CHECK-LABEL: not_dotp_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z0.s, z0.s, #0xffff
+; CHECK-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: uunpkhi z2.d, z0.s
+; CHECK-NEXT: uunpkhi z3.d, z1.s
+; CHECK-NEXT: uunpklo z0.d, z0.s
+; CHECK-NEXT: uunpklo z1.d, z1.s
+; CHECK-NEXT: mul z2.d, z2.d, z3.d
+; CHECK-NEXT: mad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
+ %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
+ %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+attributes #0 = { "target-features"="+sve2" }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @SamTebbs33 I hadn't really followed the other patches before looking at this one, so perhaps some of my questions have already been answered in the other PRs.
const bool TargetLowers = false; | ||
const bool GenericLowers = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally find this more confusing than helpful, maybe just return true
or false
direclty instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
auto *I = dyn_cast<IntrinsicInst>(CI); | ||
if (!I) | ||
return GenericLowers; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be an assert? I would expect this only to be called on intrinsic calls to experimental_vector_partial_reduce_add
(or similar)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be. It's only called from somewhere where it is an intrinsic so I'll change this to an assertion.
} | ||
|
||
if (!InputTy) | ||
return GenericLowers; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} | |
if (!InputTy) | |
return GenericLowers; | |
else | |
return GenericLowers; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
ScalableVectorType *InputTy = nullptr; | ||
|
||
auto RetScalarTy = RetTy->getScalarType(); | ||
if (RetScalarTy->isIntegerTy(64)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: curly braces are unnecessary here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType()); | ||
|
||
if (!RetTy) | ||
return GenericLowers; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType()); | |
if (!RetTy) | |
return GenericLowers; | |
if (!isa<ScalableVectorType>(I->getType())) | |
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is superseded by what I've done to address your other comment about the scalable vector type.
auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>( | ||
m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))), | ||
m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth making this check earlier (before checking the types) and then do:
if (match(I, m_Intrinsic(...)) {
if ((I->getType()->isIntegerType(64) && InputA->getType()->isIntegerType(16)) ||
(I->getType()->isIntegerType(32) && InputA->getType()->isIntegerType(8))) {
auto *Mul = cast<Instruction>(I->getOperand(1);
if (Mul->getOperand(0)->getOpcode() == Mul->getOperand(1)->getOpcode())
return false;
}
}
return true;
That way you don't need to construct any explicit Types, to then match later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's much cleaner, thanks. Some type construction was required but that's just because of how the ElementCount
class works. Done.
if (!I) | ||
return GenericLowers; | ||
|
||
ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I right that this is just an artificial limitation at the moment? (i.e. we can make this work for fixed-length vectors too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah you're right. I've removed that limitation now.
@@ -3533,6 +3533,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) { | |||
return Cost; | |||
} | |||
|
|||
bool AArch64TTIImpl::isPartialReductionSupported( | |||
const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is ScaleFactor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the scaling difference between the element size of the inputs and the size of the resulting scalar from the entire reduction. In the tests you'll see that the input types are extended and the resulting scalar is 4 times the size, which mirrors how the dot product instructions work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if I'm missing something but isPartialReductionSupported
isn't used and thus doesn't belong in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct, thanks for spotting that. It's not used in this PR but is used by the loop vectorizer part of this work, so I'll move it to that PR.
if (ReductionType->isIntegerTy(32)) { | ||
if (!InputType->isIntegerTy(8)) | ||
return false; | ||
} else if (ReductionType->isIntegerTy(64)) { | ||
if (!InputType->isIntegerTy(16)) | ||
return false; | ||
} | ||
|
||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (ReductionType->isIntegerTy(32)) { | |
if (!InputType->isIntegerTy(8)) | |
return false; | |
} else if (ReductionType->isIntegerTy(64)) { | |
if (!InputType->isIntegerTy(16)) | |
return false; | |
} | |
return true; | |
return (ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) || | |
(ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if (ScaleFactor != 4) | ||
return false; | ||
|
||
Type *ReductionType = ReductionInstr->getType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not testing that the type must be scalable. Is that on purpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally we were limiting the creation of the intrinsic to when scalable types were used so a check here wasn't necessary. But I will add a limitation to scalable types here as an initial version of this work and we can get it to work for NEON later on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this remove the shouldExpandPartialReductionIntrinsic and always emit a node, that gets expanded as needed like the rest of DAG?
I request that we defer creating a dedicated ISD node for now, with the precedent already set (see get_active_lane_mask and cttz_elts). In this instance I don't think we'll be waiting long for a dedicated ISD node but my concern is that such a node it unlikely to want to follow the exact semantics of the intrinsic and so I'd rather take the time to see how the functionality plays out before committing ourselves rather than hastily implementing all the functionality that typically comes from adding a new ISD node. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I worried I was asking for quite a lot of work when I added the comment, but I don't think we can just rely on the IR matching the DAG like this. It will need to at least check that the DAG nodes match as expected, and probably add a fallback in case. It would be good if we handled larger other vector sizes too, but that might have to wait until we do add nodes to handle legalization properly.
case Intrinsic::experimental_vector_partial_reduce_add: { | ||
SDLoc DL(N); | ||
|
||
auto NarrowOp = N->getOperand(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can't just assume that the node is still how we expect (although that will likely almost always be the case). Can you add checks that the nodes are Mul/zext/sext and add some fallback in case they are not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Dave, I've now added a fallback that creates a chain of ADDs if the nodes aren't as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realised I got a bit of the logic with IsValidDotProduct
wrong. I'll fix that up in the next commit.
auto *I = dyn_cast<IntrinsicInst>(CI); | ||
assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could just be a cast without the assert, but it might be better to pass a IntrinsicInst directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that sounds good to me. Done.
InputAType->getElementType()->isIntegerTy(16) && | ||
InputAType->getElementCount() == ExpectedCount8 && | ||
InputAType == InputBType) || | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove extra line (I would have expected the formatter to remove it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
(RetTy->getScalarType()->isIntegerTy(32) && | ||
InputAType->getElementType()->isIntegerTy(8) && | ||
InputAType->getElementCount() == ExpectedCount16 && | ||
InputAType == InputBType)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A i8->i64 reduction could use a i8->i32 udot + zext.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea, added that now. The accumulator operand is of the wider type so I made it insert a zeroinitializer and add the accumulator after extending.
auto ExtB = MulOp->getOperand(1); | ||
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND; | ||
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; | ||
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check the types too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the types be altered between shouldExpandPartialReductionIntrinsic
being called and lowering happening? If not then I'd expect them to be fine here even if the nodes themselves have changed.
@@ -0,0 +1,109 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 | |||
; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can usually remove -O3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unknwon->unknown, or I often just use -mtriple=aarch64
nowadays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah thanks!
ret <vscale x 2 x i64> %partial.reduce | ||
} | ||
|
||
attributes #0 = { "target-features"="+sve2" } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed as it is in the run line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" | ||
target triple = "aarch64-none-unknown-elf" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't match the run line, and datalayout is often unnecessary in llc tests.
m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))), | ||
m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to produce a udot even if there are multiple uses of the inner zext/sext. In the long run we will likely need them to be generated from known-bits too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if (Type1ElementSize < Type0ElementSize) | ||
Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec); | ||
else if (Type1ElementSize > Type0ElementSize) | ||
Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be extending/truncating the type. From what I understand they should be the same size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah you're right, removed now.
@@ -3533,6 +3533,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) { | |||
return Cost; | |||
} | |||
|
|||
bool AArch64TTIImpl::isPartialReductionSupported( | |||
const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if I'm missing something but isPartialReductionSupported
isn't used and thus doesn't belong in this PR?
m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>( | ||
m_Value(), m_OneUse(m_Mul(m_ZExtOrSExt(m_Value(InputA)), | ||
m_ZExtOrSExt(m_Value(InputB))))))) { | ||
VectorType *InputAType = dyn_cast<VectorType>(InputA->getType()); | ||
VectorType *InputBType = dyn_cast<VectorType>(InputB->getType()); | ||
if (!InputAType || !InputBType) | ||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this function needs to be this complex. Is it not possible to base the decision purely on the result type (e.g. legal scalable vectors).
I ask because you're putting significant effort into matching DOT instructions, which is not that unreasonable given the PR's title but the true intent of this patch is to enable better code generation for the partial reduction intrinsic, of which DOT is just one possible destination.
For me the complexity can stay in the target specific DAG combine, which will evolve over time, with the only cost being perhaps a duplication of the common lowering code. This duplication is easily solved by moving it into a dedicated SelectionDAG function that can be called by both the builder and the target specific DAG combine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Paul, I like the approach of sharing the expansion code in SelectionDAG. Let me know what you think of the new implementation.
@@ -1590,6 +1590,11 @@ class SelectionDAG { | |||
/// the target's desired shift amount type. | |||
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op); | |||
|
|||
/// Expand a partial reduction intrinsic call. | |||
/// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type. | |||
SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name lacks sufficient context (i.e. the "add" part is very important) but at the same time this nicely abstracts the intrinsic side of things so perhaps "getPartialReduceAdd()" is more representative?
On similar lines what do you think to mirroring the operand order of similar getNode()
calls? (i.e. DL, VT, operands....)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getPartialReduceAdd()
definitely fits the naming scheme of the rest of the functions in this file. Sounds good to me.
VectorType *RetTy = dyn_cast<VectorType>(I->getType()); | ||
if (!RetTy || !RetTy->isScalableTy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The langref says both operands must be vectors so this can be just cast<VectorType>
with the !RetTy
dropped.
That said, I think you'd be better off calling EVT::getEVT()
because then the following code can just be:
return VT == MVT::nv4i32 || VT == MVT::nv2i64 ....
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a lot cleaner, thanks. Done.
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( | ||
const IntrinsicInst *I) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the intent is for this function to support other partial reduction then you should be rejecting intrinsics whose ID is not Intrinsic::experimental_vector_partial_reduce_add
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
SDValue tryLowerPartialReductionToDot(SDNode *N, | ||
const AArch64Subtarget *Subtarget, | ||
SelectionDAG &DAG) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add an assert to verify N is what you expect it to be.
In general you should assume knowledge from shouldExpandPartialReductionIntrinsic
and should reverify the input types are supported by the following code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( | ||
<vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've not looked how prevalent this is but the intrinsic name is not consistent with the operand types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that would be a mistake I made when copying the 8 to 64 tests. I've removed them as part of the commit that removes the nxv4i64 handling.
; CHECK-NEXT: udot z2.s, z0.b, z1.b | ||
; CHECK-NEXT: mov z0.d, z2.d | ||
; CHECK-NEXT: ret | ||
entry: | ||
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32> | ||
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32> | ||
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide | ||
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult) | ||
ret <vscale x 4 x i32> %partial.reduce |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the tests pass in zero as the first operand, but the DAG combine doesn't have any special handling for zero and so seems irrelevant to the PR. If this is true then all such tests should be removed with that operand just being passed into the test function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, done.
if (IsSExt) | ||
DotIntrinsicId = Intrinsic::aarch64_sve_sdot; | ||
else if (IsZExt) | ||
DotIntrinsicId = Intrinsic::aarch64_sve_udot; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than emit an intrinsic, which is cumbersome to create, can you use AArch64ISD::SDOT/UDOT
directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
auto Extended = DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, | ||
DL, NarrowOp.getValueType(), {DotProduct}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need the {}
wrapping the operands because there's enough getNode functions for the number of operands you have here. The same likely goes for the other places you call getNode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32, | ||
DAG.getConstant(0, DL, MVT::i32)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to use DAG.getConstant(0, DL, MVT::nxv4i32);
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, I didn't know that. It's not needed now that the nxv4i64 case has been removed.
@@ -21229,6 +21249,101 @@ static SDValue tryCombineWhileLo(SDNode *N, | |||
return SDValue(N, 0); | |||
} | |||
|
|||
SDValue tryLowerPartialReductionToDot(SDNode *N, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're trying to do too much in this one PR, which is making it hard to see if the complexity is required. Please can we start by just handling the cases that have a direct mapping to DOT instructions.
Specifically, please remove the nxv4i64
result type handling because this seems to be the primary source of complexity. With that gone I think you'll be able to talk purely about EVTs and do less element count based maths. For example, I think you be able to implement the function more akin to:
validate input and get pre-extend operands
if ((result_type == nxv4i32 && input_type == nxv16i8)
return getNode(...);
if (result_type == nxv2i64 && input_type == nxv8i16)
return getNode(...);
return SDValue();
This will give us a foundation for future PRs to build on (i.e. to support non-legal types, 2-way dot products, usdot and figure out what ISD node semantics we need to best handle the intrinsic).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
This patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64.
651c200
to
7629679
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly superficial comments with the removal of WideType
(assuming you agree that it is redundant) being the only one I really care about. Otherwise the patch looks good to me.
if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 && | ||
ExtendedType == MVT::nxv16i8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is WideType
required here? Given by definition it must have the same element count as ExtendedType
it should be enough to say if (ReducedType == MVT::nxv4i32 && ExtendedType == MVT::nxv16i8)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remove this check, what guarantees that the mul has the proper type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DAG combines can and should assume the DAG is well formed.
In this instance you can be certain the element type of the mul will match the element type of the partial.add's result type (i.e. ReducedType ) and you can be certain the mul will have the same number of elements as its operands, which by extension means the same number of elements as its pre-extended operands (i.e ExtendedType).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, thank you. Done.
if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 && | ||
ExtendedType == MVT::nxv8i16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above, WideType
might be redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
This patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. This also involves adding a `shouldExpandPartialReductionIntrinsic` target hook, which AArch64 will return false from in the cases that it can be lowered.
This patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. This also involves adding a
shouldExpandPartialReductionIntrinsic
target hook, which AArch64 will return false for so that it can be lowered rather than be expanded.