-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[AArch64][NEON][SVE] Lower mixed sign/zero extended partial reductions to usdot #107566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
to usdot This PR adds lowering for partial reductions of a mix of sign/zero extended inputs to the usdot intrinsic.
@llvm/pr-subscribers-backend-aarch64 Author: Sam Tebbs (SamTebbs33) ChangesThis PR adds lowering for partial reductions of a mix of sign/zero extended inputs to the usdot intrinsic. Patch is 22.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107566.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d80087336d230..a3b372e677f98a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21824,37 +21824,59 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
- bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
- bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
- if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
- return SDValue();
-
auto A = ExtA->getOperand(0);
auto B = ExtB->getOperand(0);
if (A.getValueType() != B.getValueType())
return SDValue();
- unsigned Opcode = 0;
-
- if (IsSExt)
- Opcode = AArch64ISD::SDOT;
- else if (IsZExt)
- Opcode = AArch64ISD::UDOT;
-
- assert(Opcode != 0 && "Unexpected dot product case encountered.");
-
EVT ReducedType = N->getValueType(0);
EVT MulSrcType = A.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
- (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
- (ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
- (ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
- return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+ if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
+ !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
+ !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
+ !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+ return SDValue();
- return SDValue();
+ bool AIsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+ bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND;
+ if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt))
+ return SDValue();
+
+ // If the extensions are mixed, we should lower it to a usdot instead
+ if (AIsZExt != BIsZExt) {
+ if (!Subtarget->hasMatMulInt8())
+ return SDValue();
+ bool Scalable = N->getValueType(0).isScalableVT();
+
+ // There's no nxv2i64 version of usdot
+ if (Scalable && ReducedType != MVT::nxv4i32)
+ return SDValue();
+
+ unsigned IntrinsicID =
+ Scalable ? Intrinsic::aarch64_sve_usdot : Intrinsic::aarch64_neon_usdot;
+ // USDOT expects the first operand to be unsigned, so swap the operands if
+ // the first is signed and the second is unsigned
+ if (AIsSExt && BIsZExt)
+ std::swap(A, B);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ReducedType,
+ DAG.getConstant(IntrinsicID, DL, MVT::i64), NarrowOp, A,
+ B);
+ }
+
+ unsigned Opcode = 0;
+ if (AIsSExt)
+ Opcode = AArch64ISD::SDOT;
+ else if (AIsZExt)
+ Opcode = AArch64ISD::UDOT;
+
+ assert(Opcode != 0 && "Unexpected dot product case encountered.");
+
+ return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
}
static SDValue performIntrinsicCombine(SDNode *N,
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 8035504d5558b1..7b6c01f4691175 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1,6 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
-; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOIMM8
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-DOT-LABEL: udot:
@@ -18,6 +19,11 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: udot:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: udot v0.4s, v2.16b, v1.16b
+; CHECK-NOIMM8-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = zext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -45,6 +51,11 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: udot_narrow:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: udot v0.2s, v2.8b, v1.8b
+; CHECK-NOIMM8-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -68,6 +79,11 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: sdot:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NOIMM8-NEXT: ret
%u.wide = sext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -95,6 +111,11 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: sdot_narrow:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: sdot v0.2s, v2.8b, v1.8b
+; CHECK-NOIMM8-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -102,7 +123,175 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
ret <2 x i32> %partial.reduce
}
-define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
+define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
+; CHECK-DOT-LABEL: usdot:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: usdot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: usdot:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: usdot:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NOIMM8-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NOIMM8-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NOIMM8-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NOIMM8-NEXT: add v0.4s, v5.4s, v0.4s
+; CHECK-NOIMM8-NEXT: ret
+ %u.wide = zext <16 x i8> %u to <16 x i32>
+ %s.wide = sext <16 x i8> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-DOT-LABEL: usdot_narrow:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: usdot v0.2s, v1.8b, v2.8b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: usdot_narrow:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: sshll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT: smlal v3.4s, v6.4h, v5.4h
+; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: usdot_narrow:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT: sshll v2.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NOIMM8-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NOIMM8-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NOIMM8-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NOIMM8-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NOIMM8-NEXT: smlal v3.4s, v6.4h, v5.4h
+; CHECK-NOIMM8-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NOIMM8-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NOIMM8-NEXT: ret
+ %u.wide = zext <8 x i8> %u to <8 x i32>
+ %s.wide = sext <8 x i8> %s to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+ ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
+; CHECK-DOT-LABEL: sudot:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: usdot v0.4s, v2.16b, v1.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sudot:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: sudot:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NOIMM8-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NOIMM8-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NOIMM8-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NOIMM8-NEXT: add v0.4s, v5.4s, v0.4s
+; CHECK-NOIMM8-NEXT: ret
+ %u.wide = sext <16 x i8> %u to <16 x i32>
+ %s.wide = zext <16 x i8> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-DOT-LABEL: sudot_narrow:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: usdot v0.2s, v2.8b, v1.8b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sudot_narrow:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT: smlal v3.4s, v6.4h, v5.4h
+; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: sudot_narrow:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NOIMM8-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NOIMM8-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NOIMM8-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NOIMM8-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NOIMM8-NEXT: smlal v3.4s, v6.4h, v5.4h
+; CHECK-NOIMM8-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NOIMM8-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NOIMM8-NEXT: ret
+ %u.wide = sext <8 x i8> %u to <8 x i32>
+ %s.wide = zext <8 x i8> %s to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+ ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index b1354ab210f727..35d2b8ca30a041 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,8 +1,9 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOIMM8
-define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: dotp:
+define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: udot:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
@@ -14,8 +15,8 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: dotp_wide:
+define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: udot_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.d, z1.h, z2.h
; CHECK-NEXT: ret
@@ -27,8 +28,8 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: dotp_sext:
+define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: sdot:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
@@ -40,8 +41,8 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: dotp_wide_sext:
+define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: sdot_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.d, z1.h, z2.h
; CHECK-NEXT: ret
@@ -53,8 +54,80 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
-; CHECK-LABEL: not_dotp:
+define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-DOT-LABEL: usdot:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: usdot z0.s, z1.b, z2.b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: usdot:
+; CHECK-NOIMM8: // %bb.0: // %entry
+; CHECK-NOIMM8-NEXT: uunpklo z3.h, z1.b
+; CHECK-NOIMM8-NEXT: sunpklo z4.h, z2.b
+; CHECK-NOIMM8-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NOIMM8-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NOIMM8-NEXT: ptrue p0.s
+; CHECK-NOIMM8-NEXT: uunpklo z5.s, z3.h
+; CHECK-NOIMM8-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NOIMM8-NEXT: sunpklo z6.s, z4.h
+; CHECK-NOIMM8-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NOIMM8-NEXT: uunpklo z7.s, z1.h
+; CHECK-NOIMM8-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NOIMM8-NEXT: sunpklo z24.s, z2.h
+; CHECK-NOIMM8-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NOIMM8-NEXT: mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NOIMM8-NEXT: mul z3.s, z3.s, z4.s
+; CHECK-NOIMM8-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NOIMM8-NEXT: movprfx z1, z3
+; CHECK-NOIMM8-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NOIMM8-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NOIMM8-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-DOT-LABEL: sudot:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: usdot z0.s, z2.b, z1.b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: sudot:
+; CHECK-NOIMM8: // %bb.0: // %entry
+; CHECK-NOIMM8-NEXT: sunpklo z3.h, z1.b
+; CHECK-NOIMM8-NEXT: uunpklo z4.h, z2.b
+; CHECK-NOIMM8-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NOIMM8-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NOIMM8-NEXT: ptrue p0.s
+; CHECK-NOIMM8-NEXT: sunpklo z5.s, z3.h
+; CHECK-NOIMM8-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NOIMM8-NEXT: uunpklo z6.s, z4.h
+; CHECK-NOIMM8-NEXT: uunpkhi z4.s, z4.h
+; CHECK-NOIMM8-NEXT: sunpklo z7.s, z1.h
+; CHECK-NOIMM8-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NOIMM8-NEXT: uunpklo z24.s, z2.h
+; CHECK-NOIMM8-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NOIMM8-NEXT: mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NOIMM8-NEXT: mul z3.s, z3.s, z4.s
+; CHECK-NOIMM8-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NOIMM8-NEXT: movprfx z1, z3
+; CHECK-NOIMM8-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NOIMM8-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NOIMM8-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
+; CHECK-LABEL: not_udot:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z1.h, z1.h, #0xff
; CHECK-NEXT: and z2.h, z2.h, #0xff
@@ -74,8 +147,8 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
-; CHECK-LABEL: not_dotp_wide:
+define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
+; CHECK-LABEL: no...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; | ||
bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND; | ||
bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND; | ||
if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this check is necessary (though nor was it needed before this patch). The code that emits this intrinsic checks this condition (see https://github.com/llvm/llvm-project/pull/92418/files#diff-da321d454a7246f8ae276bf1db2782bf26b5210b8133cb59e4d7fd45d0905decR2156-R2158), so outside of hand-written IR the case of no extends is never taken.
That, and there doesn't seem to be any tests that check this condition either.
That said, I'm fine with this staying as a bit of defensive coding, as I can't say for sure whether all partial reduction cases in the future will match on extends, but figured it was worth bringing forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function must stand on it's own and thus cannot assume anything about its input that is not part of the intrinsic's definition.
bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; | ||
bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND; | ||
bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND; | ||
if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function must stand on it's own and thus cannot assume anything about its input that is not part of the intrinsic's definition.
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; | ||
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt)) | ||
return SDValue(); | ||
|
||
auto A = ExtA->getOperand(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You cannot just move these checks because you need to prove ExtA
and ExtB
have an operand before calling getOperand()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Scalable ? Intrinsic::aarch64_sve_usdot : Intrinsic::aarch64_neon_usdot; | ||
// USDOT expects the first operand to be unsigned, so swap the operands if |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given there is sufficient reason for AArch64ISD::SDOT
and AArch64ISD::UDOT
to exist I propose the same likely holds for creating AArch64ISD::USDOT
? Which is perhaps best done under a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree this can be done separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the avoidance of doubt I am suggesting it's a requirement of this PR that AArch64ISD::USDOT
exist rather than suggesting some post PR refactoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've created a USDOT node, please let me know if it's not as expected.
def : Pat<(v4i32 (AArch64usdot (v4i32 V128:$Rd), (v16i8 V128:$Rm), (v16i8 V128:$Rn))), (USDOTv16i8 $Rd, $Rm, $Rn)>; | ||
def : Pat<(v2i32 (AArch64usdot (v2i32 V64:$Rd), (v8i8 V64:$Rm), (v8i8 V64:$Rn))), (USDOTv8i8 $Rd, $Rm, $Rn)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need extra patterns here but instead can pass AArch64usdot
directly to defm USDOT...
in place of the existing int_aarch64_neon_usdot
parameter.
If you look at the way UDOT is handled you see the missing piece is a small update to AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN
to lower the intrinsics to AArch64ISD::USDOT
. Doing this will mean all future optimisations will apply equally to all places where the operation exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, that is cleaner. Done.
@@ -3408,6 +3408,7 @@ let Predicates = [HasSVEorSME, HasMatMulInt8] in { | |||
defm USDOT_ZZZ : sve_int_dot_mixed<"usdot", int_aarch64_sve_usdot>; | |||
defm USDOT_ZZZI : sve_int_dot_mixed_indexed<0, "usdot", int_aarch64_sve_usdot_lane>; | |||
defm SUDOT_ZZZI : sve_int_dot_mixed_indexed<1, "sudot", int_aarch64_sve_sudot_lane>; | |||
def : Pat<(nxv4i32 (AArch64usdot (nxv4i32 ZPR32:$Rd), (nxv16i8 ZPR8:$Rm), (nxv16i8 ZPR8:$Rn))), (USDOT_ZZZ $Rd, $Rm, $Rn)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above, you can pass AArch64usdot
directly into defm USDOT_ZZZ :
replacing the existing int_aarch64_sve_usdot
parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few suggestions, but only the potential bug in AArch64InstrInfo.td
is holding me back from accepting the PR.
@@ -1,8 +1,9 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 | |||
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s | |||
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CHECK-I8MM
and CHECK-NOI8MM
seem like more consistent names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// USDOT expects the first operand to be unsigned, so swap the operands if | ||
// the first is signed and the second is unsigned | ||
if (AIsSExt && BIsZExt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be condensed into:
// USDOT expects the signed operand to be last.
if (BIsZExt)
std::swap(A, B);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
bool AIsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND; | ||
bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; | ||
bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND; | ||
bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND; | ||
if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we care about all combinations I think you'd be better of with something like:
if (!isExtOpcode(ExtA->getOpcode() || !isExtOpcode(ExtB->getOpcode())
return SDValue();
bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
Whilst this allows ISD::ANY_EXTEND to pass through, I believe they would make the result undefined and thus emitting any DOT still a valid option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -1417,7 +1418,7 @@ let Predicates = [HasMatMulInt8] in { | |||
def SMMLA : SIMDThreeSameVectorMatMul<0, 0, "smmla", int_aarch64_neon_smmla>; | |||
def UMMLA : SIMDThreeSameVectorMatMul<0, 1, "ummla", int_aarch64_neon_ummla>; | |||
def USMMLA : SIMDThreeSameVectorMatMul<1, 0, "usmmla", int_aarch64_neon_usmmla>; | |||
defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", int_aarch64_neon_usdot>; | |||
defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", AArch64usdot>; | |||
defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", int_aarch64_neon_usdot>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to replace all uses of int_aarch64_neon_usdot
otherwise their instruction match will fail because the intrinsic no longer gets this far. I'm somewhat surprised you did not see any test failures?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for spotting that.
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT | ||
; RUN: llc -mtriple aarch64 -mattr=+neon,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT | ||
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOIMM8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the following reduce the number of CHECK lines?
RUN_1: CHECK,CHECK-DOT,CHECK-I8MM
RUN_2: CHECK,CHECK-NODOT,CHECK-I8MM
RUN_3: CHECK,CHECK-DOT,CHECK-NOI8MM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better, thanks. Done.
✅ With the latest revision this PR passed the C/C++ code formatter. |
Opcode = AArch64ISD::SDOT; | ||
else if (IsZExt) | ||
else if (!AIsSigned) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be just else
given by this point we know AIsSigned
has to be false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's very true, it also lets us get rid of the assertion. Done.
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM | ||
; RUN: llc -mtriple aarch64 -mattr=+neon,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT,CHECK-I8MM | ||
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This hasn't worked as expected because there are no CHECK-I8MM
lines. I guess I was expecting too much from the auto-update script?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I see what has happened. I had assumed +i8mm
would enable usdot
instructions but the DAG combine requires both +dotprod
and +i8mm
for that to happen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following seems to work better:
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's nice. I re-ordered the run lines a little to reduce the number of total changes.
…s to usdot (llvm#107566) This PR adds lowering for partial reductions of a mix of sign/zero extended inputs to the usdot intrinsic.
This PR adds lowering for partial reductions of a mix of sign/zero extended inputs to the usdot intrinsic.