@@ -406,19 +406,29 @@ struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
406
406
static vector::TransferWriteOp cloneWriteOp (RewriterBase &rewriter,
407
407
WarpExecuteOnLane0Op warpOp,
408
408
vector::TransferWriteOp writeOp,
409
- VectorType targetType) {
409
+ VectorType targetType,
410
+ VectorType maybeMaskType) {
410
411
assert (writeOp->getParentOp () == warpOp &&
411
412
" write must be nested immediately under warp" );
412
413
OpBuilder::InsertionGuard g (rewriter);
413
414
SmallVector<size_t > newRetIndices;
414
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
415
- rewriter, warpOp, ValueRange{{writeOp.getVector ()}},
416
- TypeRange{targetType}, newRetIndices);
415
+ WarpExecuteOnLane0Op newWarpOp;
416
+ if (maybeMaskType) {
417
+ newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
418
+ rewriter, warpOp, ValueRange{writeOp.getVector (), writeOp.getMask ()},
419
+ TypeRange{targetType, maybeMaskType}, newRetIndices);
420
+ } else {
421
+ newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
422
+ rewriter, warpOp, ValueRange{{writeOp.getVector ()}},
423
+ TypeRange{targetType}, newRetIndices);
424
+ }
417
425
rewriter.setInsertionPointAfter (newWarpOp);
418
426
auto newWriteOp =
419
427
cast<vector::TransferWriteOp>(rewriter.clone (*writeOp.getOperation ()));
420
428
rewriter.eraseOp (writeOp);
421
429
newWriteOp.getVectorMutable ().assign (newWarpOp.getResult (newRetIndices[0 ]));
430
+ if (maybeMaskType)
431
+ newWriteOp.getMaskMutable ().assign (newWarpOp.getResult (newRetIndices[1 ]));
422
432
return newWriteOp;
423
433
}
424
434
@@ -489,10 +499,25 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
489
499
if (!targetType)
490
500
return failure ();
491
501
502
+ // 2.5 Compute the distributed type for the new mask;
503
+ VectorType maskType;
504
+ if (writeOp.getMask ()) {
505
+ // TODO: Distribution of masked writes with non-trivial permutation maps
506
+ // requires the distribution of the mask to elementwise match the
507
+ // distribution of the permuted written vector. Currently the details
508
+ // of which lane is responsible for which element is captured strictly
509
+ // by shape information on the warp op, and thus requires materializing
510
+ // the permutation in IR.
511
+ if (!writeOp.getPermutationMap ().isMinorIdentity ())
512
+ return failure ();
513
+ maskType =
514
+ getDistributedType (writeOp.getMaskType (), map, warpOp.getWarpSize ());
515
+ }
516
+
492
517
// 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
493
518
// the rest.
494
519
vector::TransferWriteOp newWriteOp =
495
- cloneWriteOp (rewriter, warpOp, writeOp, targetType);
520
+ cloneWriteOp (rewriter, warpOp, writeOp, targetType, maskType );
496
521
497
522
// 4. Reindex the write using the distribution map.
498
523
auto newWarpOp =
@@ -561,10 +586,6 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
561
586
562
587
LogicalResult matchAndRewrite (vector::TransferWriteOp writeOp,
563
588
PatternRewriter &rewriter) const override {
564
- // Ops with mask not supported yet.
565
- if (writeOp.getMask ())
566
- return failure ();
567
-
568
589
auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp ());
569
590
if (!warpOp)
570
591
return failure ();
@@ -575,15 +596,21 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
575
596
if (!isMemoryEffectFree (nextOp))
576
597
return failure ();
577
598
599
+ Value maybeMask = writeOp.getMask ();
578
600
if (!llvm::all_of (writeOp->getOperands (), [&](Value value) {
579
601
return writeOp.getVector () == value ||
602
+ (maybeMask && maybeMask == value) ||
580
603
warpOp.isDefinedOutsideOfRegion (value);
581
604
}))
582
605
return failure ();
583
606
584
607
if (succeeded (tryDistributeOp (rewriter, writeOp, warpOp)))
585
608
return success ();
586
609
610
+ // Masked writes not supported for extraction.
611
+ if (writeOp.getMask ())
612
+ return failure ();
613
+
587
614
if (succeeded (tryExtractOp (rewriter, writeOp, warpOp)))
588
615
return success ();
589
616
0 commit comments