-
Notifications
You must be signed in to change notification settings - Fork 13.4k
TosaToLinalgNamed: add option to prefer HWCF kernel layout for Conv2D ops. #70482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (bjacob) ChangesSwitching to FHWC happened in #68304 and is fine in itself but caused downstream performance regression iree-org/iree#15296 (comment) , so this PR makes this optional. Full diff: https://github.com/llvm/llvm-project/pull/70482.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f05e5a8ae667dab..336f0d3af951b9a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1126,6 +1126,12 @@ def TosaToLinalgNamed
Linalg named operations.
}];
+ let options = [
+ Option<"preferConv2DKernelLayoutHWCF", "prefer-conv2d-kernel-layout-hwcf",
+ "bool", /*default=*/"false",
+ "Prefer generating linalg.conv_2d_nhwc_hwcf over linalg.conv_2d_nhwc_fhwc">
+ ];
+
let constructor = "tosa::createTosaToLinalgNamed()";
}
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index b4c4eb8651a6f00..7497a716e048d95 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -26,7 +26,8 @@ namespace mlir {
namespace tosa {
std::unique_ptr<Pass> createTosaToLinalg();
-std::unique_ptr<Pass> createTosaToLinalgNamed();
+std::unique_ptr<Pass> createTosaToLinalgNamed(
+ const TosaToLinalgNamedOptions &options = TosaToLinalgNamedOptions());
/// Populates passes to convert from TOSA to Linalg on buffers. At the end of
/// the pass, the function will only contain linalg ops or standard ops if the
@@ -34,6 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
/// benchmarking performance improvements from the canonicalizations.
void addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
+ const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions =
+ TosaToLinalgNamedOptions(),
// Note: Default to 'none' level unless otherwise specified.
tosa::TosaValidationOptions const &validationOptions = {
tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
@@ -45,8 +48,12 @@ void registerTosaToLinalgPipelines();
/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
+enum class Conv2DKernelLayout { FHWC, HWCF };
+
/// Populates conversion passes from TOSA dialect to Linalg named operations.
-void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToLinalgNamedConversionPatterns(
+ RewritePatternSet *patterns,
+ Conv2DKernelLayout conv2DKernelLayout = Conv2DKernelLayout::FHWC);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index ee8f52deadbd152..ae0b58acfd295b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -26,6 +26,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <numeric>
+#include <type_traits>
using namespace mlir;
using namespace mlir::tosa;
@@ -248,6 +249,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
pad.resize(pad.size() + 2, 0);
input = applyPad(loc, input, pad, zeroAttr, rewriter);
+ if (4 == inputTy.getRank()) {
+ // For 2D convolutions, we need to check if the target convolution op
+ // wants a HWCF kernel layout.
+ bool wantHwcf =
+ isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+ if (wantHwcf) {
+ // Transpose the kernel to match dimension ordering of the linalg
+ // convolution operation.
+ // TODO(suderman): See if this can be efficiently folded - check whether
+ // the input is used anywhere else, if not fold the constant.
+ SmallVector<int64_t> weightPerm;
+ for (int i = 1; i < resultTy.getRank(); i++)
+ weightPerm.push_back(i);
+ weightPerm.push_back(0);
+
+ SmallVector<int64_t> newWeightShape;
+ for (auto dim : weightPerm)
+ newWeightShape.push_back(weightShape[dim]);
+ auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+ Value weightPermValue =
+ rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ Type newWeightTy =
+ RankedTensorType::get(newWeightShape, weightTy.getElementType());
+ weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+ weightPermValue);
+ }
+ }
+
// For Conv3D transpose the kernel to match dimension ordering of the linalg
// convolution operation. Conv2D has a 1-1 mapping in linalg so better to
// map directly and then transpose later if desired.
@@ -977,10 +1007,20 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
} // namespace
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
- RewritePatternSet *patterns) {
+ RewritePatternSet *patterns, Conv2DKernelLayout conv2DKernelLayout) {
+ if (conv2DKernelLayout == Conv2DKernelLayout::FHWC) {
+ patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
+ linalg::Conv2DNhwcFhwcQOp>>(
+ patterns->getContext());
+ } else if (conv2DKernelLayout == Conv2DKernelLayout::HWCF) {
+ patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
+ linalg::Conv2DNhwcHwcfQOp>>(
+ patterns->getContext());
+ } else {
+ assert(false);
+ }
patterns->add<
// clang-format off
- ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
DepthwiseConvConverter,
MatMulConverter,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 4c941a109ed845e..e330c9cff141e40 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -37,6 +37,9 @@ namespace {
struct TosaToLinalgNamed
: public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> {
public:
+ TosaToLinalgNamed(const TosaToLinalgNamedOptions &options)
+ : impl::TosaToLinalgNamedBase<TosaToLinalgNamed>(options) {}
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
@@ -61,13 +64,18 @@ struct TosaToLinalgNamed
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
FunctionOpInterface func = getOperation();
- mlir::tosa::populateTosaToLinalgNamedConversionPatterns(&patterns);
+ tosa::Conv2DKernelLayout conv2DKernelLayout =
+ preferConv2DKernelLayoutHWCF ? tosa::Conv2DKernelLayout::HWCF
+ : tosa::Conv2DKernelLayout::FHWC;
+ tosa::populateTosaToLinalgNamedConversionPatterns(&patterns,
+ conv2DKernelLayout);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
-std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgNamed() {
- return std::make_unique<TosaToLinalgNamed>();
+std::unique_ptr<Pass>
+mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) {
+ return std::make_unique<TosaToLinalgNamed>(options);
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index a486e28c50c7129..687477810030d4c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -76,6 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
void mlir::tosa::addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
+ const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
tosa::TosaValidationOptions const &validationOptions) {
// Optional decompositions are designed to benefit linalg.
if (!options.disableTosaDecompositions)
@@ -84,7 +85,8 @@ void mlir::tosa::addTosaToLinalgPasses(
pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
- pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
+ pm.addNestedPass<func::FuncOp>(
+ tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// TODO: Remove pass that operates on const tensor and enable optionality
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
@@ -106,7 +108,9 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
"named operations.",
[](OpPassManager &pm) {
TosaToLinalgOptions tosaToLinalgOptions;
+ TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
+ tosaToLinalgNamedOptions,
/* validationOptions = */
{tosa::TosaProfileEnum::BaseInference,
/* StrictOperationSpecAlignment = */ true,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index b601bfb28a4f280..1cf7c8dee606899 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
// CHECK-LABEL: @matmul
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -363,11 +364,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
// CHECK-LABEL: @conv2d_i8
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
+ // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+ // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
// CHECK: %[[M_IN:.+]] = tensor.empty()
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty()
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+ // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
// CHECK: arith.extsi
// CHECK: arith.addi
@@ -383,11 +387,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
// CHECK-LABEL: @conv2d_f32
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
+ // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+ // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
// CHECK: %[[M_IN:.+]] = tensor.empty()
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty()
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
+ // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
// CHECK: arith.addf
// CHECK: linalg.yield
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have one small comment. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
Switching to FHWC happened in #68304 and is fine in itself but caused downstream performance regression iree-org/iree#15296 (comment) , so this PR makes this optional.