Skip to content

[RISCV] Form vredsum from explode_vector + scalar (left) reduce #67821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 2, 2023

Conversation

preames
Copy link
Collaborator

@preames preames commented Sep 29, 2023

This change adds two related DAG combines which together will take a left-reduce scalar add tree of an explode_vector, and will incrementally form a vector reduction of the vector prefix. If the entire vector is reduced, the result will be a reduction over the entire vector.

Profitability wise, this relies on vredsum being cheaper than a pair of extracts and scalar add. Given vredsum is linear in LMUL, and the vslidedown required for the extract is also linear in LMUL, this is clearly true at higher index values. At N=2, it's a bit questionable, but I think the vredsum form is probably a better canonical form anyways.

Note that this only matches left reduces. This happens to be the motivating example I have (from spec2017 x264). This approach could be generalized to handle right reduces without much effort, and could be generalized to handle any reduce whose tree starts with adjacent elements if desired. The approach fails for a reduce such as (A+C)+(B+D) because we can't find a root to start the reduce with without scanning the entire associative add expression. We could maybe explore using masked reduces for the root node, but that seems of questionable profitability. (As in, worth questioning - I haven't explored in any detail.)

This is covering up a deficiency in SLP. If SLP encounters the scalar form of reduce_or(A) + reduce_sum(a) where a is some common vectorizeable tree, SLP will sometimes fail to revisit one of the reductions after vectorizing the other. Fixing this in SLP is hard, and there's no good reason not to handle the easy cases in the backend.

Another option here would be to do this in VectorCombine or generic DAG. I chose not to as the profitability of the non-legal typed prefix cases is very target dependent. I think this makes sense as a starting point, even if we move it elsewhere later.

This is currently restructed only to add reduces, but obviously makes sense for any associative reduction operator. Once this is approved, I plan to extend it in this manner. I'm simply staging work in case we decide to go in another direction.

This change adds two related DAG combines which together will take a left-reduce scalar add tree of an explode_vector, and will incrementally form a vector reduction of the vector prefix.  If the entire vector is reduced, the result will be a reduction over the entire vector.

Profitability wise, this relies on vredsum being cheaper than a pair of extracts and scalar add.  Given vredsum is linear in LMUL, and the vslidedown required for the extract is *also* linear in LMUL, this is clearly true at higher index values.  At N=2, it's a bit questionable, but I think the vredsum form is probably a better canonical form anyways.

Note that this only matches left reduces.  This happens to be the motivating example I have (from spec2017 x264).  This approach could be generalized to handle right reduces without much effort, and could be generalized to handle any reduce whose tree starts with adjacent elements if desired.  The approach fails for a reduce such as (A+C)+(B+D) because we can't find a root to start the reduce with without scanning the entire associative add expression.  We could maybe explore using masked reduces for the root node, but that seems of questionable profitability.  (As in, worth questioning - I haven't explored in any detail.)

This is covering up a deficiency in SLP.  If SLP encounters the scalar form of reduce_or(A) + reduce_sum(a) where a is some common vectorizeable tree, SLP will sometimes fail to revisit one of the reductions after vectorizing the other.  Fixing this in SLP is hard, and there's no good reason not to handle the easy cases in the backend.

Another option here would be to do this in VectorCombine or generic DAG.  I chose not to as the profitability of the non-legal typed prefix cases is very target dependent.  I think this makes sense as a starting point, even if we move it elsewhere later.

This is currently restructed only to add reduces, but obviously makes sense for any associative reduction operator.  Once this is approved, I plan to extend it in this manner.  I'm simply staging work in case we decide to go in another direction.
@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2023

@llvm/pr-subscribers-backend-risc-v

Changes

This change adds two related DAG combines which together will take a left-reduce scalar add tree of an explode_vector, and will incrementally form a vector reduction of the vector prefix. If the entire vector is reduced, the result will be a reduction over the entire vector.

Profitability wise, this relies on vredsum being cheaper than a pair of extracts and scalar add. Given vredsum is linear in LMUL, and the vslidedown required for the extract is also linear in LMUL, this is clearly true at higher index values. At N=2, it's a bit questionable, but I think the vredsum form is probably a better canonical form anyways.

