From 9c30fbfcb6c51557b4574e3249932aa07f4a6143 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Sun, 5 Nov 2023 10:58:28 -0500 Subject: [PATCH] [mlir][vector] Add leading unit dim folding patterns for masked transfers This handles `vector.transfer_read`, `vector.transfer_write`, and `vector.constant_mask`. The unit dims are only relevant for masks created by `create_mask` and `constant_mask` if the mask size for the unit dim is non-one, in which case all subsequent sizes must also be zero. From the perspective of the vector transfers, however, these unit dims can just be dropped directly. --- .../Transforms/VectorDropLeadUnitDim.cpp | 69 ++++++++++++++++--- .../vector-dropleadunitdim-transforms.mlir | 35 ++++++++++ 2 files changed, 94 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 6bbb293fa2a6b..75f32b23e57b0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -208,9 +210,6 @@ struct CastAwayTransferReadLeadingOneDim if (read.getTransferRank() == 0) return failure(); - if (read.getMask()) - return failure(); - auto shapedType = cast(read.getSource().getType()); if (shapedType.getElementType() != read.getVectorType().getElementType()) return failure(); @@ -233,10 +232,18 @@ struct CastAwayTransferReadLeadingOneDim inBoundsAttr = rewriter.getArrayAttr( read.getInBoundsAttr().getValue().take_back(newType.getRank())); + Value mask = Value(); + if (read.getMask()) { + // The mask shape must always match the shape of the written vector, so we + // can safely use the same extraction indices. + int64_t dropDim = oldType.getRank() - newType.getRank(); + mask = rewriter.create(read.getLoc(), read.getMask(), + splatZero(dropDim)); + } + auto newRead = rewriter.create( read.getLoc(), newType, read.getSource(), read.getIndices(), - AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(), - inBoundsAttr); + AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr); rewriter.replaceOpWithNewOp(read, oldType, newRead); return success(); @@ -256,9 +263,6 @@ struct CastAwayTransferWriteLeadingOneDim if (write.getTransferRank() == 0) return failure(); - if (write.getMask()) - return failure(); - auto shapedType = dyn_cast(write.getSource().getType()); if (shapedType.getElementType() != write.getVectorType().getElementType()) return failure(); @@ -283,10 +287,21 @@ struct CastAwayTransferWriteLeadingOneDim auto newVector = rewriter.create( write.getLoc(), write.getVector(), splatZero(dropDim)); + + if (write.getMask()) { + // The mask shape must always match the shape of the written vector, so we + // can safely use the same extraction indices. + auto newMask = rewriter.create( + write.getLoc(), write.getMask(), splatZero(dropDim)); + rewriter.replaceOpWithNewOp( + write, newVector, write.getSource(), write.getIndices(), + AffineMapAttr::get(newMap), newMask, inBoundsAttr); + return success(); + } + rewriter.replaceOpWithNewOp( write, newVector, write.getSource(), write.getIndices(), AffineMapAttr::get(newMap), inBoundsAttr); - return success(); } }; @@ -467,6 +482,40 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern { } }; +// Drops leading 1 dimensions from vector.constant_mask and inserts a +// vector.broadcast back to the original shape. +struct CastAwayConstantMaskLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, + PatternRewriter &rewriter) const override { + VectorType oldType = mask.getType(); + VectorType newType = trimLeadingOneDims(oldType); + + if (newType == oldType) + return failure(); + + int64_t dropDim = oldType.getRank() - newType.getRank(); + SmallVector dimSizes; + for (auto attr : mask.getMaskDimSizes()) + dimSizes.push_back(llvm::cast(attr).getInt()); + + // If any of the dropped unit dims has a size of `0`, the entire mask is a + // zero mask, else the unit dim has no effect on the mask. + int64_t flatLeadingSize = + std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1, + static_cast(1), std::multiplies()); + SmallVector newDimSizes({flatLeadingSize}); + newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); + + auto newMask = rewriter.create( + mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes)); + rewriter.replaceOpWithNewOp(mask, oldType, newMask); + return success(); + } +}; + } // namespace void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( @@ -474,7 +523,7 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( patterns .add(patterns.getContext(), benefit); populateShapeCastFoldingPatterns(patterns, benefit); diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir index e5b27b04dcc80..5de30206927db 100644 --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -209,6 +209,20 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) return %0: vector<1x4xf16> } +// CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims +func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> { + // CHECK: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16 + %f0 = arith.constant 0. : f16 + // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> + // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16> + // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16> + %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16> + // CHECK: return %[[CAST]] + return %0: vector<1x4xf16> +} + // CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> { %c0 = arith.constant 0 : index @@ -229,6 +243,18 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16> return } +// CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims +func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) { + // CHECK: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16> + // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> + // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16> + + vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16> + return +} + // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) { %c0 = arith.constant 0 : index @@ -410,3 +436,12 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x %0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1> return %0: vector<1x1x8x1x[8]xi1> } + +// CHECK-LABEL: func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> { +// CHECK: %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1> +// CHECK: return %[[BCAST]] : vector<1x1x8x2x1xi1> +func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> { + %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1> + return %0: vector<1x1x8x2x1xi1> +}