diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 9d80087336d23..590252e5dbe0c 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2701,6 +2701,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::SADDLV) MAKE_CASE(AArch64ISD::SDOT) MAKE_CASE(AArch64ISD::UDOT) + MAKE_CASE(AArch64ISD::USDOT) MAKE_CASE(AArch64ISD::SMINV) MAKE_CASE(AArch64ISD::UMINV) MAKE_CASE(AArch64ISD::SMAXV) @@ -6114,6 +6115,11 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); } + case Intrinsic::aarch64_neon_usdot: + case Intrinsic::aarch64_sve_usdot: { + return DAG.getNode(AArch64ISD::USDOT, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + } case Intrinsic::get_active_lane_mask: { SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64); @@ -21824,37 +21830,50 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, auto ExtA = MulOp->getOperand(0); auto ExtB = MulOp->getOperand(1); - bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND; - bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; - if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt)) + + if (!ISD::isExtOpcode(ExtA->getOpcode()) || + !ISD::isExtOpcode(ExtB->getOpcode())) return SDValue(); + bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND; + bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND; auto A = ExtA->getOperand(0); auto B = ExtB->getOperand(0); if (A.getValueType() != B.getValueType()) return SDValue(); - unsigned Opcode = 0; - - if (IsSExt) - Opcode = AArch64ISD::SDOT; - else if (IsZExt) - Opcode = AArch64ISD::UDOT; - - assert(Opcode != 0 && "Unexpected dot product case encountered."); - EVT ReducedType = N->getValueType(0); EVT MulSrcType = A.getValueType(); // Dot products operate on chunks of four elements so there must be four times // as many elements in the wide type - if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) || - (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) || - (ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) || - (ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8)) - return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B); + if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) && + !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) && + !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) && + !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8)) + return SDValue(); - return SDValue(); + // If the extensions are mixed, we should lower it to a usdot instead + unsigned Opcode = 0; + if (AIsSigned != BIsSigned) { + if (!Subtarget->hasMatMulInt8()) + return SDValue(); + + bool Scalable = N->getValueType(0).isScalableVT(); + // There's no nxv2i64 version of usdot + if (Scalable && ReducedType != MVT::nxv4i32) + return SDValue(); + + Opcode = AArch64ISD::USDOT; + // USDOT expects the signed operand to be last + if (!BIsSigned) + std::swap(A, B); + } else if (AIsSigned) + Opcode = AArch64ISD::SDOT; + else + Opcode = AArch64ISD::UDOT; + + return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B); } static SDValue performIntrinsicCombine(SDNode *N, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index f9d45b02d30e3..e79b41b66d77e 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -280,9 +280,10 @@ enum NodeType : unsigned { SADDLP, UADDLP, - // udot/sdot instructions + // udot/sdot/usdot instructions UDOT, SDOT, + USDOT, // Vector across-lanes min/max // Only the lower result lane is defined. diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index ccef85bfaa8af..97f3c39145c61 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -855,6 +855,7 @@ def AArch64frsqrts : SDNode<"AArch64ISD::FRSQRTS", SDTFPBinOp>; def AArch64sdot : SDNode<"AArch64ISD::SDOT", SDT_AArch64Dot>; def AArch64udot : SDNode<"AArch64ISD::UDOT", SDT_AArch64Dot>; +def AArch64usdot : SDNode<"AArch64ISD::USDOT", SDT_AArch64Dot>; def AArch64saddv : SDNode<"AArch64ISD::SADDV", SDT_AArch64UnaryVec>; def AArch64uaddv : SDNode<"AArch64ISD::UADDV", SDT_AArch64UnaryVec>; @@ -1417,8 +1418,8 @@ let Predicates = [HasMatMulInt8] in { def SMMLA : SIMDThreeSameVectorMatMul<0, 0, "smmla", int_aarch64_neon_smmla>; def UMMLA : SIMDThreeSameVectorMatMul<0, 1, "ummla", int_aarch64_neon_ummla>; def USMMLA : SIMDThreeSameVectorMatMul<1, 0, "usmmla", int_aarch64_neon_usmmla>; -defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", int_aarch64_neon_usdot>; -defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", int_aarch64_neon_usdot>; +defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", AArch64usdot>; +defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", AArch64usdot>; // sudot lane has a pattern where usdot is expected (there is no sudot). // The second operand is used in the dup operation to repeat the indexed @@ -1430,7 +1431,7 @@ class BaseSIMDSUDOTIndex { let Pattern = [(set (AccumType RegType:$dst), - (AccumType (int_aarch64_neon_usdot (AccumType RegType:$Rd), + (AccumType (AArch64usdot (AccumType RegType:$Rd), (InputType (bitconvert (AccumType (AArch64duplane32 (v4i32 V128:$Rm), VectorIndexS:$idx)))), diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 692cd66d38437..c4207dd478594 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -3405,7 +3405,7 @@ let Predicates = [HasSVE, HasMatMulInt8] in { } // End HasSVE, HasMatMulInt8 let Predicates = [HasSVEorSME, HasMatMulInt8] in { - defm USDOT_ZZZ : sve_int_dot_mixed<"usdot", int_aarch64_sve_usdot>; + defm USDOT_ZZZ : sve_int_dot_mixed<"usdot", AArch64usdot>; defm USDOT_ZZZI : sve_int_dot_mixed_indexed<0, "usdot", int_aarch64_sve_usdot_lane>; defm SUDOT_ZZZI : sve_int_dot_mixed_indexed<1, "sudot", int_aarch64_sve_sudot_lane>; } // End HasSVEorSME, HasMatMulInt8 diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll index 8035504d5558b..841da1f8ea57c 100644 --- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll @@ -1,6 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT -; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT +; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM +; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT +; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) { ; CHECK-DOT-LABEL: udot: @@ -102,7 +103,115 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) { ret <2 x i32> %partial.reduce } -define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) { +define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) { +; CHECK-NOI8MM-LABEL: usdot: +; CHECK-NOI8MM: // %bb.0: +; CHECK-NOI8MM-NEXT: ushll v3.8h, v1.8b, #0 +; CHECK-NOI8MM-NEXT: ushll2 v1.8h, v1.16b, #0 +; CHECK-NOI8MM-NEXT: sshll v4.8h, v2.8b, #0 +; CHECK-NOI8MM-NEXT: sshll2 v2.8h, v2.16b, #0 +; CHECK-NOI8MM-NEXT: smlal v0.4s, v4.4h, v3.4h +; CHECK-NOI8MM-NEXT: smull v5.4s, v2.4h, v1.4h +; CHECK-NOI8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h +; CHECK-NOI8MM-NEXT: smlal2 v5.4s, v4.8h, v3.8h +; CHECK-NOI8MM-NEXT: add v0.4s, v5.4s, v0.4s +; CHECK-NOI8MM-NEXT: ret +; +; CHECK-I8MM-LABEL: usdot: +; CHECK-I8MM: // %bb.0: +; CHECK-I8MM-NEXT: usdot v0.4s, v1.16b, v2.16b +; CHECK-I8MM-NEXT: ret + %u.wide = zext <16 x i8> %u to <16 x i32> + %s.wide = sext <16 x i8> %s to <16 x i32> + %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide + %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult) + ret <4 x i32> %partial.reduce +} + +define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{ +; CHECK-NOI8MM-LABEL: usdot_narrow: +; CHECK-NOI8MM: // %bb.0: +; CHECK-NOI8MM-NEXT: ushll v1.8h, v1.8b, #0 +; CHECK-NOI8MM-NEXT: sshll v2.8h, v2.8b, #0 +; CHECK-NOI8MM-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NOI8MM-NEXT: smull v3.4s, v2.4h, v1.4h +; CHECK-NOI8MM-NEXT: smull2 v4.4s, v2.8h, v1.8h +; CHECK-NOI8MM-NEXT: ext v5.16b, v1.16b, v1.16b, #8 +; CHECK-NOI8MM-NEXT: ext v6.16b, v2.16b, v2.16b, #8 +; CHECK-NOI8MM-NEXT: smlal v0.4s, v2.4h, v1.4h +; CHECK-NOI8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8 +; CHECK-NOI8MM-NEXT: ext v1.16b, v4.16b, v4.16b, #8 +; CHECK-NOI8MM-NEXT: smlal v3.4s, v6.4h, v5.4h +; CHECK-NOI8MM-NEXT: add v0.2s, v1.2s, v0.2s +; CHECK-NOI8MM-NEXT: add v0.2s, v3.2s, v0.2s +; CHECK-NOI8MM-NEXT: ret +; +; CHECK-I8MM-LABEL: usdot_narrow: +; CHECK-I8MM: // %bb.0: +; CHECK-I8MM-NEXT: usdot v0.2s, v1.8b, v2.8b +; CHECK-I8MM-NEXT: ret + %u.wide = zext <8 x i8> %u to <8 x i32> + %s.wide = sext <8 x i8> %s to <8 x i32> + %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide + %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult) + ret <2 x i32> %partial.reduce +} + +define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{ +; CHECK-NOI8MM-LABEL: sudot: +; CHECK-NOI8MM: // %bb.0: +; CHECK-NOI8MM-NEXT: sshll v3.8h, v1.8b, #0 +; CHECK-NOI8MM-NEXT: sshll2 v1.8h, v1.16b, #0 +; CHECK-NOI8MM-NEXT: ushll v4.8h, v2.8b, #0 +; CHECK-NOI8MM-NEXT: ushll2 v2.8h, v2.16b, #0 +; CHECK-NOI8MM-NEXT: smlal v0.4s, v4.4h, v3.4h +; CHECK-NOI8MM-NEXT: smull v5.4s, v2.4h, v1.4h +; CHECK-NOI8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h +; CHECK-NOI8MM-NEXT: smlal2 v5.4s, v4.8h, v3.8h +; CHECK-NOI8MM-NEXT: add v0.4s, v5.4s, v0.4s +; CHECK-NOI8MM-NEXT: ret +; +; CHECK-I8MM-LABEL: sudot: +; CHECK-I8MM: // %bb.0: +; CHECK-I8MM-NEXT: usdot v0.4s, v2.16b, v1.16b +; CHECK-I8MM-NEXT: ret + %u.wide = sext <16 x i8> %u to <16 x i32> + %s.wide = zext <16 x i8> %s to <16 x i32> + %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide + %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult) + ret <4 x i32> %partial.reduce +} + +define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{ +; CHECK-NOI8MM-LABEL: sudot_narrow: +; CHECK-NOI8MM: // %bb.0: +; CHECK-NOI8MM-NEXT: sshll v1.8h, v1.8b, #0 +; CHECK-NOI8MM-NEXT: ushll v2.8h, v2.8b, #0 +; CHECK-NOI8MM-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NOI8MM-NEXT: smull v3.4s, v2.4h, v1.4h +; CHECK-NOI8MM-NEXT: smull2 v4.4s, v2.8h, v1.8h +; CHECK-NOI8MM-NEXT: ext v5.16b, v1.16b, v1.16b, #8 +; CHECK-NOI8MM-NEXT: ext v6.16b, v2.16b, v2.16b, #8 +; CHECK-NOI8MM-NEXT: smlal v0.4s, v2.4h, v1.4h +; CHECK-NOI8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8 +; CHECK-NOI8MM-NEXT: ext v1.16b, v4.16b, v4.16b, #8 +; CHECK-NOI8MM-NEXT: smlal v3.4s, v6.4h, v5.4h +; CHECK-NOI8MM-NEXT: add v0.2s, v1.2s, v0.2s +; CHECK-NOI8MM-NEXT: add v0.2s, v3.2s, v0.2s +; CHECK-NOI8MM-NEXT: ret +; +; CHECK-I8MM-LABEL: sudot_narrow: +; CHECK-I8MM: // %bb.0: +; CHECK-I8MM-NEXT: usdot v0.2s, v2.8b, v1.8b +; CHECK-I8MM-NEXT: ret + %u.wide = sext <8 x i8> %u to <8 x i32> + %s.wide = zext <8 x i8> %s to <8 x i32> + %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide + %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult) + ret <2 x i32> %partial.reduce +} + +define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{ ; CHECK-LABEL: not_udot: ; CHECK: // %bb.0: ; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index b1354ab210f72..00e5ac479d02c 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -1,8 +1,9 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s +; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM +; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM -define @dotp( %acc, %a, %b) { -; CHECK-LABEL: dotp: +define @udot( %acc, %a, %b) { +; CHECK-LABEL: udot: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: udot z0.s, z1.b, z2.b ; CHECK-NEXT: ret @@ -14,8 +15,8 @@ entry: ret %partial.reduce } -define @dotp_wide( %acc, %a, %b) { -; CHECK-LABEL: dotp_wide: +define @udot_wide( %acc, %a, %b) { +; CHECK-LABEL: udot_wide: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: udot z0.d, z1.h, z2.h ; CHECK-NEXT: ret @@ -27,8 +28,8 @@ entry: ret %partial.reduce } -define @dotp_sext( %accc, %a, %b) { -; CHECK-LABEL: dotp_sext: +define @sdot( %accc, %a, %b) { +; CHECK-LABEL: sdot: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: sdot z0.s, z1.b, z2.b ; CHECK-NEXT: ret @@ -40,8 +41,8 @@ entry: ret %partial.reduce } -define @dotp_wide_sext( %acc, %a, %b) { -; CHECK-LABEL: dotp_wide_sext: +define @sdot_wide( %acc, %a, %b) { +; CHECK-LABEL: sdot_wide: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: sdot z0.d, z1.h, z2.h ; CHECK-NEXT: ret @@ -53,8 +54,80 @@ entry: ret %partial.reduce } -define @not_dotp( %acc, %a, %b) { -; CHECK-LABEL: not_dotp: +define @usdot( %acc, %a, %b) { +; CHECK-I8MM-LABEL: usdot: +; CHECK-I8MM: // %bb.0: // %entry +; CHECK-I8MM-NEXT: usdot z0.s, z1.b, z2.b +; CHECK-I8MM-NEXT: ret +; +; CHECK-NOI8MM-LABEL: usdot: +; CHECK-NOI8MM: // %bb.0: // %entry +; CHECK-NOI8MM-NEXT: uunpklo z3.h, z1.b +; CHECK-NOI8MM-NEXT: sunpklo z4.h, z2.b +; CHECK-NOI8MM-NEXT: uunpkhi z1.h, z1.b +; CHECK-NOI8MM-NEXT: sunpkhi z2.h, z2.b +; CHECK-NOI8MM-NEXT: ptrue p0.s +; CHECK-NOI8MM-NEXT: uunpklo z5.s, z3.h +; CHECK-NOI8MM-NEXT: uunpkhi z3.s, z3.h +; CHECK-NOI8MM-NEXT: sunpklo z6.s, z4.h +; CHECK-NOI8MM-NEXT: sunpkhi z4.s, z4.h +; CHECK-NOI8MM-NEXT: uunpklo z7.s, z1.h +; CHECK-NOI8MM-NEXT: uunpkhi z1.s, z1.h +; CHECK-NOI8MM-NEXT: sunpklo z24.s, z2.h +; CHECK-NOI8MM-NEXT: sunpkhi z2.s, z2.h +; CHECK-NOI8MM-NEXT: mla z0.s, p0/m, z5.s, z6.s +; CHECK-NOI8MM-NEXT: mul z3.s, z3.s, z4.s +; CHECK-NOI8MM-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NOI8MM-NEXT: movprfx z1, z3 +; CHECK-NOI8MM-NEXT: mla z1.s, p0/m, z7.s, z24.s +; CHECK-NOI8MM-NEXT: add z0.s, z1.s, z0.s +; CHECK-NOI8MM-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %mult) + ret %partial.reduce +} + +define @sudot( %acc, %a, %b) { +; CHECK-I8MM-LABEL: sudot: +; CHECK-I8MM: // %bb.0: // %entry +; CHECK-I8MM-NEXT: usdot z0.s, z2.b, z1.b +; CHECK-I8MM-NEXT: ret +; +; CHECK-NOI8MM-LABEL: sudot: +; CHECK-NOI8MM: // %bb.0: // %entry +; CHECK-NOI8MM-NEXT: sunpklo z3.h, z1.b +; CHECK-NOI8MM-NEXT: uunpklo z4.h, z2.b +; CHECK-NOI8MM-NEXT: sunpkhi z1.h, z1.b +; CHECK-NOI8MM-NEXT: uunpkhi z2.h, z2.b +; CHECK-NOI8MM-NEXT: ptrue p0.s +; CHECK-NOI8MM-NEXT: sunpklo z5.s, z3.h +; CHECK-NOI8MM-NEXT: sunpkhi z3.s, z3.h +; CHECK-NOI8MM-NEXT: uunpklo z6.s, z4.h +; CHECK-NOI8MM-NEXT: uunpkhi z4.s, z4.h +; CHECK-NOI8MM-NEXT: sunpklo z7.s, z1.h +; CHECK-NOI8MM-NEXT: sunpkhi z1.s, z1.h +; CHECK-NOI8MM-NEXT: uunpklo z24.s, z2.h +; CHECK-NOI8MM-NEXT: uunpkhi z2.s, z2.h +; CHECK-NOI8MM-NEXT: mla z0.s, p0/m, z5.s, z6.s +; CHECK-NOI8MM-NEXT: mul z3.s, z3.s, z4.s +; CHECK-NOI8MM-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NOI8MM-NEXT: movprfx z1, z3 +; CHECK-NOI8MM-NEXT: mla z1.s, p0/m, z7.s, z24.s +; CHECK-NOI8MM-NEXT: add z0.s, z1.s, z0.s +; CHECK-NOI8MM-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %mult) + ret %partial.reduce +} + +define @not_udot( %acc, %a, %b) { +; CHECK-LABEL: not_udot: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: and z1.h, z1.h, #0xff ; CHECK-NEXT: and z2.h, z2.h, #0xff @@ -74,8 +147,8 @@ entry: ret %partial.reduce } -define @not_dotp_wide( %acc, %a, %b) { -; CHECK-LABEL: not_dotp_wide: +define @not_udot_wide( %acc, %a, %b) { +; CHECK-LABEL: not_udot_wide: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: and z1.s, z1.s, #0xffff ; CHECK-NEXT: and z2.s, z2.s, #0xffff @@ -94,3 +167,65 @@ entry: %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64( %acc, %mult) ret %partial.reduce } + +define @not_usdot( %acc, %a, %b) { +; CHECK-LABEL: not_usdot: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: uunpklo z3.s, z1.h +; CHECK-NEXT: sunpklo z4.s, z2.h +; CHECK-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEXT: sunpkhi z2.s, z2.h +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: uunpklo z5.d, z3.s +; CHECK-NEXT: uunpkhi z3.d, z3.s +; CHECK-NEXT: sunpklo z6.d, z4.s +; CHECK-NEXT: sunpkhi z4.d, z4.s +; CHECK-NEXT: uunpklo z7.d, z1.s +; CHECK-NEXT: uunpkhi z1.d, z1.s +; CHECK-NEXT: sunpklo z24.d, z2.s +; CHECK-NEXT: sunpkhi z2.d, z2.s +; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d +; CHECK-NEXT: mul z3.d, z3.d, z4.d +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: movprfx z1, z3 +; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d +; CHECK-NEXT: add z0.d, z1.d, z0.d +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) + ret %partial.reduce +} + +define @not_sudot( %acc, %a, %b) { +; CHECK-LABEL: not_sudot: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sunpklo z3.s, z1.h +; CHECK-NEXT: uunpklo z4.s, z2.h +; CHECK-NEXT: sunpkhi z1.s, z1.h +; CHECK-NEXT: uunpkhi z2.s, z2.h +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: sunpklo z5.d, z3.s +; CHECK-NEXT: sunpkhi z3.d, z3.s +; CHECK-NEXT: uunpklo z6.d, z4.s +; CHECK-NEXT: uunpkhi z4.d, z4.s +; CHECK-NEXT: sunpklo z7.d, z1.s +; CHECK-NEXT: sunpkhi z1.d, z1.s +; CHECK-NEXT: uunpklo z24.d, z2.s +; CHECK-NEXT: uunpkhi z2.d, z2.s +; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d +; CHECK-NEXT: mul z3.d, z3.d, z4.d +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: movprfx z1, z3 +; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d +; CHECK-NEXT: add z0.d, z1.d, z0.d +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) + ret %partial.reduce +}