Skip to content

Commit f0505c3

Browse files
authored
[RISCV] Form vredsum from explode_vector + scalar (left) reduce (#67821)
This change adds two related DAG combines which together will take a left-reduce scalar add tree of an explode_vector, and will incrementally form a vector reduction of the vector prefix. If the entire vector is reduced, the result will be a reduction over the entire vector. Profitability wise, this relies on vredsum being cheaper than a pair of extracts and scalar add. Given vredsum is linear in LMUL, and the vslidedown required for the extract is *also* linear in LMUL, this is clearly true at higher index values. At N=2, it's a bit questionable, but I think the vredsum form is probably a better canonical form anyways. Note that this only matches left reduces. This happens to be the motivating example I have (from spec2017 x264). This approach could be generalized to handle right reduces without much effort, and could be generalized to handle any reduce whose tree starts with adjacent elements if desired. The approach fails for a reduce such as (A+C)+(B+D) because we can't find a root to start the reduce with without scanning the entire associative add expression. We could maybe explore using masked reduces for the root node, but that seems of questionable profitability. (As in, worth questioning - I haven't explored in any detail.) This is covering up a deficiency in SLP. If SLP encounters the scalar form of reduce_or(A) + reduce_sum(a) where a is some common vectorizeable tree, SLP will sometimes fail to revisit one of the reductions after vectorizing the other. Fixing this in SLP is hard, and there's no good reason not to handle the easy cases in the backend. Another option here would be to do this in VectorCombine or generic DAG. I chose not to as the profitability of the non-legal typed prefix cases is very target dependent. I think this makes sense as a starting point, even if we move it elsewhere later. This is currently restructed only to add reduces, but obviously makes sense for any associative reduction operator. Once this is approved, I plan to extend it in this manner. I'm simply staging work in case we decide to go in another direction.
1 parent 2375d84 commit f0505c3

File tree

3 files changed

+506
-1228
lines changed

3 files changed

+506
-1228
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

+82
Original file line numberDiff line numberDiff line change
@@ -11122,6 +11122,85 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
1112211122
}
1112311123
}
1112411124

11125+
/// Perform two related transforms whose purpose is to incrementally recognize
11126+
/// an explode_vector followed by scalar reduction as a vector reduction node.
11127+
/// This exists to recover from a deficiency in SLP which can't handle
11128+
/// forests with multiple roots sharing common nodes. In some cases, one
11129+
/// of the trees will be vectorized, and the other will remain (unprofitably)
11130+
/// scalarized.
11131+
static SDValue
11132+
combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
11133+
const RISCVSubtarget &Subtarget) {
11134+
11135+
// This transforms need to run before all integer types have been legalized
11136+
// to i64 (so that the vector element type matches the add type), and while
11137+
// it's safe to introduce odd sized vector types.
11138+
if (DAG.NewNodesMustHaveLegalTypes)
11139+
return SDValue();
11140+
11141+
const SDLoc DL(N);
11142+
const EVT VT = N->getValueType(0);
11143+
[[maybe_unused]] const unsigned Opc = N->getOpcode();
11144+
assert(Opc == ISD::ADD && "extend this to other reduction types");
11145+
const SDValue LHS = N->getOperand(0);
11146+
const SDValue RHS = N->getOperand(1);
11147+
11148+
if (!LHS.hasOneUse() || !RHS.hasOneUse())
11149+
return SDValue();
11150+
11151+
if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
11152+
!isa<ConstantSDNode>(RHS.getOperand(1)))
11153+
return SDValue();
11154+
11155+
SDValue SrcVec = RHS.getOperand(0);
11156+
EVT SrcVecVT = SrcVec.getValueType();
11157+
assert(SrcVecVT.getVectorElementType() == VT);
11158+
if (SrcVecVT.isScalableVector())
11159+
return SDValue();
11160+
11161+
if (SrcVecVT.getScalarSizeInBits() > Subtarget.getELen())
11162+
return SDValue();
11163+
11164+
// match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
11165+
// reduce_op (extract_subvector [2 x VT] from V). This will form the
11166+
// root of our reduction tree. TODO: We could extend this to any two
11167+
// adjacent constant indices if desired.
11168+
if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
11169+
LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
11170+
isOneConstant(RHS.getOperand(1))) {
11171+
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
11172+
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
11173+
DAG.getVectorIdxConstant(0, DL));
11174+
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
11175+
}
11176+
11177+
// Match (binop (reduce (extract_subvector V, 0),
11178+
// (extract_vector_elt V, sizeof(SubVec))))
11179+
// into a reduction of one more element from the original vector V.
11180+
if (LHS.getOpcode() != ISD::VECREDUCE_ADD)
11181+
return SDValue();
11182+
11183+
SDValue ReduceVec = LHS.getOperand(0);
11184+
if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
11185+
ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
11186+
isNullConstant(ReduceVec.getOperand(1))) {
11187+
uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
11188+
if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
11189+
// For illegal types (e.g. 3xi32), most will be combined again into a
11190+
// wider (hopefully legal) type. If this is a terminal state, we are
11191+
// relying on type legalization here to produce something reasonable
11192+
// and this lowering quality could probably be improved. (TODO)
11193+
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
11194+
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
11195+
DAG.getVectorIdxConstant(0, DL));
11196+
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
11197+
}
11198+
}
11199+
11200+
return SDValue();
11201+
}
11202+
11203+
1112511204
// Try to fold (<bop> x, (reduction.<bop> vec, start))
1112611205
static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG,
1112711206
const RISCVSubtarget &Subtarget) {
@@ -11449,6 +11528,9 @@ static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
1144911528
return V;
1145011529
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
1145111530
return V;
11531+
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
11532+
return V;
11533+
1145211534
// fold (add (select lhs, rhs, cc, 0, y), x) ->
1145311535
// (select lhs, rhs, cc, x, (add x, y))
1145411536
return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);

0 commit comments

Comments
 (0)