-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][MemRef] Add a pattern to simplify `extract_strided_metadata(ca… #68291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…st)` `expand-strided-metadata` was missing a pattern to get rid of `memref.cast`. The pattern is straight foward: Produce a new `extract_strided_metadata` with the source of the cast and fold the static information (sizes, strides, offset) along the way.
This fixes iree-org/iree#15076 |
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Changes…st)`
Full diff: https://github.com/llvm/llvm-project/pull/68291.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 672ef3eb4cd50fa..4f3fa6a5ed245f8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -870,6 +870,92 @@ class ExtractStridedMetadataOpReinterpretCastFolder
}
};
+/// Replace `base, offset, sizes, strides =
+/// extract_strided_metadata(
+/// cast(src) to dstTy)`
+/// With
+/// ```
+/// base, ... = extract_strided_metadata(src)
+/// offset = !dstTy.srcOffset.isDynamic()?
+/// dstTy.srcOffset :
+/// extract_strided_metadata(src).offset
+/// sizes = for each srcSize in dstTy.srcSizes:
+/// !srcSize.isDynamic()
+/// ? srcSize
+// : extract_strided_metadata(src).sizes[i]
+/// strides = for each srcStride in dstTy.srcStrides:
+/// !srcStrides.isDynamic()
+/// ? srcStrides
+/// : extract_strided_metadata(src).strides[i]
+/// ```
+///
+/// In other words, consume the `cast` and apply its effects
+/// on the offset, sizes, and strides or compute them directly from `src`.
+class ExtractStridedMetadataOpCastFolder
+ : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+ PatternRewriter &rewriter) const override {
+ Value source = extractStridedMetadataOp.getSource();
+ auto castOp = source.getDefiningOp<memref::CastOp>();
+ if (!castOp)
+ return failure();
+
+ Location loc = extractStridedMetadataOp.getLoc();
+ // Check if the source is suitable for extract_strided_metadata.
+ SmallVector<Type> inferredReturnTypes;
+ if (failed(extractStridedMetadataOp.inferReturnTypes(
+ rewriter.getContext(), loc, {castOp.getSource()},
+ /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
+ inferredReturnTypes)))
+ return rewriter.notifyMatchFailure(castOp,
+ "cast source's type is incompatible");
+
+ auto memrefType = cast<MemRefType>(source.getType());
+ unsigned rank = memrefType.getRank();
+ SmallVector<OpFoldResult> results;
+ results.resize_for_overwrite(rank * 2 + 2);
+
+ auto newExtractStridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc,
+ castOp.getSource());
+
+ // Register the base_buffer.
+ results[0] = newExtractStridedMetadata.getBaseBuffer();
+
+ auto getConstantOrValue = [&rewriter](int64_t constant,
+ OpFoldResult ofr) -> OpFoldResult {
+ return !ShapedType::isDynamic(constant)
+ ? OpFoldResult(rewriter.getIndexAttr(constant))
+ : ofr;
+ };
+
+ auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
+ assert(sourceStrides.size() == rank && "unexpected number of strides");
+
+ // Register the new offset.
+ results[1] =
+ getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
+
+ const unsigned sizeStartIdx = 2;
+ const unsigned strideStartIdx = sizeStartIdx + rank;
+ ArrayRef<int64_t> sourceSizes = memrefType.getShape();
+
+ SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
+ SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
+ for (unsigned i = 0; i < rank; ++i) {
+ results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
+ results[strideStartIdx + i] =
+ getConstantOrValue(sourceStrides[i], strides[i]);
+ }
+ rewriter.replaceOp(extractStridedMetadataOp,
+ getValueOrCreateConstantIndexOp(rewriter, loc, results));
+ return success();
+ }
+};
+
/// Replace `base, offset =
/// extract_strided_metadata(extract_strided_metadata(src)#0)`
/// With
@@ -911,6 +997,7 @@ void memref::populateExpandStridedMetadataPatterns(
ExtractStridedMetadataOpGetGlobalFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
+ ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
@@ -923,6 +1010,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
+ ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index a6303aa2d971106..4efb38abcd7679c 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1369,3 +1369,127 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
memref<i32>, index, index, index, index, index
}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of cast
+// when the source of the cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// When we apply the transformation the resulting offset, sizes and strides
+// should come straight from the inputs of the cast.
+// Additionally the folder on extract_strided_metadata should propagate the
+// static information.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast
+// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[?, ?], offset: ?>>)
+//
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[DYN_STRIDES]]#1
+func.func @extract_strided_metadata_of_cast(
+ %arg : memref<3x?xi32, strided<[?, ?], offset:?>>)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.cast %arg :
+ memref<3x?xi32, strided<[?, ?], offset: ?>> to
+ memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of cast
+// when the source of the cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// Same as extract_strided_metadata_of_cast but with constant sizes and strides
+// in the destination type.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
+// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+//
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
+// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
+// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
+func.func @extract_strided_metadata_of_cast_w_csts(
+ %arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.cast %arg :
+ memref<?x?xi32, strided<[?, ?], offset: ?>> to
+ memref<4x?xi32, strided<[?, 18], offset: 25>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
+// -----
+
+// Check that we don't simplify extract_strided_metadata of
+// cast when the source of the cast is unranked.
+// Unranked memrefs cannot feed into extract_strided_metadata operations.
+// Note: Technically we could still fold the sizes and strides.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
+// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
+//
+// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
+//
+// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
+func.func @extract_strided_metadata_of_cast_unranked(
+ %arg : memref<*xi32>)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.cast %arg :
+ memref<*xi32> to
+ memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
|
Can you please add a test that casts constant strides to dynamic strides? |
Sure, I've modified |
No, this is great! Thank you :) LGTM |
…st)`
expand-strided-metadata
was missing a pattern to get rid ofmemref.cast
.The pattern is straight foward:
Produce a new
extract_strided_metadata
with the source of the cast and fold the static information (sizes, strides, offset) along the way.