Skip to content

[RISCV] Generaize reduction tree matching to fp sum reductions #68599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 9, 2023

Conversation

preames
Copy link
Collaborator

@preames preames commented Oct 9, 2023

This builds on the transform introduced in f0505c3, and generalizes to all integer operations in 45a334d. This change adds support for floating point sumation.

A couple of notes:

  • I chose to leave fmaxnum and fminnum unhandled for the moment. They have a slightly different set of legality rules.
  • We could form strictly sequenced FADD reductions for FADDs without fast math flags. As the ordered reductions are more expensive, I left thinking about this as a future exercise.
  • This can't yet match the full vector reduce + start value idiom. That will be an upcoming set of changes.

This builds on the transform introduced in f0505c3, and generalizes to all
integer operations in 45a334d.  This change adds support for floating point
sumation.

A couple of notes:
* I chose to leave fmaxnum and fminnum unhandled for the moment.  They have
  a slightly different set of legality rules.
* We could form strictly sequenced FADD reductions for FADDs without fast
  math flags.  As the ordered reductions are more expensive, I left thinking
  about this as a future exercise.
* This can't yet match the full vector reduce + start value idiom.  That will
  be an upcoming set of changes.
@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2023

@llvm/pr-subscribers-backend-risc-v

Changes

This builds on the transform introduced in f0505c3, and generalizes to all integer operations in 45a334d. This change adds support for floating point sumation.

A couple of notes:

  • I chose to leave fmaxnum and fminnum unhandled for the moment. They have a slightly different set of legality rules.
  • We could form strictly sequenced FADD reductions for FADDs without fast math flags. As the ordered reductions are more expensive, I left thinking about this as a future exercise.
  • This can't yet match the full vector reduce + start value idiom. That will be an upcoming set of changes.

Full diff: https://github.com/llvm/llvm-project/pull/68599.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+15-6)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll (+159)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index bd4150c87eabbd0..94c1c52a28462b7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11114,7 +11114,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
   }
 }
 
-/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP
+/// Given a binary operator, return the *associative* generic ISD::VECREDUCE_OP
 /// which corresponds to it.
 static unsigned getVecReduceOpcode(unsigned Opc) {
   switch (Opc) {
@@ -11136,6 +11136,9 @@ static unsigned getVecReduceOpcode(unsigned Opc) {
     return ISD::VECREDUCE_OR;
   case ISD::XOR:
     return ISD::VECREDUCE_XOR;
+  case ISD::FADD:
+    // Note: This is the associative form of the generic reduction opcode.
+    return ISD::VECREDUCE_FADD;
   }
 }
 
@@ -11162,12 +11165,16 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
 
   const SDLoc DL(N);
   const EVT VT = N->getValueType(0);
+  const unsigned Opc = N->getOpcode();
 
-  // TODO: Handle floating point here.
-  if (!VT.isInteger())
+  // For FADD, we only handle the case with reassociation allowed.  We
+  // could handle strict reduction order, but at the moment, there's no
+  // known reason to, and the complexity isn't worth it.
+  // TODO: Handle fminnum and fmaxnum here
+  if (!VT.isInteger() &&
+      (Opc != ISD::FADD || !N->getFlags().hasAllowReassociation()))
     return SDValue();
 
-  const unsigned Opc = N->getOpcode();
   const unsigned ReduceOpc = getVecReduceOpcode(Opc);
   assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
          "Inconsistent mappings");
@@ -11200,7 +11207,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
     EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
     SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
                               DAG.getVectorIdxConstant(0, DL));
-    return DAG.getNode(ReduceOpc, DL, VT, Vec);
+    return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
   }
 
   // Match (binop (reduce (extract_subvector V, 0),
@@ -11222,7 +11229,9 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
       EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
       SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
                                 DAG.getVectorIdxConstant(0, DL));
-      return DAG.getNode(ReduceOpc, DL, VT, Vec);
+      auto Flags = ReduceVec->getFlags();
+      Flags.intersectWith(N->getFlags());
+      return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
     }
   }
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index dd9a1118ab821d4..76df097a7697162 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -764,6 +764,165 @@ define i32 @reduce_umin_16xi32_prefix5(ptr %p) {
   %umin3 = call i32 @llvm.umin.i32(i32 %umin2, i32 %e4)
   ret i32 %umin3
 }
+
+define float @reduce_fadd_16xf32_prefix2(ptr %p) {
+; CHECK-LABEL: reduce_fadd_16xf32_prefix2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <16 x float>, ptr %p, align 256
+  %e0 = extractelement <16 x float> %v, i32 0
+  %e1 = extractelement <16 x float> %v, i32 1
+  %fadd0 = fadd fast float %e0, %e1
+  ret float %fadd0
+}
+
+define float @reduce_fadd_16xi32_prefix5(ptr %p) {
+; CHECK-LABEL: reduce_fadd_16xi32_prefix5:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lui a1, 524288
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v10, a1
+; CHECK-NEXT:    vsetivli zero, 6, e32, m2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 5
+; CHECK-NEXT:    vsetivli zero, 7, e32, m2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 6
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 7
+; CHECK-NEXT:    vfredusum.vs v8, v8, v10
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <16 x float>, ptr %p, align 256
+  %e0 = extractelement <16 x float> %v, i32 0
+  %e1 = extractelement <16 x float> %v, i32 1
+  %e2 = extractelement <16 x float> %v, i32 2
+  %e3 = extractelement <16 x float> %v, i32 3
+  %e4 = extractelement <16 x float> %v, i32 4
+  %fadd0 = fadd fast float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd fast float %fadd1, %e3
+  %fadd3 = fadd fast float %fadd2, %e4
+  ret float %fadd3
+}
+
+;; Corner case tests for fadd associativity
+
+; Negative test, not associative.  Would need strict opcode.
+define float @reduce_fadd_2xf32_non_associative(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_non_associative:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v8, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd float %e0, %e1
+  ret float %fadd0
+}
+
+; Positive test - minimal set of fast math flags
+define float @reduce_fadd_2xf32_reassoc_only(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_reassoc_only:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    lui a0, 524288
+; CHECK-NEXT:    vmv.s.x v9, a0
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd reassoc float %e0, %e1
+  ret float %fadd0
+}
+
+; Negative test - wrong fast math flag.
+define float @reduce_fadd_2xf32_ninf_only(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_ninf_only:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v8, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd ninf float %e0, %e1
+  ret float %fadd0
+}
+
+
+; Negative test - last fadd is not associative
+define float @reduce_fadd_4xi32_non_associative(ptr %p) {
+; CHECK-LABEL: reduce_fadd_4xi32_non_associative:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vslidedown.vi v9, v8, 3
+; CHECK-NEXT:    vfmv.f.s fa5, v9
+; CHECK-NEXT:    lui a0, 524288
+; CHECK-NEXT:    vmv.s.x v9, a0
+; CHECK-NEXT:    vslideup.vi v8, v9, 3
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa4, fa5
+; CHECK-NEXT:    ret
+  %v = load <4 x float>, ptr %p, align 256
+  %e0 = extractelement <4 x float> %v, i32 0
+  %e1 = extractelement <4 x float> %v, i32 1
+  %e2 = extractelement <4 x float> %v, i32 2
+  %e3 = extractelement <4 x float> %v, i32 3
+  %fadd0 = fadd fast float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd float %fadd1, %e3
+  ret float %fadd2
+}
+
+; Negative test - first fadd is not associative
+; We could form a reduce for elements 2 and 3.
+define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
+; CHECK-LABEL: reduce_fadd_4xi32_non_associative2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v9, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v9
+; CHECK-NEXT:    vslidedown.vi v9, v8, 2
+; CHECK-NEXT:    vfmv.f.s fa3, v9
+; CHECK-NEXT:    vslidedown.vi v8, v8, 3
+; CHECK-NEXT:    vfmv.f.s fa2, v8
+; CHECK-NEXT:    fadd.s fa5, fa5, fa4
+; CHECK-NEXT:    fadd.s fa4, fa3, fa2
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <4 x float>, ptr %p, align 256
+  %e0 = extractelement <4 x float> %v, i32 0
+  %e1 = extractelement <4 x float> %v, i32 1
+  %e2 = extractelement <4 x float> %v, i32 2
+  %e3 = extractelement <4 x float> %v, i32 3
+  %fadd0 = fadd float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd fast float %fadd1, %e3
+  ret float %fadd2
+}
+
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; RV32: {{.*}}
 ; RV64: {{.*}}

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

; CHECK-NEXT: vslideup.vi v8, v10, 6
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vslideup.vi v8, v10, 7
; CHECK-NEXT: vfredusum.vs v8, v8, v10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could improve the legalisation here by using VP nodes (something like https://reviews.llvm.org/D148523 but for reductions) to avoid having to pad out the vector with zeroes. Or it looks like there could also be a combine to replace all these inserts with a splat and single slide.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a bunch of possible improvements here. Probably the best is to use a masked reduction.

TBH, the prefix cases aren't showing up my motivating workloads (spec2017), so I'm not super worried about them. I wanted something correct and not terrible so that the incremental approach was viable, but that's currently about as far as I care.

@preames preames merged commit 08b20d8 into llvm:main Oct 9, 2023
@preames preames deleted the pr-riscv-fadd-reduction-tree branch October 9, 2023 18:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants