Skip to content

Commit fe638b0

Browse files
authored
[Codegen][CPU] Eliminate all-true vector masks after vectorization (#18190)
This enables an upstream transform that eliminates all true `vector.create_mask` ops. This is particularly beneficial for scalable vectors, which use dynamic tensor types, which results in masks that otherwise would not fold away till much later, preventing some optimizations. Depends on llvm/llvm-project#99314. --------- Signed-off-by: Benjamin Maxwell <[email protected]>
1 parent c71fe1a commit fe638b0

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,14 @@ class GenericVectorizationPass final
325325
void runOnOperation() override;
326326
};
327327

328+
/// Converts from iree_compiler::VscaleRange to vector::VscaleRange.
329+
static std::optional<vector::VscaleRange>
330+
toVectorVscaleRange(std::optional<iree_compiler::VscaleRange> vscaleRange) {
331+
if (!vscaleRange.has_value())
332+
return std::nullopt;
333+
return vector::VscaleRange{vscaleRange->min, vscaleRange->max};
334+
}
335+
328336
void GenericVectorizationPass::runOnOperation() {
329337
MLIRContext *context = &getContext();
330338
auto funcOp = getOperation();
@@ -377,6 +385,17 @@ void GenericVectorizationPass::runOnOperation() {
377385
vectorizeGatherAccesses);
378386
};
379387

