-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][vector] Add support for distributing masked writes #71482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Because the mask applies to the un-permuted write vector, we can simply distribute the mask identically to the vector, if present.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Quinn Dawkins (qedawkins) ChangesBecause the mask applies to the un-permuted write vector, we can simply distribute the mask identically to the vector, if present. Full diff: https://github.com/llvm/llvm-project/pull/71482.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 78015e3deeb967e..bbc28e64bbfd8ac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -406,19 +406,29 @@ struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
WarpExecuteOnLane0Op warpOp,
vector::TransferWriteOp writeOp,
- VectorType targetType) {
+ VectorType targetType,
+ VectorType maybeMaskType) {
assert(writeOp->getParentOp() == warpOp &&
"write must be nested immediately under warp");
OpBuilder::InsertionGuard g(rewriter);
SmallVector<size_t> newRetIndices;
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, ValueRange{{writeOp.getVector()}},
- TypeRange{targetType}, newRetIndices);
+ WarpExecuteOnLane0Op newWarpOp;
+ if (maybeMaskType) {
+ newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
+ TypeRange{targetType, maybeMaskType}, newRetIndices);
+ } else {
+ newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, ValueRange{{writeOp.getVector()}},
+ TypeRange{targetType}, newRetIndices);
+ }
rewriter.setInsertionPointAfter(newWarpOp);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
rewriter.eraseOp(writeOp);
newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+ if (maybeMaskType)
+ newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
return newWriteOp;
}
@@ -489,10 +499,18 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
if (!targetType)
return failure();
+ // 2.5 Compute the distributed type for the new mask;
+ VectorType maskType;
+ if (writeOp.getMask()) {
+ maskType =
+ getDistributedType(writeOp.getMask().getType().cast<VectorType>(),
+ map, warpOp.getWarpSize());
+ }
+
// 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
// the rest.
vector::TransferWriteOp newWriteOp =
- cloneWriteOp(rewriter, warpOp, writeOp, targetType);
+ cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
// 4. Reindex the write using the distribution map.
auto newWarpOp =
@@ -561,10 +579,6 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
- // Ops with mask not supported yet.
- if (writeOp.getMask())
- return failure();
-
auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
if (!warpOp)
return failure();
@@ -575,8 +589,10 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
if (!isMemoryEffectFree(nextOp))
return failure();
+ Value maybeMask = writeOp.getMask();
if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
return writeOp.getVector() == value ||
+ (maybeMask && maybeMask == value) ||
warpOp.isDefinedOutsideOfRegion(value);
}))
return failure();
@@ -584,6 +600,10 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
return success();
+ // Masked writes not supported for extraction.
+ if (writeOp.getMask())
+ return failure();
+
if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
return success();
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 5ec02ce002ffbd6..f050bcd246e5ef7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1253,3 +1253,31 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
// CHECK-PROP-SAME: (%{{.+}}: index, %[[SRC:.+]]: memref<4096xf32>, %[[INDEX:.+]]: index)
// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %cst {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
// CHECK-PROP: return %[[READ]] : vector<1xf32>
+
+// -----
+
+func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.warp_execute_on_lane_0(%laneid)[32] -> () {
+ %mask = "mask_def_0"() : () -> (vector<4096xi1>)
+ %mask2 = "mask_def_1"() : () -> (vector<32xi1>)
+ %0 = "some_def_0"() : () -> (vector<4096xf32>)
+ %1 = "some_def_1"() : () -> (vector<32xf32>)
+ vector.transfer_write %0, %dest[%c0], %mask : vector<4096xf32>, memref<4096xf32>
+ vector.transfer_write %1, %dest[%c0], %mask2 : vector<32xf32>, memref<4096xf32>
+ vector.yield
+ }
+ return
+}
+
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_masked_write(
+// CHECK-DIST-AND-PROP: %[[W:.*]]:4 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xi1>, vector<128xf32>, vector<128xi1>) {
+// CHECK-DIST-AND-PROP: %[[M0:.*]] = "mask_def_0"
+// CHECK-DIST-AND-PROP: %[[M1:.*]] = "mask_def_1"
+// CHECK-DIST-AND-PROP: %[[V0:.*]] = "some_def_0"
+// CHECK-DIST-AND-PROP: %[[V1:.*]] = "some_def_1"
+// CHECK-DIST-AND-PROP: vector.yield %[[V1]], %[[M1]], %[[V0]], %[[M0]]
+// CHECK-DIST-AND-PROP-SAME: vector<32xf32>, vector<32xi1>, vector<4096xf32>, vector<4096xi1>
+// CHECK-DIST-AND-PROP: }
+// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
+// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm not sure actually this is correct. In my mental model about the masking, it's applied close to the memory, immediately before/after write/read, because that's where masking really happens. So for write, the masking should match with the memory shape, not exactly the vector shape, when we have non-idenity permutation maps. I think we need inverse the map and compose with the vector shape to get the mask shape, and distribute that.. @dcaballe WDYT?
Sorry, I reread the docs for transfer_write and you're right. Moving to a draft for now. |
Actually given we are only handling 1-D distribution in transfer write at the moment, and the fact we are compressing unused dims when We can enable support step-by-step. The current impl is fine if we additionally check the transfer write has an identity map I think. |
The doc in transfer write w.r.t. masking is confusing--this part seems obsolete and need to be deleted. I'll send a pull requst to fix it. |
Uploaded #71490 for it. |
yeah after further thought, n-D masked writes will be quite tricky to handle properly because the warp op does not properly capture the distribution mapping needed by the transfer_write. We would have to do something like, apply the transfer map, distribute, and reverse the map to get back to the original mask shape. Ideally those transposes would cancel, but step-by-step sounds good to me. I'll guard a bit better here, but a discussion about how to think about these masks would be great. |
Basically this:
We mask the memory access in whatever shape happens on memory. |
Because the mask applies to the un-permuted write vector, we can simply distribute the mask identically to the vector, if present.