Skip to content

[mlir][Vector] Move vector.extract canonicalizers for DenseElementsAttr to folders #127995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 26, 2025

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented Feb 20, 2025

This PR moves vector.extract canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalization pattern.

This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.

There is also a test change which makes the indices of a vector.extract test dynamic. This is so that it doesn't fold away after this pr.

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-vector

Author: Kunwar Grover (Groverkss)

Changes

This PR moves vector.extract canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer.

This PR is marked NFC, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.


Full diff: https://github.com/llvm/llvm-project/pull/127995.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+46-76)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+7-19)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir (+2-3)
  • (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+5-9)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5f3634377e4c..96ac7fe2fa9e2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
   return {};
 }
 
+static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
+                                                   Attribute srcAttr) {
+  auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
+  if (!denseAttr) {
+    return {};
+  }
+
+  if (denseAttr.isSplat()) {
+    Attribute newAttr = denseAttr.getSplatValue<Attribute>();
+    if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
+      newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+    return newAttr;
+  }
+
+  auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType());
+  if (vecTy.isScalable())
+    return {};
+
+  if (extractOp.hasDynamicPosition()) {
+    return {};
+  }
+
+  // Calculate the linearized position of the continuous chunk of elements to
+  // extract.
+  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
+  copy(extractOp.getStaticPosition(), completePositions.begin());
+  int64_t elemBeginPosition =
+      linearize(completePositions, computeStrides(vecTy.getShape()));
+  auto denseValuesBegin =
+      denseAttr.value_begin<TypedAttr>() + elemBeginPosition;
+
+  TypedAttr newAttr;
+  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
+    SmallVector<Attribute> elementValues(
+        denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
+    newAttr = DenseElementsAttr::get(resVecTy, elementValues);
+  } else {
+    newAttr = *denseValuesBegin;
+  }
+
+  return newAttr;
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2058,6 +2101,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
     return res;
   if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
     return res;
+  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+    return res;
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
   if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2121,80 +2166,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   }
 };
 
-// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
-class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractOp' operand is not defined by a splat vector
-    // ConstantOp.
-    Value sourceVector = extractOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-    auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
-    if (!splat)
-      return failure();
-    TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
-    if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
-      newAttr = DenseElementsAttr::get(vecDstType, newAttr);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
-    return success();
-  }
-};
-
-// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
-class ExtractOpNonSplatConstantFolder final
-    : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    // TODO: Canonicalization for dynamic position not implemented yet.
-    if (extractOp.hasDynamicPosition())
-      return failure();
-
-    // Return if 'ExtractOp' operand is not defined by a compatible vector
-    // ConstantOp.
-    Value sourceVector = extractOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
-    if (vecTy.isScalable())
-      return failure();
-
-    // The splat case is handled by `ExtractOpSplatConstantFolder`.
-    auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
-    if (!dense || dense.isSplat())
-      return failure();
-
-    // Calculate the linearized position of the continuous chunk of elements to
-    // extract.
-    llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
-    copy(extractOp.getStaticPosition(), completePositions.begin());
-    int64_t elemBeginPosition =
-        linearize(completePositions, computeStrides(vecTy.getShape()));
-    auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
-
-    TypedAttr newAttr;
-    if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
-      SmallVector<Attribute> elementValues(
-          denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
-      newAttr = DenseElementsAttr::get(resVecTy, elementValues);
-    } else {
-      newAttr = *denseValuesBegin;
-    }
-
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
-    return success();
-  }
-};
-
 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
 public:
@@ -2332,8 +2303,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
-              ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e66fbe968d9b0..cd83e1239fdda 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
 
 // CHECK-DAG:  %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:  %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
-// CHECK-DAG:  %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
 
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
-
-// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
 // CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
 
  // -----
@@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
 // CHECK-SAME:                                                                 %[[VAL_0:.*]]: tensor<80x16xf32>,
 // CHECK-SAME:                                                                 %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-
-// CHECK-DAG:       %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
-// CHECK-DAG:       %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
-// CHECK-DAG:       %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 
-// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_9]] : tensor<1x4xf32>
 // CHECK:         }
 
