Skip to content

[RISCV] Generaize reduction tree matching to fp sum reductions #68599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11114,7 +11114,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
}
}

/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP
/// Given a binary operator, return the *associative* generic ISD::VECREDUCE_OP
/// which corresponds to it.
static unsigned getVecReduceOpcode(unsigned Opc) {
switch (Opc) {
Expand All @@ -11136,6 +11136,9 @@ static unsigned getVecReduceOpcode(unsigned Opc) {
return ISD::VECREDUCE_OR;
case ISD::XOR:
return ISD::VECREDUCE_XOR;
case ISD::FADD:
// Note: This is the associative form of the generic reduction opcode.
return ISD::VECREDUCE_FADD;
}
}

Expand All @@ -11162,12 +11165,16 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,

const SDLoc DL(N);
const EVT VT = N->getValueType(0);
const unsigned Opc = N->getOpcode();

// TODO: Handle floating point here.
if (!VT.isInteger())
// For FADD, we only handle the case with reassociation allowed. We
// could handle strict reduction order, but at the moment, there's no
// known reason to, and the complexity isn't worth it.
// TODO: Handle fminnum and fmaxnum here
if (!VT.isInteger() &&
(Opc != ISD::FADD || !N->getFlags().hasAllowReassociation()))
return SDValue();

const unsigned Opc = N->getOpcode();
const unsigned ReduceOpc = getVecReduceOpcode(Opc);
assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
"Inconsistent mappings");
Expand Down Expand Up @@ -11200,7 +11207,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
DAG.getVectorIdxConstant(0, DL));
return DAG.getNode(ReduceOpc, DL, VT, Vec);
return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
}

// Match (binop (reduce (extract_subvector V, 0),
Expand All @@ -11222,7 +11229,9 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
DAG.getVectorIdxConstant(0, DL));
return DAG.getNode(ReduceOpc, DL, VT, Vec);
auto Flags = ReduceVec->getFlags();
Flags.intersectWith(N->getFlags());
return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
}
}

Expand Down
159 changes: 159 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,165 @@ define i32 @reduce_umin_16xi32_prefix5(ptr %p) {
%umin3 = call i32 @llvm.umin.i32(i32 %umin2, i32 %e4)
ret i32 %umin3
}

define float @reduce_fadd_16xf32_prefix2(ptr %p) {
; CHECK-LABEL: reduce_fadd_16xf32_prefix2:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: vmv.s.x v9, zero
; CHECK-NEXT: vfredusum.vs v8, v8, v9
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: ret
%v = load <16 x float>, ptr %p, align 256
%e0 = extractelement <16 x float> %v, i32 0
%e1 = extractelement <16 x float> %v, i32 1
%fadd0 = fadd fast float %e0, %e1
ret float %fadd0
}

