Skip to content

Commit 25ec1fa

Browse files
authored
[mlir][vector] Add support for distributing masked writes (#71482)
General distribution of masked writes requires materializing the permutation on the vector of the write in IR to ensure the vector lines up with the mask. For now just support cases with trivial permutation maps.
1 parent 4832c88 commit 25ec1fa

File tree

2 files changed

+64
-9
lines changed

2 files changed

+64
-9
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -406,19 +406,29 @@ struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
406406
static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
407407
WarpExecuteOnLane0Op warpOp,
408408
vector::TransferWriteOp writeOp,
409-
VectorType targetType) {
409+
VectorType targetType,
410+
VectorType maybeMaskType) {
410411
assert(writeOp->getParentOp() == warpOp &&
411412
"write must be nested immediately under warp");
412413
OpBuilder::InsertionGuard g(rewriter);
413414
SmallVector<size_t> newRetIndices;
414-
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
415-
rewriter, warpOp, ValueRange{{writeOp.getVector()}},
416-
TypeRange{targetType}, newRetIndices);
415+
WarpExecuteOnLane0Op newWarpOp;
416+
if (maybeMaskType) {
417+
newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
418+
rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
419+
TypeRange{targetType, maybeMaskType}, newRetIndices);
420+
} else {
421+
newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
422+
rewriter, warpOp, ValueRange{{writeOp.getVector()}},
423+
TypeRange{targetType}, newRetIndices);
424+
}
417425
rewriter.setInsertionPointAfter(newWarpOp);
418426
auto newWriteOp =
419427
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
420428
rewriter.eraseOp(writeOp);
421429
newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
430+
if (maybeMaskType)
431+
newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
422432
return newWriteOp;
423433
}
424434

@@ -489,10 +499,25 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
489499
if (!targetType)
490500
return failure();
491501

502+
// 2.5 Compute the distributed type for the new mask;
503+
VectorType maskType;
504+
if (writeOp.getMask()) {
505+
// TODO: Distribution of masked writes with non-trivial permutation maps
506+
// requires the distribution of the mask to elementwise match the
507+
// distribution of the permuted written vector. Currently the details
508+
// of which lane is responsible for which element is captured strictly
509+
// by shape information on the warp op, and thus requires materializing
510+
// the permutation in IR.
511+
if (!writeOp.getPermutationMap().isMinorIdentity())
512+
return failure();
513+
maskType =
514+
getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
515+
}
516+
492517
// 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
493518
// the rest.
494519
vector::TransferWriteOp newWriteOp =
495-
cloneWriteOp(rewriter, warpOp, writeOp, targetType);
520+
cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
496521

497522
// 4. Reindex the write using the distribution map.
498523
auto newWarpOp =
@@ -561,10 +586,6 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
561586

562587
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
563588
PatternRewriter &rewriter) const override {
564-
// Ops with mask not supported yet.
565-
if (writeOp.getMask())
566-
return failure();
567-
568589
auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
569590
if (!warpOp)
570591
return failure();
@@ -575,15 +596,21 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
575596
if (!isMemoryEffectFree(nextOp))
576597
return failure();
577598

599+
Value maybeMask = writeOp.getMask();
578600
if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
579601
return writeOp.getVector() == value ||
602+
(maybeMask && maybeMask == value) ||
580603
warpOp.isDefinedOutsideOfRegion(value);
581604
}))
582605
return failure();
583606

584607
if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
585608
return success();
586609

610+
// Masked writes not supported for extraction.
611+
if (writeOp.getMask())
612+
return failure();
613+
587614
if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
588615
return success();
589616

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,3 +1253,31 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
12531253
// CHECK-PROP-SAME: (%{{.+}}: index, %[[SRC:.+]]: memref<4096xf32>, %[[INDEX:.+]]: index)
12541254
// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %cst {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
12551255
// CHECK-PROP: return %[[READ]] : vector<1xf32>
1256+
1257+
// -----
1258+
1259+
func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
1260+
%c0 = arith.constant 0 : index
1261+
vector.warp_execute_on_lane_0(%laneid)[32] -> () {
1262+
%mask = "mask_def_0"() : () -> (vector<4096xi1>)
1263+
%mask2 = "mask_def_1"() : () -> (vector<32xi1>)
1264+
%0 = "some_def_0"() : () -> (vector<4096xf32>)
1265+
%1 = "some_def_1"() : () -> (vector<32xf32>)
1266+
vector.transfer_write %0, %dest[%c0], %mask : vector<4096xf32>, memref<4096xf32>
1267+
vector.transfer_write %1, %dest[%c0], %mask2 : vector<32xf32>, memref<4096xf32>
1268+
vector.yield
1269+
}
1270+
return
1271+
}
1272+
1273+
// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_masked_write(
1274+
// CHECK-DIST-AND-PROP: %[[W:.*]]:4 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xi1>, vector<128xf32>, vector<128xi1>) {
1275+
// CHECK-DIST-AND-PROP: %[[M0:.*]] = "mask_def_0"
1276+
// CHECK-DIST-AND-PROP: %[[M1:.*]] = "mask_def_1"
1277+
// CHECK-DIST-AND-PROP: %[[V0:.*]] = "some_def_0"
1278+
// CHECK-DIST-AND-PROP: %[[V1:.*]] = "some_def_1"
1279+
// CHECK-DIST-AND-PROP: vector.yield %[[V1]], %[[M1]], %[[V0]], %[[M0]]
1280+
// CHECK-DIST-AND-PROP-SAME: vector<32xf32>, vector<32xi1>, vector<4096xf32>, vector<4096xi1>
1281+
// CHECK-DIST-AND-PROP: }
1282+
// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
1283+
// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>

0 commit comments

Comments
 (0)