@@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32
 // CHECK-DAG:       %[[PAD:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
-// CHECK-DAG:       %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
-// CHECK:           %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
 // CHECK:           %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
 // CHECK:           %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
 
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..8e5ddbfffcdd9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
 // -----
 
 // ALL-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar() {
+func.func @test_vector_extract_scalar(%idx : index) {
   %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
   // ALL-NOT: vector.shuffle
   // ALL:     vector.extract
   // ALL-NOT: vector.shuffle
-  %0 = vector.extract %cst[0] : i32 from vector<4xi32>
+  %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index c5cb09b9aa9f9..b4ebb14b8829e 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2:
 
 // CHECK-LABEL: func @transfer_write_arith_constant(
 //  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
-//       CHECK:   %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
-//       CHECK:   %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
-//       CHECK:   memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   %[[cst:.*]] = arith.constant 5.000000e+00 : f32
+//       CHECK:   memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
 func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
   %cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
   vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 20e9400ed698d..5be267c1be984 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -242,33 +242,29 @@ func.func @strided_gather(%base : memref<100x3xf32>,
 // CHECK-SAME:                         %[[IDXS:.*]]: vector<4xindex>,
 // CHECK-SAME:                         %[[VAL_4:.*]]: index,
 // CHECK-SAME:                         %[[VAL_5:.*]]: index) -> vector<4xf32> {
+// CHECK:           %[[TRUE:.*]] = arith.constant true
 // CHECK:           %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
-// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
 
 // CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
 // CHECK:           %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
 
-// CHECK:           %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
 // CHECK:           %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
-// CHECK:           scf.if %[[MASK_0]] -> (vector<4xf32>)
+// CHECK:           scf.if %[[TRUE]] -> (vector<4xf32>)
 // CHECK:             %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
 // CHECK:             %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
 
-// CHECK:           %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
 // CHECK:           %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
-// CHECK:           scf.if %[[MASK_1]] -> (vector<4xf32>)
+// CHECK:           scf.if %[[TRUE]] -> (vector<4xf32>)
 // CHECK:             %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
 // CHECK:             %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
 
-// CHECK:           %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
 // CHECK:           %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
-// CHECK:           scf.if %[[MASK_2]] -> (vector<4xf32>)
+// CHECK:           scf.if %[[TRUE]] -> (vector<4xf32>)
 // CHECK:             %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
 // CHECK:             %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
 
-// CHECK:           %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
 // CHECK:           %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
-// CHECK:           scf.if %[[MASK_3]] -> (vector<4xf32>)
+// CHECK:           scf.if %[[TRUE]] -> (vector<4xf32>)
 // CHECK:             %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
 // CHECK:             %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
 

@Groverkss Groverkss requested a review from kuhar February 20, 2025 12:03
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Folders are local, and it's always better to implement a folder than a canonicalizer.

In the past we tried to be cautious of IR size and bail out of canonicalizers if they produced large constants. I wonder if moving to folders raises the bar even higher.

@Groverkss
Copy link
Member Author

Groverkss commented Feb 20, 2025

Folders are local, and it's always better to implement a folder than a canonicalizer.

In the past we tried to be cautious of IR size and bail out of canonicalizers if they produced large constants. I wonder if moving to folders raises the bar even higher.

I don't think folders have a higher bar than canonicalizers. A folder is just a canonicalizer that is local. I think what you are raising is valid point, but it's not related to this patch, as this patch is just implementing the same canonicalization in a more restricted fashion (local transformations only).

I think for InsertOp folders, we do have a fold limit. For ExtractOp, I don't think we need this limit. I've explained my reasoning in a comment.

@Groverkss Groverkss changed the title [mlir][Vector][NFC] Move canonicalizers for DenseElementsAttr to folders [mlir][Vector][NFC] Move vector.extract canonicalizers for DenseElementsAttr to folders Feb 20, 2025
@banach-space
Copy link
Contributor

This PR is marked NFC, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.

This is re-wiring things in a non-trivial way. I suggest dropping NFC. Will take a proper look later. Thanks!

@Groverkss Groverkss changed the title [mlir][Vector][NFC] Move vector.extract canonicalizers for DenseElementsAttr to folders [mlir][Vector] Move vector.extract canonicalizers for DenseElementsAttr to folders Feb 20, 2025
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@dcaballe
Copy link
Contributor

Sorry, I missed some comments. Removing the approval to give more time for discussion

@dcaballe dcaballe self-requested a review February 20, 2025 21:39
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code changes LGTM, thanks!

That said, I don’t quite follow this from the summary:

Folders are local, and it's always better to implement a folder than a canonicalizer.

Why? This is a rather strong statement, and I don’t see any clear indication of this in our documentation: MLIR Canonicalization Docs

I have no objections to landing this, but let’s ensure we have a valid reason for this preference. Otherwise, it feels somewhat arbitrary.

@Groverkss
Copy link
Member Author

Groverkss commented Feb 25, 2025

The code changes LGTM, thanks!

That said, I don’t quite follow this from the summary:

Folders are local, and it's always better to implement a folder than a canonicalizer.

Why? This is a rather strong statement, and I don’t see any clear indication of this in our documentation: MLIR Canonicalization Docs

I have no objections to landing this, but let’s ensure we have a valid reason for this preference. Otherwise, it feels somewhat arbitrary.

I don't think it's a very strong statement, but a natural interpretation of the docs. I'm happy to update the docs, but let me explain my interpretation.

From the docs on canonicalization:

The fold mechanism is an intentionally limited, but powerful mechanism that
allows for applying canonicalizations in many places throughout the compiler.
For example, outside of the canonicalizer pass, fold is used within the
[dialect conversion
infrastructure](https://mlir.llvm.org/docs/DialectConversion/) as a
legalization mechanism, and can be invoked directly anywhere with an OpBuilder
via OpBuilder::createOrFold.

fold has the restriction that no new operations may be created, and only the
root operation may be replaced (but not erased). It allows for updating an
operation in-place, or returning a set of pre-existing values (or attributes)
to replace the operation with. This ensures that the fold method is a truly
“local” transformation, and can be invoked without the need for a pattern
rewriter.

Based on the docs, folders are a restricted way of implementing canonicalizations, where the transformation is "local" (as defined in the docs). The other way is to implement a RewritePattern, which can change the IR in any way a pattern can as long as it converges. Folders have an additional property, which can be used by rewrite drivers (dialect conversion, greedy rewrite driver) to do canonicalizations on the fly, because folding cannot create new operations (so no extra operations to run the rewrite driver on, adding only a constant cost for the rewrite driver, instead of the rewrite driver having to run on the newly created ops).

But, in the end, both are canonicalizations (as defined by the docs). A folder or a RewritePattern is just an implementation detail. If a transformation should be a canonicalization (this is a harder question that needs proof) is a different question than if it should be a folder or a rewrite pattern (this is an easier question, it's an implementation detail). The choice between implementing a canonicalization as a folder or a rewrite pattern comes down to if the canonicalization is local or not. The canonicalization in this PR is local, which is why it's better to implement it as a folder.

That said, I think my original statement should be changed to:

"Folders are local, and it's always better to implement a folder than a canonicalizer pattern"

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -2033,20 +2033,71 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
ArrayRef<int64_t> staticPos,
int64_t poisonVal) {
if (!llvm::is_contained(staticPos, poisonVal))
if (!is_contained(staticPos, poisonVal))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd undo this so that we don't depend on ADL.
Unlike the cast functions and container types like SmallVector/StringRef/ArrayRef/etc, we don't have a using declaration in the mlir namespace for functions from STLExtras.h.

@banach-space
Copy link
Contributor

Thanks for the reply!

The choice between implementing a canonicalization as a folder or a rewrite pattern comes down to if the canonicalization is local or not. The canonicalization in this PR is local, which is why it's better to implement it as a folder.

IIUC, your interpretation of the docs is something like this (please correct me if I’m misinterpreting your comment):

  • "When you have two options (canonicalization pattern vs. folder), always opt for a folder."

I don’t see any such indication in the docs. This is also why I requested that you remove "NFC" from the description - thanks for doing that!

I'm happy to update the docs.

I think that would be really helpful! It would save us this discussion and also provide clear guidance for our future selves and other contributors.

I’m happy for this to be merged (thanks for working on it and for addressing my comments!), but I’d appreciate it if you could follow up with a docs update - assuming I correctly interpreted your interpretation. 😅

Thanks, Kunwar!

@kuhar
Copy link
Member

kuhar commented Feb 25, 2025

I'm not sure if I agree with this generalization either. For example, we'd expect folders to be invoked much more frequently in the compilation pipeline (and often sometimes without opt-out like in dialect conversion), so I'd be much more wary of spending compilation time on folds that are either costly or very unlikely to apply. Downstream compilers can decide how often to run canon patterns based on the amount of compilation time budget.

@Groverkss
Copy link
Member Author

I'm not sure if I agree with this generalization either. For example, we'd expect folders to be invoked much more frequently in the compilation pipeline (and often sometimes without opt-out like in dialect conversion), so I'd be much more wary of spending compilation time on folds that are either costly or very unlikely to apply. Downstream compilers can decide how often to run canon patterns based on the amount of compilation time budget.

Costly transformations shouldn't be canonicalizations at all (from the docs):

Pattens with expensive running time (i.e. have O(n) complexity) or complicated
cost models don’t belong to canonicalization: since the algorithm is executed
iteratively until fixed-point we want patterns that execute quickly (in
particular their matching phase).

Regarding "very unlikely to apply", I'm not sure without benchmarking a transformation it's really easy to say if the cost of folding is big. If the cost of folding is infact significant even after not having expensive transformations running as folders, then folding should just be disabled in the rewrite driver and added selectively as patterns (i.e. call the folder as a pattern). I've never heard of folding being a significant cause of compile-time regressions.

@kuhar
Copy link
Member

kuhar commented Feb 25, 2025

I've never heard of folding being a significant cause of compile-time regressions.

IIRC @jpienaar reported folding as a significant contributor of total compilation times here: #104649 (comment). One solution is to run is less frequently, the other is to make it cheaper, and we probably should aim for both.

@joker-eph
Copy link
Collaborator

"When you have two options (canonicalization pattern vs. folder), always opt for a folder."

This is how we approached it in the past indeed.

@Groverkss
Copy link
Member Author

Thanks for the reply!

The choice between implementing a canonicalization as a folder or a rewrite pattern comes down to if the canonicalization is local or not. The canonicalization in this PR is local, which is why it's better to implement it as a folder.

IIUC, your interpretation of the docs is something like this (please correct me if I’m misinterpreting your comment):

  • "When you have two options (canonicalization pattern vs. folder), always opt for a folder."

I don’t see any such indication in the docs. This is also why I requested that you remove "NFC" from the description - thanks for doing that!

I'm happy to update the docs.

I think that would be really helpful! It would save us this discussion and also provide clear guidance for our future selves and other contributors.

I’m happy for this to be merged (thanks for working on it and for addressing my comments!), but I’d appreciate it if you could follow up with a docs update - assuming I correctly interpreted your interpretation. 😅

Thanks, Kunwar!

Yes, your interpretation of what I was trying to say is correct. I will send a PR to update the documentation. Let's continue the discussion on if there is consensus on always opting for a folder in that PR.

For this PR, I think I'm going to land this, unless someone has strong objections to it (which I'm assuming there aren't because the objections seem to be on the general statement, which we can discuss in the docs PR). This PR is one of the final pieces needed to finally remove vector.extractelement/vector.insertelement from Vector dialect :D

@kuhar
Copy link
Member

kuhar commented Feb 25, 2025

(which I'm assuming there aren't because the objections seem to be on the general statement, which we can discuss in the docs PR)

No objections from me, I only said I was not 100% convinced by that last statement, I think there's place for a little bit more nuance.

@Groverkss Groverkss merged commit 98542a3 into llvm:main Feb 26, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants