Skip to content

[mlir][MemRef] Add a pattern to simplify `extract_strided_metadata(ca… #68291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 5, 2023

Conversation

qcolombet
Copy link
Collaborator

…st)`

expand-strided-metadata was missing a pattern to get rid of memref.cast.
The pattern is straight foward:
Produce a new extract_strided_metadata with the source of the cast and fold the static information (sizes, strides, offset) along the way.

…st)`

`expand-strided-metadata` was missing a pattern to get rid of
`memref.cast`.
The pattern is straight foward:
Produce a new `extract_strided_metadata` with the source of the cast and fold
the static information (sizes, strides, offset) along the way.
@qcolombet
Copy link
Collaborator Author

This fixes iree-org/iree#15076

@llvmbot
Copy link
Member

llvmbot commented Oct 5, 2023

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Changes

…st)`

expand-strided-metadata was missing a pattern to get rid of memref.cast.
The pattern is straight foward:
Produce a new extract_strided_metadata with the source of the cast and fold the static information (sizes, strides, offset) along the way.


Full diff: https://github.com/llvm/llvm-project/pull/68291.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+88)
  • (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+124)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 672ef3eb4cd50fa..4f3fa6a5ed245f8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -870,6 +870,92 @@ class ExtractStridedMetadataOpReinterpretCastFolder
   }
 };
 
+/// Replace `base, offset, sizes, strides =
+///              extract_strided_metadata(
+///                 cast(src) to dstTy)`
+/// With
+/// ```
+/// base, ... = extract_strided_metadata(src)
+/// offset = !dstTy.srcOffset.isDynamic()?
+///            dstTy.srcOffset :
+///            extract_strided_metadata(src).offset
+/// sizes = for each srcSize in dstTy.srcSizes:
+///           !srcSize.isDynamic()
+///             ? srcSize
+//              : extract_strided_metadata(src).sizes[i]
+/// strides = for each srcStride in dstTy.srcStrides:
+///             !srcStrides.isDynamic()
+///               ? srcStrides
+///               : extract_strided_metadata(src).strides[i]
+/// ```
+///
+/// In other words, consume the `cast` and apply its effects
+/// on the offset, sizes, and strides or compute them directly from `src`.
+class ExtractStridedMetadataOpCastFolder
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+                  PatternRewriter &rewriter) const override {
+    Value source = extractStridedMetadataOp.getSource();
+    auto castOp = source.getDefiningOp<memref::CastOp>();
+    if (!castOp)
+      return failure();
+
+    Location loc = extractStridedMetadataOp.getLoc();
+    // Check if the source is suitable for extract_strided_metadata.
+    SmallVector<Type> inferredReturnTypes;
+    if (failed(extractStridedMetadataOp.inferReturnTypes(
+            rewriter.getContext(), loc, {castOp.getSource()},
+            /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
+            inferredReturnTypes)))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "cast source's type is incompatible");
+
+    auto memrefType = cast<MemRefType>(source.getType());
+    unsigned rank = memrefType.getRank();
+    SmallVector<OpFoldResult> results;
+    results.resize_for_overwrite(rank * 2 + 2);
+
+    auto newExtractStridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc,
+                                                          castOp.getSource());
+
+    // Register the base_buffer.
+    results[0] = newExtractStridedMetadata.getBaseBuffer();
+
+    auto getConstantOrValue = [&rewriter](int64_t constant,
+                                          OpFoldResult ofr) -> OpFoldResult {
+      return !ShapedType::isDynamic(constant)
+                 ? OpFoldResult(rewriter.getIndexAttr(constant))
+                 : ofr;
+    };
+
+    auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
+    assert(sourceStrides.size() == rank && "unexpected number of strides");
+
+    // Register the new offset.
+    results[1] =
+        getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
+
+    const unsigned sizeStartIdx = 2;
+    const unsigned strideStartIdx = sizeStartIdx + rank;
+    ArrayRef<int64_t> sourceSizes = memrefType.getShape();
+
+    SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
+    SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
+    for (unsigned i = 0; i < rank; ++i) {
+      results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
+      results[strideStartIdx + i] =
+          getConstantOrValue(sourceStrides[i], strides[i]);
+    }
+    rewriter.replaceOp(extractStridedMetadataOp,
+                       getValueOrCreateConstantIndexOp(rewriter, loc, results));
+    return success();
+  }
+};
+
 /// Replace `base, offset =
 ///            extract_strided_metadata(extract_strided_metadata(src)#0)`
 /// With
