-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[RISCV] Generaize reduction tree matching to fp sum reductions #68599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This builds on the transform introduced in f0505c3, and generalizes to all integer operations in 45a334d. This change adds support for floating point sumation. A couple of notes: * I chose to leave fmaxnum and fminnum unhandled for the moment. They have a slightly different set of legality rules. * We could form strictly sequenced FADD reductions for FADDs without fast math flags. As the ordered reductions are more expensive, I left thinking about this as a future exercise. * This can't yet match the full vector reduce + start value idiom. That will be an upcoming set of changes.
@llvm/pr-subscribers-backend-risc-v ChangesThis builds on the transform introduced in f0505c3, and generalizes to all integer operations in 45a334d. This change adds support for floating point sumation. A couple of notes:
Full diff: https://github.com/llvm/llvm-project/pull/68599.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index bd4150c87eabbd0..94c1c52a28462b7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11114,7 +11114,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
}
}
-/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP
+/// Given a binary operator, return the *associative* generic ISD::VECREDUCE_OP
/// which corresponds to it.
static unsigned getVecReduceOpcode(unsigned Opc) {
switch (Opc) {
@@ -11136,6 +11136,9 @@ static unsigned getVecReduceOpcode(unsigned Opc) {
return ISD::VECREDUCE_OR;
case ISD::XOR:
return ISD::VECREDUCE_XOR;
+ case ISD::FADD:
+ // Note: This is the associative form of the generic reduction opcode.
+ return ISD::VECREDUCE_FADD;
}
}
@@ -11162,12 +11165,16 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
const SDLoc DL(N);
const EVT VT = N->getValueType(0);
+ const unsigned Opc = N->getOpcode();
- // TODO: Handle floating point here.
- if (!VT.isInteger())
+ // For FADD, we only handle the case with reassociation allowed. We
+ // could handle strict reduction order, but at the moment, there's no
+ // known reason to, and the complexity isn't worth it.
+ // TODO: Handle fminnum and fmaxnum here
+ if (!VT.isInteger() &&
+ (Opc != ISD::FADD || !N->getFlags().hasAllowReassociation()))
return SDValue();
- const unsigned Opc = N->getOpcode();
const unsigned ReduceOpc = getVecReduceOpcode(Opc);
assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
"Inconsistent mappings");
@@ -11200,7 +11207,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
DAG.getVectorIdxConstant(0, DL));
- return DAG.getNode(ReduceOpc, DL, VT, Vec);
+ return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
}
// Match (binop (reduce (extract_subvector V, 0),
@@ -11222,7 +11229,9 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
DAG.getVectorIdxConstant(0, DL));
- return DAG.getNode(ReduceOpc, DL, VT, Vec);
+ auto Flags = ReduceVec->getFlags();
+ Flags.intersectWith(N->getFlags());
+ return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
}
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index dd9a1118ab821d4..76df097a7697162 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -764,6 +764,165 @@ define i32 @reduce_umin_16xi32_prefix5(ptr %p) {
%umin3 = call i32 @llvm.umin.i32(i32 %umin2, i32 %e4)
ret i32 %umin3
}
+
+define float @reduce_fadd_16xf32_prefix2(ptr %p) {
+; CHECK-LABEL: reduce_fadd_16xf32_prefix2:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vle32.v v8, (a0)
+; CHECK-NEXT: vmv.s.x v9, zero
+; CHECK-NEXT: vfredusum.vs v8, v8, v9
+; CHECK-NEXT: vfmv.f.s fa0, v8
+; CHECK-NEXT: ret
+ %v = load <16 x float>, ptr %p, align 256
+ %e0 = extractelement <16 x float> %v, i32 0
+ %e1 = extractelement <16 x float> %v, i32 1
+ %fadd0 = fadd fast float %e0, %e1
+ ret float %fadd0
+}
+
+define float @reduce_fadd_16xi32_prefix5(ptr %p) {
+; CHECK-LABEL: reduce_fadd_16xi32_prefix5:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lui a1, 524288
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vle32.v v8, (a0)
+; CHECK-NEXT: vmv.s.x v10, a1
+; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma
+; CHECK-NEXT: vslideup.vi v8, v10, 5
+; CHECK-NEXT: vsetivli zero, 7, e32, m2, tu, ma
+; CHECK-NEXT: vslideup.vi v8, v10, 6
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vslideup.vi v8, v10, 7
+; CHECK-NEXT: vfredusum.vs v8, v8, v10
+; CHECK-NEXT: vfmv.f.s fa0, v8
+; CHECK-NEXT: ret
+ %v = load <16 x float>, ptr %p, align 256
+ %e0 = extractelement <16 x float> %v, i32 0
+ %e1 = extractelement <16 x float> %v, i32 1
+ %e2 = extractelement <16 x float> %v, i32 2
+ %e3 = extractelement <16 x float> %v, i32 3
+ %e4 = extractelement <16 x float> %v, i32 4
+ %fadd0 = fadd fast float %e0, %e1
+ %fadd1 = fadd fast float %fadd0, %e2
+ %fadd2 = fadd fast float %fadd1, %e3
+ %fadd3 = fadd fast float %fadd2, %e4
+ ret float %fadd3
+}
+
+;; Corner case tests for fadd associativity
+
+; Negative test, not associative. Would need strict opcode.
+define float @reduce_fadd_2xf32_non_associative(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_non_associative:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vle32.v v8, (a0)
+; CHECK-NEXT: vfmv.f.s fa5, v8
+; CHECK-NEXT: vslidedown.vi v8, v8, 1
+; CHECK-NEXT: vfmv.f.s fa4, v8
+; CHECK-NEXT: fadd.s fa0, fa5, fa4
+; CHECK-NEXT: ret
+ %v = load <2 x float>, ptr %p, align 256
+ %e0 = extractelement <2 x float> %v, i32 0
+ %e1 = extractelement <2 x float> %v, i32 1
+ %fadd0 = fadd float %e0, %e1
+ ret float %fadd0
+}
+
+; Positive test - minimal set of fast math flags
+define float @reduce_fadd_2xf32_reassoc_only(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_reassoc_only:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vle32.v v8, (a0)
+; CHECK-NEXT: lui a0, 524288
+; CHECK-NEXT: vmv.s.x v9, a0
+; CHECK-NEXT: vfredusum.vs v8, v8, v9
+; CHECK-NEXT: vfmv.f.s fa0, v8
+; CHECK-NEXT: ret
+ %v = load <2 x float>, ptr %p, align 256
+ %e0 = extractelement <2 x float> %v, i32 0
+ %e1 = extractelement <2 x float> %v, i32 1
+ %fadd0 = fadd reassoc float %e0, %e1
+ ret float %fadd0
+}
+
+; Negative test - wrong fast math flag.
+define float @reduce_fadd_2xf32_ninf_only(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_ninf_only:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vle32.v v8, (a0)
+; CHECK-NEXT: vfmv.f.s fa5, v8
+; CHECK-NEXT: vslidedown.vi v8, v8, 1
+; CHECK-NEXT: vfmv.f.s fa4, v8
+; CHECK-NEXT: fadd.s fa0, fa5, fa4
+; CHECK-NEXT: ret
+ %v = load <2 x float>, ptr %p, align 256
+ %e0 = extractelement <2 x float> %v, i32 0
+ %e1 = extractelement <2 x float> %v, i32 1
+ %fadd0 = fadd ninf float %e0, %e1
+ ret float %fadd0
+}
+
+
+; Negative test - last fadd is not associative
+define float @reduce_fadd_4xi32_non_associative(ptr %p) {
+; CHECK-LABEL: reduce_fadd_4xi32_non_associative:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT: vle32.v v8, (a0)
+; CHECK-NEXT: vslidedown.vi v9, v8, 3
+; CHECK-NEXT: vfmv.f.s fa5, v9
+; CHECK-NEXT: lui a0, 524288
+; CHECK-NEXT: vmv.s.x v9, a0
+; CHECK-NEXT: vslideup.vi v8, v9, 3
+; CHECK-NEXT: vfredusum.vs v8, v8, v9
+; CHECK-NEXT: vfmv.f.s fa4, v8
+; CHECK-NEXT: fadd.s fa0, fa4, fa5
+; CHECK-NEXT: ret
+ %v = load <4 x float>, ptr %p, align 256
+ %e0 = extractelement <4 x float> %v, i32 0
+ %e1 = extractelement <4 x float> %v, i32 1
+ %e2 = extractelement <4 x float> %v, i32 2
+ %e3 = extractelement <4 x float> %v, i32 3
+ %fadd0 = fadd fast float %e0, %e1
+ %fadd1 = fadd fast float %fadd0, %e2
+ %fadd2 = fadd float %fadd1, %e3
+ ret float %fadd2
+}
+
+; Negative test - first fadd is not associative
+; We could form a reduce for elements 2 and 3.
+define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
+; CHECK-LABEL: reduce_fadd_4xi32_non_associative2:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT: vle32.v v8, (a0)
+; CHECK-NEXT: vfmv.f.s fa5, v8
+; CHECK-NEXT: vslidedown.vi v9, v8, 1
+; CHECK-NEXT: vfmv.f.s fa4, v9
+; CHECK-NEXT: vslidedown.vi v9, v8, 2
+; CHECK-NEXT: vfmv.f.s fa3, v9
+; CHECK-NEXT: vslidedown.vi v8, v8, 3
+; CHECK-NEXT: vfmv.f.s fa2, v8
+; CHECK-NEXT: fadd.s fa5, fa5, fa4
+; CHECK-NEXT: fadd.s fa4, fa3, fa2
+; CHECK-NEXT: fadd.s fa0, fa5, fa4
+; CHECK-NEXT: ret
+ %v = load <4 x float>, ptr %p, align 256
+ %e0 = extractelement <4 x float> %v, i32 0
+ %e1 = extractelement <4 x float> %v, i32 1
+ %e2 = extractelement <4 x float> %v, i32 2
+ %e3 = extractelement <4 x float> %v, i32 3
+ %fadd0 = fadd float %e0, %e1
+ %fadd1 = fadd fast float %fadd0, %e2
+ %fadd2 = fadd fast float %fadd1, %e3
+ ret float %fadd2
+}
+
+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; RV32: {{.*}}
; RV64: {{.*}}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
; CHECK-NEXT: vslideup.vi v8, v10, 6 | ||
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma | ||
; CHECK-NEXT: vslideup.vi v8, v10, 7 | ||
; CHECK-NEXT: vfredusum.vs v8, v8, v10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we could improve the legalisation here by using VP nodes (something like https://reviews.llvm.org/D148523 but for reductions) to avoid having to pad out the vector with zeroes. Or it looks like there could also be a combine to replace all these inserts with a splat and single slide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a bunch of possible improvements here. Probably the best is to use a masked reduction.
TBH, the prefix cases aren't showing up my motivating workloads (spec2017), so I'm not super worried about them. I wanted something correct and not terrible so that the incremental approach was viable, but that's currently about as far as I care.
This builds on the transform introduced in f0505c3, and generalizes to all integer operations in 45a334d. This change adds support for floating point sumation.
A couple of notes: