Skip to content

[mlir][vector] Add leading unit dim folding patterns for masked transfers #71466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 59 additions & 10 deletions mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//

#include <numeric>

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down Expand Up @@ -208,9 +210,6 @@ struct CastAwayTransferReadLeadingOneDim
if (read.getTransferRank() == 0)
return failure();

if (read.getMask())
return failure();

auto shapedType = cast<ShapedType>(read.getSource().getType());
if (shapedType.getElementType() != read.getVectorType().getElementType())
return failure();
Expand All @@ -233,10 +232,18 @@ struct CastAwayTransferReadLeadingOneDim
inBoundsAttr = rewriter.getArrayAttr(
read.getInBoundsAttr().getValue().take_back(newType.getRank()));

Value mask = Value();
if (read.getMask()) {
// The mask shape must always match the shape of the written vector, so we
// can safely use the same extraction indices.
int64_t dropDim = oldType.getRank() - newType.getRank();
mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
splatZero(dropDim));
}

auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), newType, read.getSource(), read.getIndices(),
AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
inBoundsAttr);
AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);

return success();
Expand All @@ -256,9 +263,6 @@ struct CastAwayTransferWriteLeadingOneDim
if (write.getTransferRank() == 0)
return failure();

if (write.getMask())
return failure();

auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
if (shapedType.getElementType() != write.getVectorType().getElementType())
return failure();
Expand All @@ -283,10 +287,21 @@ struct CastAwayTransferWriteLeadingOneDim

auto newVector = rewriter.create<vector::ExtractOp>(
write.getLoc(), write.getVector(), splatZero(dropDim));

if (write.getMask()) {
// The mask shape must always match the shape of the written vector, so we
// can safely use the same extraction indices.
auto newMask = rewriter.create<vector::ExtractOp>(
write.getLoc(), write.getMask(), splatZero(dropDim));
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.getSource(), write.getIndices(),
AffineMapAttr::get(newMap), newMask, inBoundsAttr);
return success();
}

rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.getSource(), write.getIndices(),
AffineMapAttr::get(newMap), inBoundsAttr);

return success();
}
};
Expand Down Expand Up @@ -467,14 +482,48 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
}
};

// Drops leading 1 dimensions from vector.constant_mask and inserts a
// vector.broadcast back to the original shape.
struct CastAwayConstantMaskLeadingOneDim
: public OpRewritePattern<vector::ConstantMaskOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
PatternRewriter &rewriter) const override {
VectorType oldType = mask.getType();
VectorType newType = trimLeadingOneDims(oldType);

if (newType == oldType)
return failure();

int64_t dropDim = oldType.getRank() - newType.getRank();
SmallVector<int64_t> dimSizes;
for (auto attr : mask.getMaskDimSizes())
dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());

// If any of the dropped unit dims has a size of `0`, the entire mask is a
// zero mask, else the unit dim has no effect on the mask.
int64_t flatLeadingSize =
std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
static_cast<int64_t>(1), std::multiplies<int64_t>());
SmallVector<int64_t> newDimSizes({flatLeadingSize});
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());

auto newMask = rewriter.create<vector::ConstantMaskOp>(
mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
return success();
}
};

} // namespace

void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<CastAwayExtractStridedSliceLeadingOneDim,
CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
CastAwayTransferReadLeadingOneDim,
CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
populateShapeCastFoldingPatterns(patterns, benefit);
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,20 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>)
return %0: vector<1x4xf16>
}

// CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims
func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
%f0 = arith.constant 0. : f16
// CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
// CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
// CHECK: return %[[CAST]]
return %0: vector<1x4xf16>
}

// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
%c0 = arith.constant 0 : index
Expand All @@ -229,6 +243,18 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>
return
}

// CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims
func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
// CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>

vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
return
}

// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -410,3 +436,12 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
return %0: vector<1x1x8x1x[8]xi1>
}

// CHECK-LABEL: func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
// CHECK: %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
// CHECK: return %[[BCAST]] : vector<1x1x8x2x1xi1>
func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
%0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
return %0: vector<1x1x8x2x1xi1>
}