Skip to content

Commit a1a3cf2

Browse files
banach-spacepuja2196
authored andcommitted
[mlir][vector] Add a new TD Op for patterns leveraging ShapeCastOp (#110525)
Adds a new Transform Dialect Op that collects patters for dropping unit dims from various Ops: * `transform.apply_patterns.vector.drop_unit_dims_with_shape_cast`. It excludes patterns for vector.transfer Ops - these are collected under: * `apply_patterns.vector.rank_reducing_subview_patterns`, and use ShapeCastOp _and_ SubviewOp to reduce the rank (and to eliminate unit dims). This new TD Ops allows us to test the "ShapeCast folder" pattern in isolation. I've extracted the only test that I could find for that folder from "vector-transforms.mlir" and moved it to a dedicated file: "shape-cast-folder.mlir". I also added a test case with scalable vectors. Changes in VectorTransforms.cpp are not needed (added a comment with a TODO + ordered the patterns alphabetically). I am Including them here to avoid a separate PR.
1 parent 5633631 commit a1a3cf2

File tree

5 files changed

+66
-11
lines changed

5 files changed

+66
-11
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@ def ApplyRankReducingSubviewPatternsOp : Op<Transform_Dialect,
6868
let assemblyFormat = "attr-dict";
6969
}
7070

71+
def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
72+
"apply_patterns.vector.drop_unit_dims_with_shape_cast",
73+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
74+
let description = [{
75+
Apply vector patterns to fold unit dims with vector.shape_cast Ops:
76+
- DropUnitDimFromElementwiseOps
77+
- DropUnitDimsFromScfForOp
78+
- DropUnitDimsFromTransposeOp
79+
80+
Excludes patterns for vector.transfer Ops. This is complemented by
81+
shape_cast folding patterns.
82+
}];
83+
84+
let assemblyFormat = "attr-dict";
85+
}
86+
7187
def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
7288
"apply_patterns.vector.transfer_permutation_patterns",
7389
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
8585
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
8686
}
8787

88+
void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
89+
RewritePatternSet &patterns) {
90+
vector::populateDropUnitDimWithShapeCastPatterns(patterns);
91+
}
92+
8893
void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
8994
RewritePatternSet &patterns) {
9095
vector::populateVectorBitCastLoweringPatterns(patterns);

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,8 +2056,13 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
20562056

20572057
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
20582058
RewritePatternSet &patterns, PatternBenefit benefit) {
2059-
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
2060-
ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
2059+
// TODO: Consider either:
2060+
// * including DropInnerMostUnitDimsTransferRead and
2061+
// DropInnerMostUnitDimsTransferWrite, or
2062+
// * better naming to distinguish this and
2063+
// populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
2064+
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2065+
DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
20612066
patterns.getContext(), benefit);
20622067
}
20632068

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2+
3+
///----------------------------------------------------------------------------------------
4+
/// [Pattern: ShapeCastOpFolder]
5+
///----------------------------------------------------------------------------------------
6+
7+
// CHECK-LABEL: func @fixed_width
8+
// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
9+
// CHECK-NOT: vector.shape_cast
10+
// CHECK: return %[[A0]] : vector<2x4xf32>
11+
func.func @fixed_width(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
12+
%0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
13+
%1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
14+
return %1 : vector<2x4xf32>
15+
}
16+
17+
// CHECK-LABEL: func @scalable
18+
// CHECK-SAME: %[[A0:.*0]]: vector<2x[4]xf32>
19+
// CHECK-NOT: vector.shape_cast
20+
// CHECK: return %[[A0]] : vector<2x[4]xf32>
21+
func.func @scalable(%arg0 : vector<2x[4]xf32>) -> vector<2x[4]xf32> {
22+
%0 = vector.shape_cast %arg0 : vector<2x[4]xf32> to vector<[8]xf32>
23+
%1 = vector.shape_cast %0 : vector<[8]xf32> to vector<2x[4]xf32>
24+
return %1 : vector<2x[4]xf32>
25+
}
26+
27+
// ============================================================================
28+
// TD sequence
29+
// ============================================================================
30+
module attributes {transform.with_named_sequence} {
31+
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
32+
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
33+
transform.apply_patterns to %func_op {
34+
transform.apply_patterns.vector.drop_unit_dims_with_shape_cast
35+
} : !transform.op<"func.func">
36+
transform.yield
37+
}
38+
}

mlir/test/Dialect/Vector/vector-transforms.mlir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,6 @@ func.func @vector_transfers(%arg0: index, %arg1: index) {
184184
return
185185
}
186186

187-
// CHECK-LABEL: func @cancelling_shape_cast_ops
188-
// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
189-
// CHECK: return %[[A0]] : vector<2x4xf32>
190-
func.func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
191-
%0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
192-
%1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
193-
return %1 : vector<2x4xf32>
194-
}
195-
196187
// CHECK-LABEL: func @elementwise_unroll
197188
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>)
198189
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index

0 commit comments

Comments
 (0)