define float @reduce_fadd_16xi32_prefix5(ptr %p) {
; CHECK-LABEL: reduce_fadd_16xi32_prefix5:
; CHECK: # %bb.0:
; CHECK-NEXT: lui a1, 524288
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: vmv.s.x v10, a1
; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma
; CHECK-NEXT: vslideup.vi v8, v10, 5
; CHECK-NEXT: vsetivli zero, 7, e32, m2, tu, ma
; CHECK-NEXT: vslideup.vi v8, v10, 6
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vslideup.vi v8, v10, 7
; CHECK-NEXT: vfredusum.vs v8, v8, v10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could improve the legalisation here by using VP nodes (something like https://reviews.llvm.org/D148523 but for reductions) to avoid having to pad out the vector with zeroes. Or it looks like there could also be a combine to replace all these inserts with a splat and single slide.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a bunch of possible improvements here. Probably the best is to use a masked reduction.

TBH, the prefix cases aren't showing up my motivating workloads (spec2017), so I'm not super worried about them. I wanted something correct and not terrible so that the incremental approach was viable, but that's currently about as far as I care.

; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: ret
%v = load <16 x float>, ptr %p, align 256
%e0 = extractelement <16 x float> %v, i32 0
%e1 = extractelement <16 x float> %v, i32 1
%e2 = extractelement <16 x float> %v, i32 2
%e3 = extractelement <16 x float> %v, i32 3
%e4 = extractelement <16 x float> %v, i32 4
%fadd0 = fadd fast float %e0, %e1
%fadd1 = fadd fast float %fadd0, %e2
%fadd2 = fadd fast float %fadd1, %e3
%fadd3 = fadd fast float %fadd2, %e4
ret float %fadd3
}

;; Corner case tests for fadd associativity

; Negative test, not associative. Would need strict opcode.
define float @reduce_fadd_2xf32_non_associative(ptr %p) {
; CHECK-LABEL: reduce_fadd_2xf32_non_associative:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: vfmv.f.s fa5, v8
; CHECK-NEXT: vslidedown.vi v8, v8, 1
; CHECK-NEXT: vfmv.f.s fa4, v8
; CHECK-NEXT: fadd.s fa0, fa5, fa4
; CHECK-NEXT: ret
%v = load <2 x float>, ptr %p, align 256
%e0 = extractelement <2 x float> %v, i32 0
%e1 = extractelement <2 x float> %v, i32 1
%fadd0 = fadd float %e0, %e1
ret float %fadd0
}

; Positive test - minimal set of fast math flags
define float @reduce_fadd_2xf32_reassoc_only(ptr %p) {
; CHECK-LABEL: reduce_fadd_2xf32_reassoc_only:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: lui a0, 524288
; CHECK-NEXT: vmv.s.x v9, a0
; CHECK-NEXT: vfredusum.vs v8, v8, v9
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: ret
%v = load <2 x float>, ptr %p, align 256
%e0 = extractelement <2 x float> %v, i32 0
%e1 = extractelement <2 x float> %v, i32 1
%fadd0 = fadd reassoc float %e0, %e1
ret float %fadd0
}

; Negative test - wrong fast math flag.
define float @reduce_fadd_2xf32_ninf_only(ptr %p) {
; CHECK-LABEL: reduce_fadd_2xf32_ninf_only:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: vfmv.f.s fa5, v8
; CHECK-NEXT: vslidedown.vi v8, v8, 1
; CHECK-NEXT: vfmv.f.s fa4, v8
; CHECK-NEXT: fadd.s fa0, fa5, fa4
; CHECK-NEXT: ret
%v = load <2 x float>, ptr %p, align 256
%e0 = extractelement <2 x float> %v, i32 0
%e1 = extractelement <2 x float> %v, i32 1
%fadd0 = fadd ninf float %e0, %e1
ret float %fadd0
}


; Negative test - last fadd is not associative
define float @reduce_fadd_4xi32_non_associative(ptr %p) {
; CHECK-LABEL: reduce_fadd_4xi32_non_associative:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: vslidedown.vi v9, v8, 3
; CHECK-NEXT: vfmv.f.s fa5, v9
; CHECK-NEXT: lui a0, 524288
; CHECK-NEXT: vmv.s.x v9, a0
; CHECK-NEXT: vslideup.vi v8, v9, 3
; CHECK-NEXT: vfredusum.vs v8, v8, v9
; CHECK-NEXT: vfmv.f.s fa4, v8
; CHECK-NEXT: fadd.s fa0, fa4, fa5
; CHECK-NEXT: ret
%v = load <4 x float>, ptr %p, align 256
%e0 = extractelement <4 x float> %v, i32 0
%e1 = extractelement <4 x float> %v, i32 1
%e2 = extractelement <4 x float> %v, i32 2
%e3 = extractelement <4 x float> %v, i32 3
%fadd0 = fadd fast float %e0, %e1
%fadd1 = fadd fast float %fadd0, %e2
%fadd2 = fadd float %fadd1, %e3
ret float %fadd2
}

; Negative test - first fadd is not associative
; We could form a reduce for elements 2 and 3.
define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
; CHECK-LABEL: reduce_fadd_4xi32_non_associative2:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: vfmv.f.s fa5, v8
; CHECK-NEXT: vslidedown.vi v9, v8, 1
; CHECK-NEXT: vfmv.f.s fa4, v9
; CHECK-NEXT: vslidedown.vi v9, v8, 2
; CHECK-NEXT: vfmv.f.s fa3, v9
; CHECK-NEXT: vslidedown.vi v8, v8, 3
; CHECK-NEXT: vfmv.f.s fa2, v8
; CHECK-NEXT: fadd.s fa5, fa5, fa4
; CHECK-NEXT: fadd.s fa4, fa3, fa2
; CHECK-NEXT: fadd.s fa0, fa5, fa4
; CHECK-NEXT: ret
%v = load <4 x float>, ptr %p, align 256
%e0 = extractelement <4 x float> %v, i32 0
%e1 = extractelement <4 x float> %v, i32 1
%e2 = extractelement <4 x float> %v, i32 2
%e3 = extractelement <4 x float> %v, i32 3
%fadd0 = fadd float %e0, %e1
%fadd1 = fadd fast float %fadd0, %e2
%fadd2 = fadd fast float %fadd1, %e3
ret float %fadd2
}


;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; RV32: {{.*}}
; RV64: {{.*}}