Skip to content

[AArch64][NEON][SVE] Lower mixed sign/zero extended partial reductions to usdot #107566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 19, 2024
55 changes: 37 additions & 18 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2701,6 +2701,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::SADDLV)
MAKE_CASE(AArch64ISD::SDOT)
MAKE_CASE(AArch64ISD::UDOT)
MAKE_CASE(AArch64ISD::USDOT)
MAKE_CASE(AArch64ISD::SMINV)
MAKE_CASE(AArch64ISD::UMINV)
MAKE_CASE(AArch64ISD::SMAXV)
Expand Down Expand Up @@ -6114,6 +6115,11 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2), Op.getOperand(3));
}
case Intrinsic::aarch64_neon_usdot:
case Intrinsic::aarch64_sve_usdot: {
return DAG.getNode(AArch64ISD::USDOT, dl, Op.getValueType(),
Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
}
case Intrinsic::get_active_lane_mask: {
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64);
Expand Down Expand Up @@ -21824,37 +21830,50 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,

auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))

if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
!ISD::isExtOpcode(ExtB->getOpcode()))
return SDValue();
bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;

auto A = ExtA->getOperand(0);
auto B = ExtB->getOperand(0);
if (A.getValueType() != B.getValueType())
return SDValue();

unsigned Opcode = 0;

if (IsSExt)
Opcode = AArch64ISD::SDOT;
else if (IsZExt)
Opcode = AArch64ISD::UDOT;

assert(Opcode != 0 && "Unexpected dot product case encountered.");

EVT ReducedType = N->getValueType(0);
EVT MulSrcType = A.getValueType();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this check is necessary (though nor was it needed before this patch). The code that emits this intrinsic checks this condition (see https://github.com/llvm/llvm-project/pull/92418/files#diff-da321d454a7246f8ae276bf1db2782bf26b5210b8133cb59e4d7fd45d0905decR2156-R2158), so outside of hand-written IR the case of no extends is never taken.
That, and there doesn't seem to be any tests that check this condition either.

That said, I'm fine with this staying as a bit of defensive coding, as I can't say for sure whether all partial reduction cases in the future will match on extends, but figured it was worth bringing forward

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function must stand on it's own and thus cannot assume anything about its input that is not part of the intrinsic's definition.

// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
return SDValue();

return SDValue();
// If the extensions are mixed, we should lower it to a usdot instead
unsigned Opcode = 0;
if (AIsSigned != BIsSigned) {
if (!Subtarget->hasMatMulInt8())
return SDValue();

bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
if (Scalable && ReducedType != MVT::nxv4i32)
return SDValue();

Opcode = AArch64ISD::USDOT;
// USDOT expects the signed operand to be last
if (!BIsSigned)
std::swap(A, B);
} else if (AIsSigned)
Opcode = AArch64ISD::SDOT;
else
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be just else given by this point we know AIsSigned has to be false.

Copy link
Collaborator Author

@SamTebbs33 SamTebbs33 Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's very true, it also lets us get rid of the assertion. Done.

Opcode = AArch64ISD::UDOT;

return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
}

static SDValue performIntrinsicCombine(SDNode *N,
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,10 @@ enum NodeType : unsigned {
SADDLP,
UADDLP,

// udot/sdot instructions
// udot/sdot/usdot instructions
UDOT,
SDOT,
USDOT,

// Vector across-lanes min/max
// Only the lower result lane is defined.
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ def AArch64frsqrts : SDNode<"AArch64ISD::FRSQRTS", SDTFPBinOp>;

def AArch64sdot : SDNode<"AArch64ISD::SDOT", SDT_AArch64Dot>;
def AArch64udot : SDNode<"AArch64ISD::UDOT", SDT_AArch64Dot>;
def AArch64usdot : SDNode<"AArch64ISD::USDOT", SDT_AArch64Dot>;

def AArch64saddv : SDNode<"AArch64ISD::SADDV", SDT_AArch64UnaryVec>;
def AArch64uaddv : SDNode<"AArch64ISD::UADDV", SDT_AArch64UnaryVec>;
Expand Down Expand Up @@ -1417,8 +1418,8 @@ let Predicates = [HasMatMulInt8] in {
def SMMLA : SIMDThreeSameVectorMatMul<0, 0, "smmla", int_aarch64_neon_smmla>;
def UMMLA : SIMDThreeSameVectorMatMul<0, 1, "ummla", int_aarch64_neon_ummla>;
def USMMLA : SIMDThreeSameVectorMatMul<1, 0, "usmmla", int_aarch64_neon_usmmla>;
defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", int_aarch64_neon_usdot>;
defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", int_aarch64_neon_usdot>;
defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", AArch64usdot>;
defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", AArch64usdot>;

// sudot lane has a pattern where usdot is expected (there is no sudot).
// The second operand is used in the dup operation to repeat the indexed
Expand All @@ -1430,7 +1431,7 @@ class BaseSIMDSUDOTIndex<bit Q, string dst_kind, string lhs_kind,
lhs_kind, rhs_kind, RegType, AccumType,
InputType, null_frag> {
let Pattern = [(set (AccumType RegType:$dst),
(AccumType (int_aarch64_neon_usdot (AccumType RegType:$Rd),
(AccumType (AArch64usdot (AccumType RegType:$Rd),
(InputType (bitconvert (AccumType
(AArch64duplane32 (v4i32 V128:$Rm),
VectorIndexS:$idx)))),
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3405,7 +3405,7 @@ let Predicates = [HasSVE, HasMatMulInt8] in {
} // End HasSVE, HasMatMulInt8

let Predicates = [HasSVEorSME, HasMatMulInt8] in {
defm USDOT_ZZZ : sve_int_dot_mixed<"usdot", int_aarch64_sve_usdot>;
defm USDOT_ZZZ : sve_int_dot_mixed<"usdot", AArch64usdot>;
defm USDOT_ZZZI : sve_int_dot_mixed_indexed<0, "usdot", int_aarch64_sve_usdot_lane>;
defm SUDOT_ZZZI : sve_int_dot_mixed_indexed<1, "sudot", int_aarch64_sve_sudot_lane>;
} // End HasSVEorSME, HasMatMulInt8
Expand Down
115 changes: 112 additions & 3 deletions llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM

define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-DOT-LABEL: udot:
Expand Down Expand Up @@ -102,7 +103,115 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
ret <2 x i32> %partial.reduce
}

define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-NOI8MM-LABEL: usdot:
; CHECK-NOI8MM: // %bb.0:
; CHECK-NOI8MM-NEXT: ushll v3.8h, v1.8b, #0
; CHECK-NOI8MM-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-NOI8MM-NEXT: sshll v4.8h, v2.8b, #0
; CHECK-NOI8MM-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NOI8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
; CHECK-NOI8MM-NEXT: smull v5.4s, v2.4h, v1.4h
; CHECK-NOI8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
; CHECK-NOI8MM-NEXT: smlal2 v5.4s, v4.8h, v3.8h
; CHECK-NOI8MM-NEXT: add v0.4s, v5.4s, v0.4s
; CHECK-NOI8MM-NEXT: ret
;
; CHECK-I8MM-LABEL: usdot:
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.4s, v1.16b, v2.16b
; CHECK-I8MM-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
ret <4 x i32> %partial.reduce
}

define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-NOI8MM-LABEL: usdot_narrow:
; CHECK-NOI8MM: // %bb.0:
; CHECK-NOI8MM-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NOI8MM-NEXT: sshll v2.8h, v2.8b, #0
; CHECK-NOI8MM-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NOI8MM-NEXT: smull v3.4s, v2.4h, v1.4h
; CHECK-NOI8MM-NEXT: smull2 v4.4s, v2.8h, v1.8h
; CHECK-NOI8MM-NEXT: ext v5.16b, v1.16b, v1.16b, #8
; CHECK-NOI8MM-NEXT: ext v6.16b, v2.16b, v2.16b, #8
; CHECK-NOI8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
; CHECK-NOI8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
; CHECK-NOI8MM-NEXT: ext v1.16b, v4.16b, v4.16b, #8
; CHECK-NOI8MM-NEXT: smlal v3.4s, v6.4h, v5.4h
; CHECK-NOI8MM-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NOI8MM-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NOI8MM-NEXT: ret
;
; CHECK-I8MM-LABEL: usdot_narrow:
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.2s, v1.8b, v2.8b
; CHECK-I8MM-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
ret <2 x i32> %partial.reduce
}

define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
; CHECK-NOI8MM-LABEL: sudot:
; CHECK-NOI8MM: // %bb.0:
; CHECK-NOI8MM-NEXT: sshll v3.8h, v1.8b, #0
; CHECK-NOI8MM-NEXT: sshll2 v1.8h, v1.16b, #0
; CHECK-NOI8MM-NEXT: ushll v4.8h, v2.8b, #0
; CHECK-NOI8MM-NEXT: ushll2 v2.8h, v2.16b, #0
; CHECK-NOI8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
; CHECK-NOI8MM-NEXT: smull v5.4s, v2.4h, v1.4h
; CHECK-NOI8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
; CHECK-NOI8MM-NEXT: smlal2 v5.4s, v4.8h, v3.8h
; CHECK-NOI8MM-NEXT: add v0.4s, v5.4s, v0.4s
; CHECK-NOI8MM-NEXT: ret
;
; CHECK-I8MM-LABEL: sudot:
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.4s, v2.16b, v1.16b
; CHECK-I8MM-NEXT: ret
%u.wide = sext <16 x i8> %u to <16 x i32>
%s.wide = zext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
ret <4 x i32> %partial.reduce
}

define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-NOI8MM-LABEL: sudot_narrow:
; CHECK-NOI8MM: // %bb.0:
; CHECK-NOI8MM-NEXT: sshll v1.8h, v1.8b, #0
; CHECK-NOI8MM-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NOI8MM-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NOI8MM-NEXT: smull v3.4s, v2.4h, v1.4h
; CHECK-NOI8MM-NEXT: smull2 v4.4s, v2.8h, v1.8h
; CHECK-NOI8MM-NEXT: ext v5.16b, v1.16b, v1.16b, #8
; CHECK-NOI8MM-NEXT: ext v6.16b, v2.16b, v2.16b, #8
; CHECK-NOI8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
; CHECK-NOI8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
; CHECK-NOI8MM-NEXT: ext v1.16b, v4.16b, v4.16b, #8
; CHECK-NOI8MM-NEXT: smlal v3.4s, v6.4h, v5.4h
; CHECK-NOI8MM-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NOI8MM-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NOI8MM-NEXT: ret
;
; CHECK-I8MM-LABEL: sudot_narrow:
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.2s, v2.8b, v1.8b
; CHECK-I8MM-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
ret <2 x i32> %partial.reduce
}

define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
Expand Down
Loading