-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][linalg] Add TransposeConv2D Transform Op #68567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Add TransposeConv2D Transform Op #68567
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Changes
Full diff: https://github.com/llvm/llvm-project/pull/68567.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 5f46affe592a2da..96c809f10323922 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -65,6 +65,10 @@ std::unique_ptr<Pass> createLinalgGeneralizationPass();
/// work on primitive types, if possible.
std::unique_ptr<Pass> createLinalgDetensorizePass();
+/// Create a pass to convert linalg.conv_2d_nhwc_fhwc(_q) to
+/// linalg.conv_2d_nhwc_hwcf(_q).
+std::unique_ptr<Pass> createLinalgTransposeConv2DPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 3093604af63e338..74cbe0c354f9018 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -145,4 +145,10 @@ def LinalgDetensorize : InterfacePass<"linalg-detensorize", "FunctionOpInterface
];
}
+def LinalgTransposeConv2D : Pass<"linalg-transpose-conv2d-ops"> {
+ let summary = "Convert conv_2d_nhwc_fhwc to conv_2d_nhwc_hwcf by transposing the weights.";
+ let constructor = "mlir::createLinalgTransposeConv2DPass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4e094609afa6a03..823b7bfd9810804 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Tiling.cpp
TilingInterfaceImpl.cpp
Transforms.cpp
+ TransposeConv2D.cpp
Vectorization.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
new file mode 100644
index 000000000000000..a8dee1126031601
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
@@ -0,0 +1,116 @@
+//===- TransposeConv2D.cpp - Convoultion transposition -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSECONV2D
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+// Convolution converter that matches linalg.conv_2d_nhwc_fhwc and
+// linalg.conv_2d_nhwc_fhwc_q to linalg.transpose + linalg.conv_2d_nhwc_hwcf and
+// linalg.tranpose + linalg.conv_2d_nhwc_hwcf_q respectively.
+template <typename FHWCConvOp, typename HWCFConvOp>
+class ConvConverter : public OpRewritePattern<FHWCConvOp> {
+public:
+ using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(FHWCConvOp op,
+ PatternRewriter &rewriter) const final {
+ // Transpose the weights.
+ //
+ // To do this we first need to construct a permutation of the weight tensor
+ // dimensions. For a 2D convolution this will be known statically as [1, 2,
+ // 3, 0] however we construct the vector dynamically to future proof this
+ // logic so it can be extended to convolutions of higher dimensions.
+ auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+ auto weightPerm = SmallVector<int64_t>(resultTy.getRank() - 1);
+ std::iota(std::begin(weightPerm), std::end(weightPerm), 1);
+ weightPerm.push_back(0);
+
+ // Create the type for the transposed weight tensor since this will be
+ // different from the original weight type.
+ auto weight = op->getOperand(1);
+ auto weightTy = cast<ShapedType>(weight.getType());
+ auto newWeightShape = SmallVector<int64_t>(weightPerm.size());
+ std::generate(std::begin(newWeightShape), std::end(newWeightShape),
+ [dim = 0, &weightTy, &weightPerm]() mutable {
+ return weightTy.getShape()[weightPerm[dim++]];
+ });
+ auto newWeightTy =
+ RankedTensorType::get(newWeightShape, weightTy.getElementType());
+
+ // Because linalg.tranpose expects an "out" parameter we need to pass it a
+ // tensor of zeros of the result type so here we construct that tensor.
+ auto resultETy = resultTy.getElementType();
+ auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+ auto loc = op->getLoc();
+ auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+ loc, newWeightTy.getShape(), resultETy);
+ auto zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
+ auto zeroTensor = rewriter
+ .create<linalg::FillOp>(loc, ValueRange{zero},
+ ValueRange{emptyTensor})
+ .result();
+
+ // We can then construct the transposition on our weights.
+ weight =
+ rewriter
+ .create<linalg::TransposeOp>(loc, weight, zeroTensor, weightPerm)
+ .getResult()[0];
+
+ // Create the convolution.
+ //
+ // The weights are always the second input argument.
+ auto newInputs = SmallVector<Value>{op.getInputs()};
+ newInputs[1] = weight;
+ rewriter.template replaceOpWithNewOp<HWCFConvOp>(
+ op, resultTy, newInputs, op.getOutputs(), op.getStrides(),
+ op.getDilations());
+ return success();
+ }
+};
+
+// This pass converts NHWC Conv2D operations with FHWC channel orderings to NHWC
+// Conv2D operations with HWCF channel orderings.
+struct LinalgTransposeConv2D
+ : public impl::LinalgTransposeConv2DBase<LinalgTransposeConv2D> {
+public:
+ void runOnOperation() override {
+ auto *ctx = getOperation()->getContext();
+ auto patternSet = RewritePatternSet{ctx};
+ patternSet.add<
+ ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
+ ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
+ ctx);
+
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patternSet))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createLinalgTransposeConv2DPass() {
+ return std::make_unique<LinalgTransposeConv2D>();
+}
diff --git a/mlir/test/Dialect/Linalg/transpose-conv2d.mlir b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir
new file mode 100644
index 000000000000000..22019029a02743d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(linalg-transpose-conv2d-ops))' | FileCheck %s
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[WEIGHTS:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+func.func @conv_2d_nhwc_fhwc(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+ // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+ // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[WEIGHTS]] : tensor<8x2x2x6xf32>) outs(%[[FILL]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+ outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+ // CHECK: return %[[CONV]] : tensor<1x2x2x8xf32>
+ return %0 : tensor<1x2x2x8xf32>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_q
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[WEIGHTS:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>, %[[A:.+]]: i32, %[[B:.+]]: i32) -> tensor<1x2x2x8xf32> {
+ func.func @conv_2d_nhwc_fhwc_q(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>, %a: i32, %b: i32) -> tensor<1x2x2x8xf32> {
+ // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+ // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[WEIGHTS]] : tensor<8x2x2x6xf32>) outs(%[[FILL]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]], %[[A]], %[[B]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>, i32, i32) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+ %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter, %a, %b: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>, i32, i32)
+ outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+ // CHECK: return %[[CONV]] : tensor<1x2x2x8xf32>
+ return %0 : tensor<1x2x2x8xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! My main comment would be that this pass is a bit limited ATM:
Convert conv_2d_nhwc_fhwc to conv_2d_nhwc_hwcf
Why not support conversions in both directions?
20cb7f1
to
3f7ee74
Compare
Sure, although I'm not 100% clear what "both directions" means here. Should the pass also run on |
Yes, that's what I had in mind. And in general, it would be helpful if you expanded a bit on the use cases for this in the summary (and how would this be used). Otherwise it feels a bit like some arbitrary transformation. I am sure it isn't, but it's just not clear from the summary. More specifically, is there a hardware that works better with one type over the other? And some frameworks tend to generate the "wrong" type? Just curious, but I suspect other people might be too. Also, apologies for the delay responding and thanks for the updates :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just got some nitpicks and a few questions:
Could you also add an integration test similar to https://github.com/llvm/llvm-project/blob/main/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir ? |
No worries about the delay! Thanks for the thorough reivew! 🤗 Okay so the motivation behind this is that we made some changes in the TOSA to LinAlg lowering to allow a more direct mapping. Previously We felt that a more direct mapping was appropriate as there is a 1-1 correspondence between
So I guess in summary, this pass allows anyone who wants to restore the old behaviour to do so but in a way such that they know they are materializing the extra transposition. I hope this helps, sorry for rambling. I agree this is quite a specific transformation but I'm not sure how else to add it without a whole pass. I guess I'm still a little confused about the utility of converting |
Sure, I'm happy to add a test, although I'm not sure if this is testing the transformation or the nhwc-fhwc conv2d operation itself? I see we don't have tests for that convolution channel but it existed before this PR. |
3f7ee74
to
944c68b
Compare
Thanks for the explanation @FranklandJack !
It sounds like you are only interested in one specific case, which is totally fine :) However, could you update the summary to make the justification a bit more visible? In particular, the TOSA context makes a lot of sense and IMHO is key (it gives the rationale for this pass). Otherwise it's not clear why would
What happened to that test? Ideally there would be one test for 2D convolutions and you'd simply add another configuration to run (that would include your pass). In fact, I would start by parametrizing test-conv-2d-nhwc-hwcf-call.mlir so that it works for both formats. And take it from there. But this could be a task for a separate patch.
Ha, makes sense :) Naming is hard! |
✅ With the latest revision this PR passed the C/C++ code formatter. |
98400d5
to
80d4c10
Compare
Yep, great idea!
Would it be acceptable to do this as part of another patch? The way I see it is that we have lit tests for this transformation, so in that respect it should be "correct". The CPU tests in question seemed to validate the behaviour of convolutions on the CPU runner. I agree it would be beneficial to test the code generated from this tranformation, but since that would amount to testing the behaviour of |
80d4c10
to
923b8b7
Compare
Right, it is testing the conv2d operation itself but I was thinking in the lines of testing the transformation. For now, I think you can leave the integration test for another patch. |
923b8b7
to
a7c5def
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good thanks!
Approving eagerly, please address the last 2 comments in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for addressing my comments!
Please address the final comments from Nicolas before landing this :)
* Add a linalg transform op to convert 2D convolutions and quantized 2D convolutions that have the `FHWC` filter channel ordering into a transpose followed by 2D convolutions that have the `HWCF` channel ordering. * Add a lit test to check the semantics of the transformation are correct for both quantized and unquantized variants. Signed-off-by: Jack Frankland <[email protected]>
a7c5def
to
4ba5e31
Compare
Add a LinAlg pass to convert 2D convolutions and quantized 2D convolutions that have the
FHWC
filter channel ordering into a transpose followed by 2D convolutions that have theHWCF
channel ordering.Add a lit test to check the semantics of the transformation are correct for both quantized and unquantized variants.