Note that this only matches left reduces. This happens to be the motivating example I have (from spec2017 x264). This approach could be generalized to handle right reduces without much effort, and could be generalized to handle any reduce whose tree starts with adjacent elements if desired. The approach fails for a reduce such as (A+C)+(B+D) because we can't find a root to start the reduce with without scanning the entire associative add expression. We could maybe explore using masked reduces for the root node, but that seems of questionable profitability. (As in, worth questioning - I haven't explored in any detail.)

This is covering up a deficiency in SLP. If SLP encounters the scalar form of reduce_or(A) + reduce_sum(a) where a is some common vectorizeable tree, SLP will sometimes fail to revisit one of the reductions after vectorizing the other. Fixing this in SLP is hard, and there's no good reason not to handle the easy cases in the backend.

Another option here would be to do this in VectorCombine or generic DAG. I chose not to as the profitability of the non-legal typed prefix cases is very target dependent. I think this makes sense as a starting point, even if we move it elsewhere later.

This is currently restructed only to add reduces, but obviously makes sense for any associative reduction operator. Once this is approved, I plan to extend it in this manner. I'm simply staging work in case we decide to go in another direction.


Patch is 47.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67821.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+82)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll (+159-952)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5e3975df1c4425d..d51d72edb64cdaa 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11122,6 +11122,85 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
   }
 }
 