388+
{
389+
// Eliminate (all-true) vector masks as early as possible (to avoid missing
390+
// optimizations/folds). This is particularly beneficial for scalable
391+
// vectors that use dynamic tensor shapes.
392+
auto targetAttr =
393+
iree_compiler::IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
394+
auto vscaleRange = iree_compiler::getDefaultVscaleRange(targetAttr);
395+
vector::eliminateVectorMasks(rewriter, funcOp,
396+
toVectorVscaleRange(vscaleRange));
397+
}
398+
380399
{
381400
// Canonicalize mask related ops before we lower them.
382401
RewritePatternSet maskCanonPatterns(funcOp.getContext());

compiler/src/iree/compiler/Codegen/Common/test/generic_vectorization.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,61 @@ func.func @dynamic_fill_with_scalable_tiling_infer_remainder_vector_size(%arg0:
445445
// CHECK-MASK: scf.for
446446
// CHECK-MASK: scf.for
447447
// CHECK-MASK: vector.transfer_write %[[CST]], {{.*}} {in_bounds = [true, true, true, true]} : vector<1x1x4x[4]xf32>, tensor<1x1x4x?xf32>
448+
449+
// -----
450+
451+
#aarch64_sve = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {cpu_features = "+sve", target_triple = "aarch64-none-elf"}>
452+
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0, 0], [1, 4, [4], 0], [0, 0, 0, 3], [0, 0, 0, 0]]>
453+
#map = affine_map<()[s0] -> (-(96 mod s0) + 96)>
454+
#map1 = affine_map<(d0) -> (d0 * 2)>
455+
456+
func.func @depthwise_conv_fold_away_masking(%arg0: tensor<1x68x120x96xf32>, %arg1: tensor<1x137x241x96xf32>, %arg2: tensor<3x3x96xf32>) -> tensor<1x68x120x96xf32>
457+
attributes {hal.executable.target = #aarch64_sve}
458+
{
459+
%c3 = arith.constant 3 : index
460+
%c120 = arith.constant 120 : index
461+
%c68 = arith.constant 68 : index
462+
%c4 = arith.constant 4 : index
463+
%c1 = arith.constant 1 : index
464+
%cst = arith.constant 0.000000e+00 : f32
465+
%c0 = arith.constant 0 : index
466+
%vscale = vector.vscale
467+
%c4_vscale = arith.muli %vscale, %c4 : index
468+
%0 = scf.for %arg3 = %c0 to %c68 step %c1 iter_args(%arg4 = %arg0) -> (tensor<1x68x120x96xf32>) {
469+
%1 = scf.for %arg5 = %c0 to %c120 step %c4 iter_args(%arg6 = %arg4) -> (tensor<1x68x120x96xf32>) {
470+
%2 = affine.apply #map()[%c4_vscale]
471+
%3 = scf.for %arg7 = %c0 to %2 step %c4_vscale iter_args(%arg8 = %arg6) -> (tensor<1x68x120x96xf32>) {
472+
%4 = affine.apply #map1(%arg3)
473+
%5 = affine.apply #map1(%arg5)
474+
%extracted_slice = tensor.extract_slice %arg1[0, %4, %5, %arg7] [1, 3, 9, %c4_vscale] [1, 1, 1, 1] : tensor<1x137x241x96xf32> to tensor<1x3x9x?xf32>
475+
%extracted_slice_0 = tensor.extract_slice %arg2[0, 0, %arg7] [3, 3, %c4_vscale] [1, 1, 1] : tensor<3x3x96xf32> to tensor<3x3x?xf32>
476+
%extracted_slice_1 = tensor.extract_slice %arg8[0, %arg3, %arg5, %arg7] [1, 1, 4, %c4_vscale] [1, 1, 1, 1] : tensor<1x68x120x96xf32> to tensor<1x1x4x?xf32>
477+
%6 = linalg.fill ins(%cst : f32) outs(%extracted_slice_1 : tensor<1x1x4x?xf32>) -> tensor<1x1x4x?xf32>
478+
%7 = scf.for %arg9 = %c0 to %c3 step %c1 iter_args(%arg10 = %6) -> (tensor<1x1x4x?xf32>) {
479+
%extracted_slice_2 = tensor.extract_slice %extracted_slice[0, %arg9, 0, 0] [1, 1, 9, %c4_vscale] [1, 1, 1, 1] : tensor<1x3x9x?xf32> to tensor<1x1x9x?xf32>
480+
%extracted_slice_3 = tensor.extract_slice %extracted_slice_0[%arg9, 0, 0] [1, 3, %c4_vscale] [1, 1, 1] : tensor<3x3x?xf32> to tensor<1x3x?xf32>
481+
%extracted_slice_4 = tensor.extract_slice %arg10[0, 0, 0, 0] [1, 1, 4, %c4_vscale] [1, 1, 1, 1] : tensor<1x1x4x?xf32> to tensor<1x1x4x?xf32>
482+
%extracted_slice_5 = tensor.extract_slice %extracted_slice_2[0, 0, 0, 0] [1, 1, 9, %c4_vscale] [1, 1, 1, 1] : tensor<1x1x9x?xf32> to tensor<1x9x?xf32>
483+
%extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, 0, 0] [1, 3, %c4_vscale] [1, 1, 1] : tensor<1x3x?xf32> to tensor<3x?xf32>
484+
%extracted_slice_7 = tensor.extract_slice %extracted_slice_4[0, 0, 0, 0] [1, 1, 4, %c4_vscale] [1, 1, 1, 1] : tensor<1x1x4x?xf32> to tensor<1x4x?xf32>
485+
%8 = linalg.depthwise_conv_1d_nwc_wc {dilations = dense<1> : vector<1xi64>, lowering_config = #config, strides = dense<2> : vector<1xi64>} ins(%extracted_slice_5, %extracted_slice_6 : tensor<1x9x?xf32>, tensor<3x?xf32>) outs(%extracted_slice_7 : tensor<1x4x?xf32>) -> tensor<1x4x?xf32>
486+
%inserted_slice_8 = tensor.insert_slice %8 into %extracted_slice_4[0, 0, 0, 0] [1, 1, 4, %c4_vscale] [1, 1, 1, 1] : tensor<1x4x?xf32> into tensor<1x1x4x?xf32>
487+
%inserted_slice_9 = tensor.insert_slice %inserted_slice_8 into %arg10[0, 0, 0, 0] [1, 1, 4, %c4_vscale] [1, 1, 1, 1] : tensor<1x1x4x?xf32> into tensor<1x1x4x?xf32>
488+
scf.yield %inserted_slice_9 : tensor<1x1x4x?xf32>
489+
}
490+
%inserted_slice = tensor.insert_slice %7 into %arg8[0, %arg3, %arg5, %arg7] [1, 1, 4, %c4_vscale] [1, 1, 1, 1] : tensor<1x1x4x?xf32> into tensor<1x68x120x96xf32>
491+
scf.yield %inserted_slice : tensor<1x68x120x96xf32>
492+
}
493+
scf.yield %3 : tensor<1x68x120x96xf32>
494+
}
495+
scf.yield %1 : tensor<1x68x120x96xf32>
496+
}
497+
return %0 : tensor<1x68x120x96xf32>
498+
}
499+
500+
/// This checks that the masks (introduced by the vectorizer) are eliminated by
501+
/// the end of the iree-codegen-generic-vectorization pass.
502+
503+
// CHECK-MASK-LABEL: func.func @depthwise_conv_fold_away_masking
504+
// CHECK-MASK-NOT: vector.create_mask
505+
// CHECK-MASK-NOT: vector.constant_mask

0 commit comments

Comments
 (0)