-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[mlir][Vector] Move vector.extract canonicalizers for DenseElementsAttr to folders #127995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Move vector.extract canonicalizers for DenseElementsAttr to folders #127995
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Kunwar Grover (Groverkss) ChangesThis PR moves vector.extract canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer. This PR is marked NFC, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default. Full diff: https://github.com/llvm/llvm-project/pull/127995.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5f3634377e4c..96ac7fe2fa9e2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
return {};
}
+static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
+ Attribute srcAttr) {
+ auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
+ if (!denseAttr) {
+ return {};
+ }
+
+ if (denseAttr.isSplat()) {
+ Attribute newAttr = denseAttr.getSplatValue<Attribute>();
+ if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
+ newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+ return newAttr;
+ }
+
+ auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType());
+ if (vecTy.isScalable())
+ return {};
+
+ if (extractOp.hasDynamicPosition()) {
+ return {};
+ }
+
+ // Calculate the linearized position of the continuous chunk of elements to
+ // extract.
+ llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
+ copy(extractOp.getStaticPosition(), completePositions.begin());
+ int64_t elemBeginPosition =
+ linearize(completePositions, computeStrides(vecTy.getShape()));
+ auto denseValuesBegin =
+ denseAttr.value_begin<TypedAttr>() + elemBeginPosition;
+
+ TypedAttr newAttr;
+ if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
+ SmallVector<Attribute> elementValues(
+ denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
+ newAttr = DenseElementsAttr::get(resVecTy, elementValues);
+ } else {
+ newAttr = *denseValuesBegin;
+ }
+
+ return newAttr;
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2058,6 +2101,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return res;
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
return res;
+ if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+ return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2121,80 +2166,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
}
};
-// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
-class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- // Return if 'ExtractOp' operand is not defined by a splat vector
- // ConstantOp.
- Value sourceVector = extractOp.getVector();
- Attribute vectorCst;
- if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
- return failure();
- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
- if (!splat)
- return failure();
- TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
- if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
- newAttr = DenseElementsAttr::get(vecDstType, newAttr);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
- return success();
- }
-};
-
-// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
-class ExtractOpNonSplatConstantFolder final
- : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- // TODO: Canonicalization for dynamic position not implemented yet.
- if (extractOp.hasDynamicPosition())
- return failure();
-
- // Return if 'ExtractOp' operand is not defined by a compatible vector
- // ConstantOp.
- Value sourceVector = extractOp.getVector();
- Attribute vectorCst;
- if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
- return failure();
-
- auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
- if (vecTy.isScalable())
- return failure();
-
- // The splat case is handled by `ExtractOpSplatConstantFolder`.
- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
- if (!dense || dense.isSplat())
- return failure();
-
- // Calculate the linearized position of the continuous chunk of elements to
- // extract.
- llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
- copy(extractOp.getStaticPosition(), completePositions.begin());
- int64_t elemBeginPosition =
- linearize(completePositions, computeStrides(vecTy.getShape()));
- auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
-
- TypedAttr newAttr;
- if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
- SmallVector<Attribute> elementValues(
- denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
- newAttr = DenseElementsAttr::get(resVecTy, elementValues);
- } else {
- newAttr = *denseValuesBegin;
- }
-
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
- return success();
- }
-};
-
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
public:
@@ -2332,8 +2303,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e66fbe968d9b0..cd83e1239fdda 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
-
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
// -----
@@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
-// CHECK-DAG: %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
-// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_9]] : tensor<1x4xf32>
// CHECK: }
@@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
-// CHECK-DAG: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
-// CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..8e5ddbfffcdd9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
// -----
// ALL-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar() {
+func.func @test_vector_extract_scalar(%idx : index) {
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
// ALL-NOT: vector.shuffle
// ALL: vector.extract
// ALL-NOT: vector.shuffle
- %0 = vector.extract %cst[0] : i32 from vector<4xi32>
+ %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
return
}
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index c5cb09b9aa9f9..b4ebb14b8829e 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2:
// CHECK-LABEL: func @transfer_write_arith_constant(
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
-// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
-// CHECK: %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
-// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: %[[cst:.*]] = arith.constant 5.000000e+00 : f32
+// CHECK: memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
%cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 20e9400ed698d..5be267c1be984 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -242,33 +242,29 @@ func.func @strided_gather(%base : memref<100x3xf32>,
// CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>,
// CHECK-SAME: %[[VAL_4:.*]]: index,
// CHECK-SAME: %[[VAL_5:.*]]: index) -> vector<4xf32> {
+// CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
-// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
// CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
-// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
-// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>)
+// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
-// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
-// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>)
+// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
-// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
// CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
-// CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>)
+// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
// CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
-// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
-// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
+// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Folders are local, and it's always better to implement a folder than a canonicalizer.
In the past we tried to be cautious of IR size and bail out of canonicalizers if they produced large constants. I wonder if moving to folders raises the bar even higher.
I don't think folders have a higher bar than canonicalizers. A folder is just a canonicalizer that is local. I think what you are raising is valid point, but it's not related to this patch, as this patch is just implementing the same canonicalization in a more restricted fashion (local transformations only). I think for InsertOp folders, we do have a fold limit. For ExtractOp, I don't think we need this limit. I've explained my reasoning in a comment. |
This is re-wiring things in a non-trivial way. I suggest dropping NFC. Will take a proper look later. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Sorry, I missed some comments. Removing the approval to give more time for discussion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code changes LGTM, thanks!
That said, I don’t quite follow this from the summary:
Folders are local, and it's always better to implement a folder than a canonicalizer.
Why? This is a rather strong statement, and I don’t see any clear indication of this in our documentation: MLIR Canonicalization Docs
I have no objections to landing this, but let’s ensure we have a valid reason for this preference. Otherwise, it feels somewhat arbitrary.
I don't think it's a very strong statement, but a natural interpretation of the docs. I'm happy to update the docs, but let me explain my interpretation. From the docs on canonicalization:
Based on the docs, folders are a restricted way of implementing canonicalizations, where the transformation is "local" (as defined in the docs). The other way is to implement a RewritePattern, which can change the IR in any way a pattern can as long as it converges. Folders have an additional property, which can be used by rewrite drivers (dialect conversion, greedy rewrite driver) to do canonicalizations on the fly, because folding cannot create new operations (so no extra operations to run the rewrite driver on, adding only a constant cost for the rewrite driver, instead of the rewrite driver having to run on the newly created ops). But, in the end, both are canonicalizations (as defined by the docs). A folder or a RewritePattern is just an implementation detail. If a transformation should be a canonicalization (this is a harder question that needs proof) is a different question than if it should be a folder or a rewrite pattern (this is an easier question, it's an implementation detail). The choice between implementing a canonicalization as a folder or a rewrite pattern comes down to if the canonicalization is local or not. The canonicalization in this PR is local, which is why it's better to implement it as a folder. That said, I think my original statement should be changed to: "Folders are local, and it's always better to implement a folder than a canonicalizer pattern" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -2033,20 +2033,71 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, | |||
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, | |||
ArrayRef<int64_t> staticPos, | |||
int64_t poisonVal) { | |||
if (!llvm::is_contained(staticPos, poisonVal)) | |||
if (!is_contained(staticPos, poisonVal)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'd undo this so that we don't depend on ADL.
Unlike the cast functions and container types like SmallVector/StringRef/ArrayRef/etc, we don't have a using declaration in the mlir namespace for functions from STLExtras.h.
Thanks for the reply!
IIUC, your interpretation of the docs is something like this (please correct me if I’m misinterpreting your comment):
I don’t see any such indication in the docs. This is also why I requested that you remove "NFC" from the description - thanks for doing that!
I think that would be really helpful! It would save us this discussion and also provide clear guidance for our future selves and other contributors. I’m happy for this to be merged (thanks for working on it and for addressing my comments!), but I’d appreciate it if you could follow up with a docs update - assuming I correctly interpreted your interpretation. 😅 Thanks, Kunwar! |
I'm not sure if I agree with this generalization either. For example, we'd expect folders to be invoked much more frequently in the compilation pipeline (and often sometimes without opt-out like in dialect conversion), so I'd be much more wary of spending compilation time on folds that are either costly or very unlikely to apply. Downstream compilers can decide how often to run canon patterns based on the amount of compilation time budget. |
Costly transformations shouldn't be canonicalizations at all (from the docs):
Regarding "very unlikely to apply", I'm not sure without benchmarking a transformation it's really easy to say if the cost of folding is big. If the cost of folding is infact significant even after not having expensive transformations running as folders, then folding should just be disabled in the rewrite driver and added selectively as patterns (i.e. call the folder as a pattern). I've never heard of folding being a significant cause of compile-time regressions. |
IIRC @jpienaar reported folding as a significant contributor of total compilation times here: #104649 (comment). One solution is to run is less frequently, the other is to make it cheaper, and we probably should aim for both. |
This is how we approached it in the past indeed. |
Yes, your interpretation of what I was trying to say is correct. I will send a PR to update the documentation. Let's continue the discussion on if there is consensus on always opting for a folder in that PR. For this PR, I think I'm going to land this, unless someone has strong objections to it (which I'm assuming there aren't because the objections seem to be on the general statement, which we can discuss in the docs PR). This PR is one of the final pieces needed to finally remove vector.extractelement/vector.insertelement from Vector dialect :D |
No objections from me, I only said I was not 100% convinced by that last statement, I think there's place for a little bit more nuance. |
This PR moves vector.extract canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalization pattern.
This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.
There is also a test change which makes the indices of a vector.extract test dynamic. This is so that it doesn't fold away after this pr.