@@ -911,6 +997,7 @@ void memref::populateExpandStridedMetadataPatterns(
                ExtractStridedMetadataOpGetGlobalFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
+               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
       patterns.getContext());
 }
@@ -923,6 +1010,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
+               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
       patterns.getContext());
 }
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index a6303aa2d971106..4efb38abcd7679c 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1369,3 +1369,127 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
   return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
       memref<i32>, index, index, index, index, index
 }
+
+// -----
+
+// Check that we simplify extract_strided_metadata of cast
+// when the source of the cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// When we apply the transformation the resulting offset, sizes and strides
+// should come straight from the inputs of the cast.
+// Additionally the folder on extract_strided_metadata should propagate the
+// static information.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast
+//  CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[?, ?], offset: ?>>)
+//
+//       CHECK: %[[C3:.*]] = arith.constant 3 : index
+//       CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+//       CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[DYN_STRIDES]]#1
+func.func @extract_strided_metadata_of_cast(
+  %arg : memref<3x?xi32, strided<[?, ?], offset:?>>)
+  -> (memref<i32>, index,
+      index, index,
+      index, index) {
+
+  %cast =
+    memref.cast %arg :
+      memref<3x?xi32, strided<[?, ?], offset: ?>> to
+      memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+  %base, %base_offset, %sizes:2, %strides:2 =
+    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+    -> memref<i32>, index,
+       index, index,
+       index, index
+
+  return %base, %base_offset,
+    %sizes#0, %sizes#1,
+    %strides#0, %strides#1 :
+      memref<i32>, index,
+      index, index,
+      index, index
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of cast
+// when the source of the cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// Same as extract_strided_metadata_of_cast but with constant sizes and strides
+// in the destination type.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
+//  CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+//
+//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
+//   CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
+//       CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+//       CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
+func.func @extract_strided_metadata_of_cast_w_csts(
+  %arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
+  -> (memref<i32>, index,
+      index, index,
+      index, index) {
+
+  %cast =
+    memref.cast %arg :
+      memref<?x?xi32, strided<[?, ?], offset: ?>> to
+      memref<4x?xi32, strided<[?, 18], offset: 25>>
+
+  %base, %base_offset, %sizes:2, %strides:2 =
+    memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
+    -> memref<i32>, index,
+       index, index,
+       index, index
+
+  return %base, %base_offset,
+    %sizes#0, %sizes#1,
+    %strides#0, %strides#1 :
+      memref<i32>, index,
+      index, index,
+      index, index
+}
+// -----
+
+// Check that we don't simplify extract_strided_metadata of
+// cast when the source of the cast is unranked.
+// Unranked memrefs cannot feed into extract_strided_metadata operations.
+// Note: Technically we could still fold the sizes and strides.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
+//  CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
+//
+//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
+//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
+//
+//       CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
+func.func @extract_strided_metadata_of_cast_unranked(
+  %arg : memref<*xi32>)
+  -> (memref<i32>, index,
+      index, index,
+      index, index) {
+
+  %cast =
+    memref.cast %arg :
+      memref<*xi32> to
+      memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+  %base, %base_offset, %sizes:2, %strides:2 =
+    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+    -> memref<i32>, index,
+       index, index,
+       index, index
+
+  return %base, %base_offset,
+    %sizes#0, %sizes#1,
+    %strides#0, %strides#1 :
+      memref<i32>, index,
+      index, index,
+      index, index
+}

@Groverkss
Copy link
Member

Groverkss commented Oct 5, 2023

Can you please add a test that casts constant strides to dynamic strides?

@qcolombet
Copy link
Collaborator Author

@Groverkss

Can you please add a test that casts constant strides to dynamic strides?

Sure, I've modified @extract_strided_metadata_of_cast to cover that.
Let me know if you prefer to keep @extract_strided_metadata_of_cast how it was and have a different test. I don't think that's necessary but I don't have a strong opinion here.

@Groverkss Groverkss self-requested a review October 5, 2023 10:25
@Groverkss
Copy link
Member

Let me know if you prefer to keep @extract_strided_metadata_of_cast

No, this is great! Thank you :)

LGTM

@qcolombet qcolombet merged commit 932dc9d into llvm:main Oct 5, 2023
@qcolombet qcolombet deleted the add_extract_strided_md_cast branch October 5, 2023 12:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants