Skip to content

TosaToLinalgNamed: add option to prefer HWCF kernel layout for Conv2D ops. #70482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 27, 2023

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Oct 27, 2023

Switching to FHWC happened in #68304 and is fine in itself but caused downstream performance regression iree-org/iree#15296 (comment) , so this PR makes this optional.

@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (bjacob)

Changes

Switching to FHWC happened in #68304 and is fine in itself but caused downstream performance regression iree-org/iree#15296 (comment) , so this PR makes this optional.


Full diff: https://github.com/llvm/llvm-project/pull/70482.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+6)
  • (modified) mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h (+9-2)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+42-2)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp (+11-3)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (+5-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+7)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f05e5a8ae667dab..336f0d3af951b9a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1126,6 +1126,12 @@ def TosaToLinalgNamed
     Linalg named operations.
   }];
 
+  let options = [
+      Option<"preferConv2DKernelLayoutHWCF", "prefer-conv2d-kernel-layout-hwcf",
+           "bool", /*default=*/"false",
+           "Prefer generating linalg.conv_2d_nhwc_hwcf over linalg.conv_2d_nhwc_fhwc">
+  ];
+
   let constructor = "tosa::createTosaToLinalgNamed()";
 }
 
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index b4c4eb8651a6f00..7497a716e048d95 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -26,7 +26,8 @@ namespace mlir {
 namespace tosa {
 
 std::unique_ptr<Pass> createTosaToLinalg();
-std::unique_ptr<Pass> createTosaToLinalgNamed();
+std::unique_ptr<Pass> createTosaToLinalgNamed(
+    const TosaToLinalgNamedOptions &options = TosaToLinalgNamedOptions());
 
 /// Populates passes to convert from TOSA to Linalg on buffers. At the end of
 /// the pass, the function will only contain linalg ops or standard ops if the
@@ -34,6 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
 /// benchmarking performance improvements from the canonicalizations.
 void addTosaToLinalgPasses(
     OpPassManager &pm, const TosaToLinalgOptions &options,
+    const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions =
+        TosaToLinalgNamedOptions(),
     // Note: Default to 'none' level unless otherwise specified.
     tosa::TosaValidationOptions const &validationOptions = {
         tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
@@ -45,8 +48,12 @@ void registerTosaToLinalgPipelines();
 /// Populates conversion passes from TOSA dialect to Linalg dialect.
 void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
 
+enum class Conv2DKernelLayout { FHWC, HWCF };
+
 /// Populates conversion passes from TOSA dialect to Linalg named operations.
-void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToLinalgNamedConversionPatterns(
+    RewritePatternSet *patterns,
+    Conv2DKernelLayout conv2DKernelLayout = Conv2DKernelLayout::FHWC);
 
 } // namespace tosa
 } // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index ee8f52deadbd152..ae0b58acfd295b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include <numeric>
+#include <type_traits>
 
 using namespace mlir;
 using namespace mlir::tosa;
@@ -248,6 +249,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     pad.resize(pad.size() + 2, 0);
     input = applyPad(loc, input, pad, zeroAttr, rewriter);
 
+    if (4 == inputTy.getRank()) {
+      // For 2D convolutions, we need to check if the target convolution op
+      // wants a HWCF kernel layout.
+      bool wantHwcf =
+          isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+                      : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+      if (wantHwcf) {
+        // Transpose the kernel to match dimension ordering of the linalg
+        // convolution operation.
+        // TODO(suderman): See if this can be efficiently folded - check whether
+        // the input is used anywhere else, if not fold the constant.
+        SmallVector<int64_t> weightPerm;
+        for (int i = 1; i < resultTy.getRank(); i++)
+          weightPerm.push_back(i);
+        weightPerm.push_back(0);
+
+        SmallVector<int64_t> newWeightShape;
+        for (auto dim : weightPerm)
+          newWeightShape.push_back(weightShape[dim]);
+        auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+        Value weightPermValue =
+            rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+        Type newWeightTy =
+            RankedTensorType::get(newWeightShape, weightTy.getElementType());
+        weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+                                                    weightPermValue);
+      }
+    }
+
     // For Conv3D transpose the kernel to match dimension ordering of the linalg
     // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
     // map directly and then transpose later if desired.
@@ -977,10 +1007,20 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
-    RewritePatternSet *patterns) {
+    RewritePatternSet *patterns, Conv2DKernelLayout conv2DKernelLayout) {
+  if (conv2DKernelLayout == Conv2DKernelLayout::FHWC) {
+    patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
+                                linalg::Conv2DNhwcFhwcQOp>>(
+        patterns->getContext());
+  } else if (conv2DKernelLayout == Conv2DKernelLayout::HWCF) {
+    patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
+                                linalg::Conv2DNhwcHwcfQOp>>(
+        patterns->getContext());
+  } else {
+    assert(false);
+  }
   patterns->add<
       // clang-format off
-      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
       ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
       DepthwiseConvConverter,
       MatMulConverter,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 4c941a109ed845e..e330c9cff141e40 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -37,6 +37,9 @@ namespace {
 struct TosaToLinalgNamed
     : public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> {
 public:
+  TosaToLinalgNamed(const TosaToLinalgNamedOptions &options)
+      : impl::TosaToLinalgNamedBase<TosaToLinalgNamed>(options) {}
+
   void getDependentDialects(DialectRegistry &registry) const override {
     registry
         .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
@@ -61,13 +64,18 @@ struct TosaToLinalgNamed
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
     FunctionOpInterface func = getOperation();
-    mlir::tosa::populateTosaToLinalgNamedConversionPatterns(&patterns);
+    tosa::Conv2DKernelLayout conv2DKernelLayout =
+        preferConv2DKernelLayoutHWCF ? tosa::Conv2DKernelLayout::HWCF
+                                     : tosa::Conv2DKernelLayout::FHWC;
+    tosa::populateTosaToLinalgNamedConversionPatterns(&patterns,
+                                                      conv2DKernelLayout);
     if (failed(applyFullConversion(func, target, std::move(patterns))))
       signalPassFailure();
   }
 };
 } // namespace
 
-std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgNamed() {
-  return std::make_unique<TosaToLinalgNamed>();
+std::unique_ptr<Pass>
+mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) {
+  return std::make_unique<TosaToLinalgNamed>(options);
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index a486e28c50c7129..687477810030d4c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -76,6 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
 
 void mlir::tosa::addTosaToLinalgPasses(
     OpPassManager &pm, const TosaToLinalgOptions &options,
+    const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
     tosa::TosaValidationOptions const &validationOptions) {
   // Optional decompositions are designed to benefit linalg.
   if (!options.disableTosaDecompositions)
@@ -84,7 +85,8 @@ void mlir::tosa::addTosaToLinalgPasses(
 
   pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
-  pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
+  pm.addNestedPass<func::FuncOp>(
+      tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
   // TODO: Remove pass that operates on const tensor and enable optionality
   pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
@@ -106,7 +108,9 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
       "named operations.",
       [](OpPassManager &pm) {
         TosaToLinalgOptions tosaToLinalgOptions;
+        TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
         tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
+                                    tosaToLinalgNamedOptions,
                                     /* validationOptions = */
                                     {tosa::TosaProfileEnum::BaseInference,
                                      /* StrictOperationSpecAlignment = */ true,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index b601bfb28a4f280..1cf7c8dee606899 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
 
 // CHECK-LABEL: @matmul
 func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -363,11 +364,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 
 // CHECK-LABEL: @conv2d_i8
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
+  // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+  // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
   // CHECK: %[[M_IN:.+]] = tensor.empty()
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty()
   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+  // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
   // CHECK:   arith.extsi
   // CHECK:   arith.addi
@@ -383,11 +387,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 
 // CHECK-LABEL: @conv2d_f32
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
+  // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+  // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
   // CHECK: %[[M_IN:.+]] = tensor.empty()
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty()
   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
+  // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
   // CHECK:   arith.addf
   // CHECK:   linalg.yield

@bjacob
Copy link
Contributor Author

bjacob commented Oct 27, 2023

@FranklandJack

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have one small comment. Thanks!

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@bjacob bjacob merged commit acc6f3e into llvm:main Oct 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants