Skip to content

[mlir][vector] Add support for distributing masked writes #71482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 7, 2023

Conversation

qedawkins
Copy link
Contributor

Because the mask applies to the un-permuted write vector, we can simply distribute the mask identically to the vector, if present.

Because the mask applies to the un-permuted write vector, we can
simply distribute the mask identically to the vector, if present.
@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Quinn Dawkins (qedawkins)

Changes

Because the mask applies to the un-permuted write vector, we can simply distribute the mask identically to the vector, if present.


Full diff: https://github.com/llvm/llvm-project/pull/71482.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+29-9)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+28)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 78015e3deeb967e..bbc28e64bbfd8ac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -406,19 +406,29 @@ struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
                                             WarpExecuteOnLane0Op warpOp,
                                             vector::TransferWriteOp writeOp,
-                                            VectorType targetType) {
+                                            VectorType targetType,
+                                            VectorType maybeMaskType) {
   assert(writeOp->getParentOp() == warpOp &&
          "write must be nested immediately under warp");
   OpBuilder::InsertionGuard g(rewriter);
   SmallVector<size_t> newRetIndices;
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-      rewriter, warpOp, ValueRange{{writeOp.getVector()}},
-      TypeRange{targetType}, newRetIndices);
+  WarpExecuteOnLane0Op newWarpOp;
+  if (maybeMaskType) {
+    newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
+        TypeRange{targetType, maybeMaskType}, newRetIndices);
+  } else {
+    newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, ValueRange{{writeOp.getVector()}},
+        TypeRange{targetType}, newRetIndices);
+  }
   rewriter.setInsertionPointAfter(newWarpOp);
   auto newWriteOp =
       cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
   rewriter.eraseOp(writeOp);
   newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+  if (maybeMaskType)
+    newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
   return newWriteOp;
 }
 
@@ -489,10 +499,18 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     if (!targetType)
       return failure();
 
+    // 2.5 Compute the distributed type for the new mask;
+    VectorType maskType;
+    if (writeOp.getMask()) {
+      maskType =
+          getDistributedType(writeOp.getMask().getType().cast<VectorType>(),
+                             map, warpOp.getWarpSize());
+    }
+
     // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
     // the rest.
     vector::TransferWriteOp newWriteOp =
-        cloneWriteOp(rewriter, warpOp, writeOp, targetType);
+        cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
 
     // 4. Reindex the write using the distribution map.
     auto newWarpOp =
@@ -561,10 +579,6 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
                                 PatternRewriter &rewriter) const override {
-    // Ops with mask not supported yet.
-    if (writeOp.getMask())
-      return failure();
-
     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
     if (!warpOp)
       return failure();
@@ -575,8 +589,10 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
       if (!isMemoryEffectFree(nextOp))
         return failure();
 
+    Value maybeMask = writeOp.getMask();
     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
           return writeOp.getVector() == value ||
+                 (maybeMask && maybeMask == value) ||
                  warpOp.isDefinedOutsideOfRegion(value);
         }))
       return failure();
@@ -584,6 +600,10 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
       return success();
 
+    // Masked writes not supported for extraction.
+    if (writeOp.getMask())
+      return failure();
+
     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
       return success();
 
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 5ec02ce002ffbd6..f050bcd246e5ef7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1253,3 +1253,31 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
 //  CHECK-PROP-SAME: (%{{.+}}: index, %[[SRC:.+]]: memref<4096xf32>, %[[INDEX:.+]]: index)
 //       CHECK-PROP:   %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %cst {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
 //       CHECK-PROP:   return %[[READ]] : vector<1xf32>
+
+// -----
+
+func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.warp_execute_on_lane_0(%laneid)[32] -> () {
+    %mask = "mask_def_0"() : () -> (vector<4096xi1>)
+    %mask2 = "mask_def_1"() : () -> (vector<32xi1>)
+    %0 = "some_def_0"() : () -> (vector<4096xf32>)
+    %1 = "some_def_1"() : () -> (vector<32xf32>)
+    vector.transfer_write %0, %dest[%c0], %mask : vector<4096xf32>, memref<4096xf32>
+    vector.transfer_write %1, %dest[%c0], %mask2 : vector<32xf32>, memref<4096xf32>
+    vector.yield
+  }
+  return
+}
+
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_masked_write(
+//       CHECK-DIST-AND-PROP:   %[[W:.*]]:4 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xi1>, vector<128xf32>, vector<128xi1>) {
+//       CHECK-DIST-AND-PROP:     %[[M0:.*]] = "mask_def_0"
+//       CHECK-DIST-AND-PROP:     %[[M1:.*]] = "mask_def_1"
+//       CHECK-DIST-AND-PROP:     %[[V0:.*]] = "some_def_0"
+//       CHECK-DIST-AND-PROP:     %[[V1:.*]] = "some_def_1"
+//       CHECK-DIST-AND-PROP:     vector.yield %[[V1]], %[[M1]], %[[V0]], %[[M0]]
+//  CHECK-DIST-AND-PROP-SAME:       vector<32xf32>, vector<32xi1>, vector<4096xf32>, vector<4096xi1>
+//       CHECK-DIST-AND-PROP:   }
+//       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
+//       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>

@antiagainst antiagainst requested a review from dcaballe November 7, 2023 05:21
Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure actually this is correct. In my mental model about the masking, it's applied close to the memory, immediately before/after write/read, because that's where masking really happens. So for write, the masking should match with the memory shape, not exactly the vector shape, when we have non-idenity permutation maps. I think we need inverse the map and compose with the vector shape to get the mask shape, and distribute that.. @dcaballe WDYT?

@qedawkins
Copy link
Contributor Author

Hmm, I'm not sure actually this is correct. In my mental model about the masking, it's applied close to the memory, immediately before/after write/read, because that's where masking really happens. So for write, the masking should match with the memory shape, not exactly the vector shape, when we have non-idenity permutation maps. I think we need inverse the map and compose with the vector shape to get the mask shape, and distribute that.. @dcaballe WDYT?

Sorry, I reread the docs for transfer_write and you're right. Moving to a draft for now.

@qedawkins qedawkins marked this pull request as draft November 7, 2023 05:28
@antiagainst
Copy link
Member

Actually given we are only handling 1-D distribution in transfer write at the moment, and the fact we are compressing unused dims when inferTransferOpMaskType, effectively that guarantees we have the same shape as the vector type? I'm a bit confused now; better to get @dcaballe to take another look.

We can enable support step-by-step. The current impl is fine if we additionally check the transfer write has an identity map I think.

@antiagainst
Copy link
Member

The doc in transfer write w.r.t. masking is confusing--this part seems obsolete and need to be deleted. I'll send a pull requst to fix it.

@antiagainst
Copy link
Member

Uploaded #71490 for it.

@qedawkins
Copy link
Contributor Author

Actually given we are only handling 1-D distribution in transfer write at the moment, and the fact we are compressing unused dims when inferTransferOpMaskType, effectively that guarantees we have the same shape as the vector type? I'm a bit confused now; better to get @dcaballe to take another look.

We can enable support step-by-step. The current impl is fine if we additionally check the transfer write has an identity map I think.

yeah after further thought, n-D masked writes will be quite tricky to handle properly because the warp op does not properly capture the distribution mapping needed by the transfer_write. We would have to do something like, apply the transfer map, distribute, and reverse the map to get back to the original mask shape. Ideally those transposes would cancel, but step-by-step sounds good to me. I'll guard a bit better here, but a discussion about how to think about these masks would be great.

@qedawkins qedawkins marked this pull request as ready for review November 7, 2023 14:56
@qedawkins qedawkins requested a review from antiagainst November 7, 2023 14:56
@dcaballe
Copy link
Contributor

dcaballe commented Nov 7, 2023

Basically this:

it's applied close to the memory,

We mask the memory access in whatever shape happens on memory.

@qedawkins qedawkins changed the title [MLIR][Vector] Add support for distributing masked writes [mlir][vector] Add support for distributing masked writes Nov 7, 2023
@qedawkins qedawkins merged commit 25ec1fa into llvm:main Nov 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants