Skip to content

Commit 03529b9

Browse files
authored
[mlir][linalg] Add support for vectorizing dynamic elementwise named ops (#71454)
We are able to vectorize them in linalg.generic form. We just need to relax the condition, so it can also vectorize named ops.
1 parent d2361b2 commit 03529b9

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -1465,9 +1465,11 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
14651465
}
14661466

14671467
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
1468-
// TODO: Masking only supports dynamic generic ops for now.
1469-
if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp,
1470-
linalg::ContractionOpInterface>(op.getOperation()))
1468+
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1469+
// linalg.copy ops and ops that implement ContractionOpInterface for now.
1470+
if (!isElementwise(op) &&
1471+
!isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1472+
op.getOperation()))
14711473
return failure();
14721474

14731475
LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");

mlir/test/Dialect/Linalg/vectorization.mlir

+28
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,34 @@ module attributes {transform.with_named_sequence} {
368368

369369
// -----
370370

371+
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)>
372+
// CHECK: func @test_masked_vectorize_linalg_transpose
373+
func.func @test_masked_vectorize_linalg_transpose(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
374+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
375+
// CHECK: %[[D0:.*]] = tensor.dim %arg0, %[[C0]]
376+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
377+
// CHECK: %[[D1:.*]] = tensor.dim %arg0, %[[C1]]
378+
// CHECK: %[[MASK0:.*]] = vector.create_mask %[[D0]], %[[D1]]
379+
// CHECK: %[[LOAD:.*]] = vector.mask %[[MASK0]] { vector.transfer_read %arg0{{.+}} }
380+
// CHECK-SAME: vector<2x4xi1> -> vector<2x4xf32>
381+
// CHECK: %[[MASK1:.*]] = vector.create_mask %[[D1]], %[[D0]]
382+
// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK1]] { vector.transfer_write %[[LOAD]], %arg1{{.+}} permutation_map = #[[MAP]]{{.+}} }
383+
// CHECK-SAME: vector<4x2xi1> -> tensor<?x?xf32>
384+
// CHECK: return %[[WRITE]]
385+
%0 = linalg.transpose ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
386+
return %0 : tensor<?x?xf32>
387+
}
388+
389+
module attributes {transform.with_named_sequence} {
390+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
391+
%0 = transform.structured.match ops{["linalg.transpose"]} in %arg1 : (!transform.any_op) -> !transform.any_op
392+
transform.structured.vectorize %0 vector_sizes [2, 4] : !transform.any_op
393+
transform.yield
394+
}
395+
}
396+
397+
// -----
398+
371399
// CHECK-LABEL: func @test_masked_vectorize_linalg_copy
372400
func.func @test_masked_vectorize_linalg_copy(%A : memref<?x?xf32>, %B : memref<?x?xf32>) {
373401
// CHECK: %[[c0:.*]] = arith.constant 0 : index

0 commit comments

Comments
 (0)