+/// Perform two related transforms whose purpose is to incrementally recognize
+/// an explode_vector followed by scalar reduction as a vector reduction node.
+/// This exists to recover from a deficiency in SLP which can't handle
+/// forrests with multiple roots sharing common nodes.  In some cases, one
+/// of the trees will be vectorized, and the other will remain (unprofitably)
+/// scalarized.
+static SDValue combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
+                                                 const RISCVSubtarget &Subtarget) {
+
+  // This transforms need to run before all integer types have been legalized
+  // to i64 (so that the vector element type matches the add type), and while
+  // it's safe to introduce odd sized vector types.
+  if (DAG.NewNodesMustHaveLegalTypes)
+    return SDValue();
+
+  const SDLoc DL(N);
+  const EVT VT = N->getValueType(0);
+  const unsigned Opc = N->getOpcode();
+  assert(Opc == ISD::ADD && "extend this to other reduction types");
+  const SDValue LHS = N->getOperand(0);
+  const SDValue RHS = N->getOperand(1);
+
+  if (!LHS.hasOneUse() || !RHS.hasOneUse())
+    return SDValue();
+
+  if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+      !isa<ConstantSDNode>(RHS.getOperand(1)))
+    return SDValue();
+
+  SDValue SrcVec = RHS.getOperand(0);
+  EVT SrcVecVT = SrcVec.getValueType();
+  assert(SrcVecVT.getVectorElementType() == VT);
+  if (SrcVecVT.isScalableVector())
+    return SDValue();
+
+  if (SrcVecVT.getScalarSizeInBits() > Subtarget.getELen())
+    return SDValue();
+
+  // match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
+  // reduce_op (extract_subvector [2 x VT] from V).  This will form the
+  // root of our reduction tree. TODO: We could extend this to any two
+  // adjacent constant indices if desired.
+  if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+      LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
+      isOneConstant(RHS.getOperand(1))) {
+    EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
+    SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+                              DAG.getVectorIdxConstant(0, DL));
+    return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
+  }
+
+  // Match (binop reduce (extract_subvector V, 0),
+  //                      extract_vector_elt V, sizeof(SubVec))
+  // into a reduction of one more element from the original vector V.
+  if (LHS.getOpcode() != ISD::VECREDUCE_ADD)
+    return SDValue();
+
+  SDValue ReduceVec = LHS.getOperand(0);
+  if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+      ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
+      isNullConstant(ReduceVec.getOperand(1)) &&
+      isa<ConstantSDNode>(RHS.getOperand(1))) {
+    uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getZExtValue();
+    if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
+      // For illegal types (e.g. 3xi32), most will be combined again into a
+      // wider (hopefully legal) type.  If this is a terminal state, we are
+      // relying on type legalization here to poduce something reasonable
+      // and this lowering quality could probably be improved. (TODO)
+      EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx+1);
+      SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+                                DAG.getVectorIdxConstant(0, DL));
+      return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
+    }
+  }
+
+  return SDValue();
+}
+
+
 // Try to fold (<bop> x, (reduction.<bop> vec, start))
 static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG,
                                     const RISCVSubtarget &Subtarget) {
@@ -11449,6 +11528,9 @@ static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
     return V;
   if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
     return V;
+  if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
+    return V;
+
   // fold (add (select lhs, rhs, cc, 0, y), x) ->
   //      (select lhs, rhs, cc, x, (add x, y))
   return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index 5ef6b291e309517..173b70def03d4c5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -1,25 +1,15 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+v,+m -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+m -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+m -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+v,+m -verify-machineinstrs < %s | FileCheck %s
 
 define i32 @reduce_sum_2xi32(<2 x i32> %v) {
-; RV32-LABEL: reduce_sum_2xi32:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
-; RV32-NEXT:    vmv.x.s a0, v8
-; RV32-NEXT:    vslidedown.vi v8, v8, 1
-; RV32-NEXT:    vmv.x.s a1, v8
-; RV32-NEXT:    add a0, a0, a1
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: reduce_sum_2xi32:
-; RV64:       # %bb.0:
-; RV64-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
-; RV64-NEXT:    vmv.x.s a0, v8
-; RV64-NEXT:    vslidedown.vi v8, v8, 1
-; RV64-NEXT:    vmv.x.s a1, v8
-; RV64-NEXT:    addw a0, a0, a1
-; RV64-NEXT:    ret
+; CHECK-LABEL: reduce_sum_2xi32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v9
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
   %e0 = extractelement <2 x i32> %v, i32 0
   %e1 = extractelement <2 x i32> %v, i32 1
   %add0 = add i32 %e0, %e1
@@ -27,35 +17,13 @@ define i32 @reduce_sum_2xi32(<2 x i32> %v) {
 }
 
 define i32 @reduce_sum_4xi32(<4 x i32> %v) {
-; RV32-LABEL: reduce_sum_4xi32:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV32-NEXT:    vmv.x.s a0, v8
-; RV32-NEXT:    vslidedown.vi v9, v8, 1
-; RV32-NEXT:    vmv.x.s a1, v9
-; RV32-NEXT:    vslidedown.vi v9, v8, 2
-; RV32-NEXT:    vmv.x.s a2, v9
-; RV32-NEXT:    vslidedown.vi v8, v8, 3
-; RV32-NEXT:    vmv.x.s a3, v8
-; RV32-NEXT:    add a0, a0, a1
-; RV32-NEXT:    add a2, a2, a3
-; RV32-NEXT:    add a0, a0, a2
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: reduce_sum_4xi32:
-; RV64:       # %bb.0:
-; RV64-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV64-NEXT:    vmv.x.s a0, v8
-; RV64-NEXT:    vslidedown.vi v9, v8, 1
-; RV64-NEXT:    vmv.x.s a1, v9
-; RV64-NEXT:    vslidedown.vi v9, v8, 2
-; RV64-NEXT:    vmv.x.s a2, v9
-; RV64-NEXT:    vslidedown.vi v8, v8, 3
-; RV64-NEXT:    vmv.x.s a3, v8
-; RV64-NEXT:    add a0, a0, a1
-; RV64-NEXT:    add a2, a2, a3
-; RV64-NEXT:    addw a0, a0, a2
-; RV64-NEXT:    ret
+; CHECK-LABEL: reduce_sum_4xi32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v9
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
   %e0 = extractelement <4 x i32> %v, i32 0
   %e1 = extractelement <4 x i32> %v, i32 1
   %e2 = extractelement <4 x i32> %v, i32 2
@@ -68,61 +36,13 @@ define i32 @reduce_sum_4xi32(<4 x i32> %v) {
 
 
 define i32 @reduce_sum_8xi32(<8 x i32> %v) {
-; RV32-LABEL: reduce_sum_8xi32:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV32-NEXT:    vmv.x.s a0, v8
-; RV32-NEXT:    vslidedown.vi v10, v8, 1
-; RV32-NEXT:    vmv.x.s a1, v10
-; RV32-NEXT:    vslidedown.vi v10, v8, 2
-; RV32-NEXT:    vmv.x.s a2, v10
-; RV32-NEXT:    vslidedown.vi v10, v8, 3
-; RV32-NEXT:    vmv.x.s a3, v10
-; RV32-NEXT:    vsetivli zero, 1, e32, m2, ta, ma
-; RV32-NEXT:    vslidedown.vi v10, v8, 4
-; RV32-NEXT:    vmv.x.s a4, v10
-; RV32-NEXT:    vslidedown.vi v10, v8, 5
-; RV32-NEXT:    vmv.x.s a5, v10
-; RV32-NEXT:    vslidedown.vi v10, v8, 6
-; RV32-NEXT:    vmv.x.s a6, v10
-; RV32-NEXT:    vslidedown.vi v8, v8, 7
-; RV32-NEXT:    vmv.x.s a7, v8
-; RV32-NEXT:    add a0, a0, a1
-; RV32-NEXT:    add a2, a2, a3
-; RV32-NEXT:    add a0, a0, a2
-; RV32-NEXT:    add a4, a4, a5
-; RV32-NEXT:    add a4, a4, a6
-; RV32-NEXT:    add a0, a0, a4
-; RV32-NEXT:    add a0, a0, a7
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: reduce_sum_8xi32:
-; RV64:       # %bb.0:
-; RV64-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV64-NEXT:    vmv.x.s a0, v8
-; RV64-NEXT:    vslidedown.vi v10, v8, 1
-; RV64-NEXT:    vmv.x.s a1, v10
-; RV64-NEXT:    vslidedown.vi v10, v8, 2
-; RV64-NEXT:    vmv.x.s a2, v10
-; RV64-NEXT:    vslidedown.vi v10, v8, 3
-; RV64-NEXT:    vmv.x.s a3, v10
-; RV64-NEXT:    vsetivli zero, 1, e32, m2, ta, ma
-; RV64-NEXT:    vslidedown.vi v10, v8, 4
-; RV64-NEXT:    vmv.x.s a4, v10
-; RV64-NEXT:    vslidedown.vi v10, v8, 5
-; RV64-NEXT:    vmv.x.s a5, v10
-; RV64-NEXT:    vslidedown.vi v10, v8, 6
-; RV64-NEXT:    vmv.x.s a6, v10
-; RV64-NEXT:    vslidedown.vi v8, v8, 7
-; RV64-NEXT:    vmv.x.s a7, v8
-; RV64-NEXT:    add a0, a0, a1
-; RV64-NEXT:    add a2, a2, a3
-; RV64-NEXT:    add a0, a0, a2
-; RV64-NEXT:    add a4, a4, a5
-; RV64-NEXT:    add a4, a4, a6
-; RV64-NEXT:    add a0, a0, a4
-; RV64-NEXT:    addw a0, a0, a7
-; RV64-NEXT:    ret
+; CHECK-LABEL: reduce_sum_8xi32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmv.s.x v10, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v10
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
   %e0 = extractelement <8 x i32> %v, i32 0
   %e1 = extractelement <8 x i32> %v, i32 1
   %e2 = extractelement <8 x i32> %v, i32 2
@@ -142,131 +62,13 @@ define i32 @reduce_sum_8xi32(<8 x i32> %v) {
 }
 
 define i32 @reduce_sum_16xi32(<16 x i32> %v) {
-; RV32-LABEL: reduce_sum_16xi32:
-; RV32:       # %bb.0:
-; RV32-NEXT:    addi sp, sp, -128
-; RV32-NEXT:    .cfi_def_cfa_offset 128
-; RV32-NEXT:    sw ra, 124(sp) # 4-byte Folded Spill
-; RV32-NEXT:    sw s0, 120(sp) # 4-byte Folded Spill
-; RV32-NEXT:    sw s2, 116(sp) # 4-byte Folded Spill
-; RV32-NEXT:    .cfi_offset ra, -4
-; RV32-NEXT:    .cfi_offset s0, -8
-; RV32-NEXT:    .cfi_offset s2, -12
-; RV32-NEXT:    addi s0, sp, 128
-; RV32-NEXT:    .cfi_def_cfa s0, 0
-; RV32-NEXT:    andi sp, sp, -64
-; RV32-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV32-NEXT:    vmv.x.s a0, v8
-; RV32-NEXT:    vslidedown.vi v12, v8, 1
-; RV32-NEXT:    vmv.x.s a1, v12
-; RV32-NEXT:    vslidedown.vi v12, v8, 2
-; RV32-NEXT:    vmv.x.s a2, v12
-; RV32-NEXT:    vslidedown.vi v12, v8, 3
-; RV32-NEXT:    vmv.x.s a3, v12
-; RV32-NEXT:    vsetivli zero, 1, e32, m2, ta, ma
-; RV32-NEXT:    vslidedown.vi v12, v8, 4
-; RV32-NEXT:    vmv.x.s a4, v12
-; RV32-NEXT:    vslidedown.vi v12, v8, 5
-; RV32-NEXT:    vmv.x.s a5, v12
-; RV32-NEXT:    vslidedown.vi v12, v8, 6
-; RV32-NEXT:    vmv.x.s a6, v12
-; RV32-NEXT:    vslidedown.vi v12, v8, 7
-; RV32-NEXT:    vmv.x.s a7, v12
-; RV32-NEXT:    mv t0, sp
-; RV32-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; RV32-NEXT:    vse32.v v8, (t0)
-; RV32-NEXT:    lw t0, 32(sp)
-; RV32-NEXT:    lw t1, 36(sp)
-; RV32-NEXT:    lw t2, 40(sp)
-; RV32-NEXT:    lw t3, 44(sp)
-; RV32-NEXT:    lw t4, 48(sp)
-; RV32-NEXT:    lw t5, 52(sp)
-; RV32-NEXT:    lw t6, 56(sp)
-; RV32-NEXT:    lw s2, 60(sp)
-; RV32-NEXT:    add a0, a0, a1
-; RV32-NEXT:    add a2, a2, a3
-; RV32-NEXT:    add a0, a0, a2
-; RV32-NEXT:    add a4, a4, a5
-; RV32-NEXT:    add a4, a4, a6
-; RV32-NEXT:    add a0, a0, a4
-; RV32-NEXT:    add a7, a7, t0
-; RV32-NEXT:    add a0, a0, a7
-; RV32-NEXT:    add t1, t1, t2
-; RV32-NEXT:    add t1, t1, t3
-; RV32-NEXT:    add a0, a0, t1
-; RV32-NEXT:    add t4, t4, t5
-; RV32-NEXT:    add t4, t4, t6
-; RV32-NEXT:    add t4, t4, s2
-; RV32-NEXT:    add a0, a0, t4
-; RV32-NEXT:    addi sp, s0, -128
-; RV32-NEXT:    lw ra, 124(sp) # 4-byte Folded Reload
-; RV32-NEXT:    lw s0, 120(sp) # 4-byte Folded Reload
-; RV32-NEXT:    lw s2, 116(sp) # 4-byte Folded Reload
-; RV32-NEXT:    addi sp, sp, 128
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: reduce_sum_16xi32:
-; RV64:       # %bb.0:
-; RV64-NEXT:    addi sp, sp, -128
-; RV64-NEXT:    .cfi_def_cfa_offset 128
-; RV64-NEXT:    sd ra, 120(sp) # 8-byte Folded Spill
-; RV64-NEXT:    sd s0, 112(sp) # 8-byte Folded Spill
-; RV64-NEXT:    sd s2, 104(sp) # 8-byte Folded Spill
-; RV64-NEXT:    .cfi_offset ra, -8
-; RV64-NEXT:    .cfi_offset s0, -16
-; RV64-NEXT:    .cfi_offset s2, -24
-; RV64-NEXT:    addi s0, sp, 128
-; RV64-NEXT:    .cfi_def_cfa s0, 0
-; RV64-NEXT:    andi sp, sp, -64
-; RV64-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV64-NEXT:    vmv.x.s a0, v8
-; RV64-NEXT:    vslidedown.vi v12, v8, 1
-; RV64-NEXT:    vmv.x.s a1, v12
-; RV64-NEXT:    vslidedown.vi v12, v8, 2
-; RV64-NEXT:    vmv.x.s a2, v12
-; RV64-NEXT:    vslidedown.vi v12, v8, 3
-; RV64-NEXT:    vmv.x.s a3, v12
-; RV64-NEXT:    vsetivli zero, 1, e32, m2, ta, ma
-; RV64-NEXT:    vslidedown.vi v12, v8, 4
-; RV64-NEXT:    vmv.x.s a4, v12
-; RV64-NEXT:    vslidedown.vi v12, v8, 5
-; RV64-NEXT:    vmv.x.s a5, v12
-; RV64-NEXT:    vslidedown.vi v12, v8, 6
-; RV64-NEXT:    vmv.x.s a6, v12
-; RV64-NEXT:    vslidedown.vi v12, v8, 7
-; RV64-NEXT:    vmv.x.s a7, v12
-; RV64-NEXT:    mv t0, sp
-; RV64-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; RV64-NEXT:    vse32.v v8, (t0)
-; RV64-NEXT:    lw t0, 32(sp)
-; RV64-NEXT:    lw t1, 36(sp)
-; RV64-NEXT:    lw t2, 40(sp)
-; RV64-NEXT:    lw t3, 44(sp)
-; RV64-NEXT:    lw t4, 48(sp)
-; RV64-NEXT:    lw t5, 52(sp)
-; RV64-NEXT:    lw t6, 56(sp)
-; RV64-NEXT:    lw s2, 60(sp)
-; RV64-NEXT:    add a0, a0, a1
-; RV64-NEXT:    add a2, a2, a3
-; RV64-NEXT:    add a0, a0, a2
-; RV64-NEXT:    add a4, a4, a5
-; RV64-NEXT:    add a4, a4, a6
-; RV64-NEXT:    add a0, a0, a4
-; RV64-NEXT:    add a7, a7, t0
-; RV64-NEXT:    add a0, a0, a7
-; RV64-NEXT:    add t1, t1, t2
-; RV64-NEXT:    add t1, t1, t3
-; RV64-NEXT:    add a0, a0, t1
-; RV64-NEXT:    add t4, t4, t5
-; RV64-NEXT:    add t4, t4, t6
-; RV64-NEXT:    add t4, t4, s2
-; RV64-NEXT:    addw a0, a0, t4
-; RV64-NEXT:    addi sp, s0, -128
-; RV64-NEXT:    ld ra, 120(sp) # 8-byte Folded Reload
-; RV64-NEXT:    ld s0, 112(sp) # 8-byte Folded Reload
-; RV64-NEXT:    ld s2, 104(sp) # 8-byte Folded Reload
-; RV64-NEXT:    addi sp, sp, 128
-; RV64-NEXT:    ret
+; CHECK-LABEL: reduce_sum_16xi32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; CHECK-NEXT:    vmv.s.x v12, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v12
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
   %e0 = extractelement <16 x i32> %v, i32 0
   %e1 = extractelement <16 x i32> %v, i32 1
   %e2 = extractelement <16 x i32> %v, i32 2
@@ -302,27 +104,14 @@ define i32 @reduce_sum_16xi32(<16 x i32> %v) {
 }
 
 define i32 @reduce_sum_16xi32_prefix2(ptr %p) {
-; RV32-LABEL: reduce_sum_16xi32_prefix2:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; RV32-NEXT:    vle32.v v8, (a0)
-; RV32-NEXT:    vmv.x.s a0, v8
-; RV32-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV32-NEXT:    vslidedown.vi v8, v8, 1
-; RV32-NEXT:    vmv.x.s a1, v8
-; RV32-NEXT:    add a0, a0, a1
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: reduce_sum_16xi32_prefix2:
-; RV64:       # %bb.0:
-; RV64-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; RV64-NEXT:    vle32.v v8, (a0)
-; RV64-NEXT:    vmv.x.s a0, v8
-; RV64-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV64-NEXT:    vslidedown.vi v8, v8, 1
-; RV64-NEXT:    vmv.x.s a1, v8
-; RV64-NEXT:    addw a0, a0, a1
-; RV64-NEXT:    ret
+; CHECK-LABEL: reduce_sum_16xi32_prefix2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v9
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
   %v = load <16 x i32>, ptr %p, align 256
   %e0 = extractelement <16 x i32> %v, i32 0
   %e1 = extractelement <16 x i32> %v, i32 1
@@ -331,33 +120,15 @@ define i32 @reduce_sum_16xi32_prefix2(ptr %p) {
 }
 
 define i32 @reduce_sum_16xi32_prefix3(ptr %p) {
-; RV32-LABEL: reduce_sum_16xi32_prefix3:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; RV32-NEXT:    vle32.v v8, (a0)
-; RV32-NEXT:    vmv.x.s a0, v8
-; RV32-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV32-NEXT:    vslidedown.vi v9, v8, 1
-; RV32-NEXT:    vmv.x.s a1, v9
-; RV32-NEXT:    vslidedown.vi v8, v8, 2
-; RV32-NEXT:    vmv.x.s a2, v8
-; RV32-NEXT:    add a0, a0, a1
-; RV32-NEXT:    add a0, a0, a2
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: reduce_sum_16xi32_prefix3:
-; RV64:       # %bb.0:
-; RV64-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; RV64-NEXT:    vle32.v v8, (a0)
-; RV64-NEXT:    vmv.x.s a0, v8
-; RV64-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV64-NEXT:    vslidedown.vi v9, v8, 1
-; RV64-NEXT:    vmv.x.s a1, v9
-; RV64-NEXT:    vslidedown.vi v8, v8, 2
-; RV64-NEXT:    vmv.x.s a2, v8
-; RV64-NEXT:    add a0, a0, a1
-; RV64-NEXT:    addw a0, a0, a2
-; RV64-NEXT:    ret
+; CHECK-LABEL: reduce_sum_16xi32_prefix3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vslideup.vi v8, v9, 3
+; CHECK-NEXT:    vredsum.vs v8, v8, v9
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
   %v = load <16 x i32>, ptr %p, align 256
   %e0 = extractelement <16 x i32> %v, i32 0
   %e1 = extractelement <16 x i32> %v, i32 1
@@ -368,39 +139,14 @@ define i32 @reduce_sum_16xi32_prefix3(ptr %p) {
 }
 
 define i32 @reduce_sum_16xi32_prefix4(ptr %p) {
-; RV32-LABEL: reduce_sum_16xi32_prefix4:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; RV32-NEXT:    vle32.v v8, (a0)
-; RV32-NEXT:    vmv.x.s a0, v8
-; RV32-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV32-NEXT:    vslidedown.vi v9, v8, 1
-; RV32-NEXT:    vmv.x.s a1, v9
-; RV32-NEXT:    vslidedown.vi v9, v8, 2
-; RV32-NEXT:    vmv.x.s a2, v9
-; RV32-NEXT:    vslidedown.vi v8, v8, 3
-; RV32-NEXT:    vmv.x.s a3, v8
-; RV32-NEXT:    add a0, a0, a1
-; RV32-NEXT:    add a2, a2, a3
-; RV32-NEXT:    add a0, a0, a2
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: reduce_sum_16xi32_prefix4:
-; RV64:       # %bb.0:
-; RV64-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; RV64-NEXT:    vle32.v v8, (a0)
-; RV64-NEXT:    vmv.x.s a0, v8
-; RV64-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV64-NEXT:    vslidedown.vi v9, v8, 1
-; RV64-NEXT:    vmv.x.s a1, v9
-; RV64-NEXT:    vslidedown.vi v9, v8, 2
-; RV64-NEXT:    vmv.x.s a2, v9
-; RV64-NEXT:    vslidedown.vi v8, v8, 3
-; RV64-NEXT:    vmv.x.s a3, v8
-; RV64-NEXT:    add a0, a0, a1
-; RV64-NEXT:    add a2, a2, a3
-; RV64-NEXT:    addw a0, a0, a2
-; RV64-NEXT:    ret
+; CHECK-LABEL: reduce_sum_16xi32_prefix4:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v9
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
   %v = load <16 x i32>, ptr %p, align 256
   %e0 = extractelement <16 x i32> %v, i32 0
   %e1 = extractelement <16 x i32> %v, i32 1
@@ -413,47 +159,21 @@ define i32 @reduce_sum_16xi32_prefix4(ptr %p) {
 }
 
 define i32 @reduce_sum_16xi32_prefix5(ptr %p) {
-; RV32-LABEL: reduce_sum_16xi32_prefix5:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; RV32-NEXT:    vle32.v v8, (a0)
-; RV32-NEXT:    vmv.x.s a0, v8
-; RV32-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; RV32-NEXT:    vslidedown.vi v10, v8, 1
-; RV32-NEXT:    vmv.x.s a1, v10
-; RV32-NEXT:    vslidedown.vi v10, v8, 2
-; RV32-NEXT:    vmv.x.s a2, v10
-; RV32-NEXT:    vslidedown.vi v10, v8, 3
-; RV32-NEXT:    vmv.x.s a3, v10
-; RV32-NEXT:    vsetivli zero, 1, e32, m2, ta, ma
-; RV32-NEXT:    vslidedown.vi v8, v8, 4
-; RV32-NEXT:    vmv.x.s a4, v8
-; RV32-NEXT:    add a0, a0, a1
-; RV32-NEXT:    add a2, a2, a3
-; RV32-NEXT:    ad...
[truncated]

; RV64-NEXT: add a0, a0, a2
; RV64-NEXT: addw a0, a0, a4
; RV64-NEXT: ret
; CHECK-LABEL: reduce_sum_16xi32_prefix5:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to flag - We could definitely do better on the lowering for illegally typed prefix vectors. I think this makes sense to land as is because the current result is better than the scalar tree. We could explore using either a select with splat(zero) here, a masked reduce, or a vl toggle for the reduce.

; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: vmv.s.x v9, zero
; CHECK-NEXT: vslideup.vi v8, v9, 3
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting an opportunity here. Iif we instead slide up the source vector by one, we could use a vslide1up here with a zero register source, and use the same register as our scalar source on the vredsum. Note sure if this is worth bothering to actually implement or not.

Alternatively, we could use the same lowerings as mentioned below for the prefix5 case.

@github-actions
Copy link

github-actions bot commented Sep 29, 2023

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff c718336c4cb14e27aa88054cc30ee0c2fc6505d1 38a60f057354dc3ab0385114335caf4ba70b9f32 -- llvm/lib/Target/RISCV/RISCVISelLowering.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index fb7ed79c7888..9d6b3b7a77ac 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11200,7 +11200,6 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
-
 // Try to fold (<bop> x, (reduction.<bop> vec, start))
 static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG,
                                     const RISCVSubtarget &Subtarget) {

/// Perform two related transforms whose purpose is to incrementally recognize
/// an explode_vector followed by scalar reduction as a vector reduction node.
/// This exists to recover from a deficiency in SLP which can't handle
/// forrests with multiple roots sharing common nodes. In some cases, one
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forrests -> forests


const SDLoc DL(N);
const EVT VT = N->getValueType(0);
const unsigned Opc = N->getOpcode();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused var in release builds I think?

if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
isNullConstant(ReduceVec.getOperand(1)) &&
isa<ConstantSDNode>(RHS.getOperand(1))) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RHS.getOperand(1) is known to be a ConstantSDNode from line 11151 right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this was addressed. I think this isa<ConstantSDNode>(RHS.getOperand(1)) can be removed since it was already checked earlier.

return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
}

// Match (binop reduce (extract_subvector V, 0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put parentheses around reduce (extract_subvector V, 0)

ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
isNullConstant(ReduceVec.getOperand(1)) &&
isa<ConstantSDNode>(RHS.getOperand(1))) {
uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getZExtValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure you can guarantee Idx fits in 64 bits here. It would definitely be UB if it didn't fit, but the getZExtValue() would assert in that case.

if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
// For illegal types (e.g. 3xi32), most will be combined again into a
// wider (hopefully legal) type. If this is a terminal state, we are
// relying on type legalization here to poduce something reasonable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poduce -> produce

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit f0505c3 into llvm:main Oct 2, 2023
@preames preames deleted the pr-riscv-reduction-tree-via-dag branch October 2, 2023 00:42
preames added a commit to preames/llvm-project that referenced this pull request Oct 2, 2023
This builds on the transform introduced in llvm#67821, and generalizes it for all integer reduction types.

A couple of notes:
* This will only form smax/smin/umax/umin reductions when zbb is enabled.  Otherwise, we lower the min/max expressions early.  I don't care about this case, and don't plan to address this further.
* This excludes floating point.  Floating point introduces concerns about associativity.  I may or may not do a follow up patch for that case.
* The explodevector test change is mildly undesirable from a clarity perspective.  If anyone sees a good way to rewrite that to stablize the test, please suggest.
preames added a commit that referenced this pull request Oct 3, 2023
…68014)

This builds on the transform introduced in
#67821, and generalizes it for
all integer reduction types.

A couple of notes:
* This will only form smax/smin/umax/umin reductions when zbb is
enabled. Otherwise, we lower the min/max expressions early. I don't care
about this case, and don't plan to address this further.
* This excludes floating point. Floating point introduces concerns about
associativity. I may or may not do a follow up patch for that case.
* The explodevector test change is mildly undesirable from a clarity
perspective. If anyone sees a good way to rewrite that to stablize the
test, please suggest.
preames added a commit that referenced this pull request Oct 4, 2023
…68014) (reapply)

This was reverted in 824251c exposed by this change in a previous patch.  Fixed in 199cbec.  Original commit message follows.

This builds on the transform introduced in
#67821, and generalizes it for
all integer reduction types.

A couple of notes:
* This will only form smax/smin/umax/umin reductions when zbb is
enabled. Otherwise, we lower the min/max expressions early. I don't care
about this case, and don't plan to address this further.
* This excludes floating point. Floating point introduces concerns about
associativity. I may or may not do a follow up patch for that case.
* The explodevector test change is mildly undesirable from a clarity
perspective. If anyone sees a good way to rewrite that to stablize the
test, please suggest.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants