Skip to content

[mlir][linalg] Add TransposeConv2D Transform Op #68567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

FranklandJack
Copy link
Contributor

  • Add a LinAlg pass to convert 2D convolutions and quantized 2D convolutions that have the FHWC filter channel ordering into a transpose followed by 2D convolutions that have the HWCF channel ordering.

  • Add a lit test to check the semantics of the transformation are correct for both quantized and unquantized variants.

@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Changes
  • Add a LinAlg pass to convert 2D convolutions and quantized 2D convolutions that have the FHWC filter channel ordering into a transpose followed by 2D convolutions that have the HWCF channel ordering.

  • Add a lit test to check the semantics of the transformation are correct for both quantized and unquantized variants.


Full diff: https://github.com/llvm/llvm-project/pull/68567.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.h (+4)
  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp (+116)
  • (added) mlir/test/Dialect/Linalg/transpose-conv2d.mlir (+33)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 5f46affe592a2da..96c809f10323922 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -65,6 +65,10 @@ std::unique_ptr<Pass> createLinalgGeneralizationPass();
 /// work on primitive types, if possible.
 std::unique_ptr<Pass> createLinalgDetensorizePass();
 
+/// Create a pass to convert linalg.conv_2d_nhwc_fhwc(_q) to
+/// linalg.conv_2d_nhwc_hwcf(_q).
+std::unique_ptr<Pass> createLinalgTransposeConv2DPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 3093604af63e338..74cbe0c354f9018 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -145,4 +145,10 @@ def LinalgDetensorize : InterfacePass<"linalg-detensorize", "FunctionOpInterface
   ];
 }
 
+def LinalgTransposeConv2D : Pass<"linalg-transpose-conv2d-ops"> {
+  let summary = "Convert conv_2d_nhwc_fhwc to conv_2d_nhwc_hwcf by transposing the weights.";
+  let constructor = "mlir::createLinalgTransposeConv2DPass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4e094609afa6a03..823b7bfd9810804 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Tiling.cpp
   TilingInterfaceImpl.cpp
   Transforms.cpp
+  TransposeConv2D.cpp
   Vectorization.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
new file mode 100644
index 000000000000000..a8dee1126031601
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
@@ -0,0 +1,116 @@
+//===- TransposeConv2D.cpp - Convoultion transposition  -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSECONV2D
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+// Convolution converter that matches linalg.conv_2d_nhwc_fhwc and
+// linalg.conv_2d_nhwc_fhwc_q to linalg.transpose + linalg.conv_2d_nhwc_hwcf and
+// linalg.tranpose + linalg.conv_2d_nhwc_hwcf_q respectively.
+template <typename FHWCConvOp, typename HWCFConvOp>
+class ConvConverter : public OpRewritePattern<FHWCConvOp> {
+public:
+  using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(FHWCConvOp op,
+                                PatternRewriter &rewriter) const final {
+    // Transpose the weights.
+    //
+    // To do this we first need to construct a permutation of the weight tensor
+    // dimensions. For a 2D convolution this will be known statically as [1, 2,
+    // 3, 0] however we construct the vector dynamically to future proof this
+    // logic so it can be extended to convolutions of higher dimensions.
+    auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+    auto weightPerm = SmallVector<int64_t>(resultTy.getRank() - 1);
+    std::iota(std::begin(weightPerm), std::end(weightPerm), 1);
+    weightPerm.push_back(0);
+
+    // Create the type for the transposed weight tensor since this will be
+    // different from the original weight type.
+    auto weight = op->getOperand(1);
+    auto weightTy = cast<ShapedType>(weight.getType());
+    auto newWeightShape = SmallVector<int64_t>(weightPerm.size());
+    std::generate(std::begin(newWeightShape), std::end(newWeightShape),
+                  [dim = 0, &weightTy, &weightPerm]() mutable {
+                    return weightTy.getShape()[weightPerm[dim++]];
+                  });
+    auto newWeightTy =
+        RankedTensorType::get(newWeightShape, weightTy.getElementType());
+
+    // Because linalg.tranpose expects an "out" parameter we need to pass it a
+    // tensor of zeros of the result type so here we construct that tensor.
+    auto resultETy = resultTy.getElementType();
+    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+    auto loc = op->getLoc();
+    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+        loc, newWeightTy.getShape(), resultETy);
+    auto zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
+    auto zeroTensor = rewriter
+                          .create<linalg::FillOp>(loc, ValueRange{zero},
+                                                  ValueRange{emptyTensor})
+                          .result();
+
+    // We can then construct the transposition on our weights.
+    weight =
+        rewriter
+            .create<linalg::TransposeOp>(loc, weight, zeroTensor, weightPerm)
+            .getResult()[0];
+
+    // Create the convolution.
+    //
+    // The weights are always the second input argument.
+    auto newInputs = SmallVector<Value>{op.getInputs()};
+    newInputs[1] = weight;
+    rewriter.template replaceOpWithNewOp<HWCFConvOp>(
+        op, resultTy, newInputs, op.getOutputs(), op.getStrides(),
+        op.getDilations());
+    return success();
+  }
+};
+
+// This pass converts NHWC Conv2D operations with FHWC channel orderings to NHWC
+// Conv2D operations with HWCF channel orderings.
+struct LinalgTransposeConv2D
+    : public impl::LinalgTransposeConv2DBase<LinalgTransposeConv2D> {
+public:
+  void runOnOperation() override {
+    auto *ctx = getOperation()->getContext();
+    auto patternSet = RewritePatternSet{ctx};
+    patternSet.add<
+        ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
+        ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
+        ctx);
+
+    if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                            std::move(patternSet))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createLinalgTransposeConv2DPass() {
+  return std::make_unique<LinalgTransposeConv2D>();
+}
diff --git a/mlir/test/Dialect/Linalg/transpose-conv2d.mlir b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir
new file mode 100644
index 000000000000000..22019029a02743d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(linalg-transpose-conv2d-ops))' | FileCheck %s
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[WEIGHTS:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+func.func @conv_2d_nhwc_fhwc(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+  // CHECK-DAG:    %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+  // CHECK:    %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+  // CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[WEIGHTS]] : tensor<8x2x2x6xf32>) outs(%[[FILL]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+  // CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  // CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
+  return %0 : tensor<1x2x2x8xf32>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_q
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[WEIGHTS:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>, %[[A:.+]]: i32, %[[B:.+]]: i32) -> tensor<1x2x2x8xf32> {
+  func.func @conv_2d_nhwc_fhwc_q(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>, %a: i32, %b: i32) -> tensor<1x2x2x8xf32> {
+  // CHECK-DAG:    %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+  // CHECK:    %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+  // CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[WEIGHTS]] : tensor<8x2x2x6xf32>) outs(%[[FILL]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+  // CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]], %[[A]], %[[B]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>, i32, i32) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter, %a, %b: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>, i32, i32)
+    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  // CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
+  return %0 : tensor<1x2x2x8xf32>
+}

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! My main comment would be that this pass is a bit limited ATM:

Convert conv_2d_nhwc_fhwc to conv_2d_nhwc_hwcf

Why not support conversions in both directions?

@FranklandJack FranklandJack force-pushed the jacfra01/tranpose-conv2d-nhwc+fhwc branch from 20cb7f1 to 3f7ee74 Compare October 17, 2023 15:22
@FranklandJack
Copy link
Contributor Author

Thanks for the contribution! My main comment would be that this pass is a bit limited ATM:

Convert conv_2d_nhwc_fhwc to conv_2d_nhwc_hwcf

Why not support conversions in both directions?

Sure, although I'm not 100% clear what "both directions" means here. Should the pass also run on conv_2d_nhwc_hwcf and convert it to fwhc?

@banach-space
Copy link
Contributor

Thanks for the contribution! My main comment would be that this pass is a bit limited ATM:

Convert conv_2d_nhwc_fhwc to conv_2d_nhwc_hwcf

Why not support conversions in both directions?

Sure, although I'm not 100% clear what "both directions" means here. Should the pass also run on conv_2d_nhwc_hwcf and convert it to fwhc?

Yes, that's what I had in mind. And in general, it would be helpful if you expanded a bit on the use cases for this in the summary (and how would this be used). Otherwise it feels a bit like some arbitrary transformation. I am sure it isn't, but it's just not clear from the summary. More specifically, is there a hardware that works better with one type over the other? And some frameworks tend to generate the "wrong" type? Just curious, but I suspect other people might be too.

Also, apologies for the delay responding and thanks for the updates :)

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just got some nitpicks and a few questions:

@cfRod
Copy link
Contributor

cfRod commented Oct 19, 2023

@FranklandJack
Copy link
Contributor Author

Thanks for the contribution! My main comment would be that this pass is a bit limited ATM:

Convert conv_2d_nhwc_fhwc to conv_2d_nhwc_hwcf

Why not support conversions in both directions?

Sure, although I'm not 100% clear what "both directions" means here. Should the pass also run on conv_2d_nhwc_hwcf and convert it to fwhc?

Yes, that's what I had in mind. And in general, it would be helpful if you expanded a bit on the use cases for this in the summary (and how would this be used). Otherwise it feels a bit like some arbitrary transformation. I am sure it isn't, but it's just not clear from the summary. More specifically, is there a hardware that works better with one type over the other? And some frameworks tend to generate the "wrong" type? Just curious, but I suspect other people might be too.

Also, apologies for the delay responding and thanks for the updates :)

No worries about the delay! Thanks for the thorough reivew! 🤗

Okay so the motivation behind this is that we made some changes in the TOSA to LinAlg lowering to allow a more direct mapping. Previously tosa.conv2d (which has the {f,h,w,c} ordering) was lowered to linalg's linalg.conv_2d_nhwc_hwcf. In order to reorder the filter channels a tranpose was materialized (essentially what this pass does was previously part of TOSA->LinAlg lowering). (By the way as an aside, the TOSA documentatino refers to the filter as "weights" so I guess this is where I got that from 🤦)

We felt that a more direct mapping was appropriate as there is a 1-1 correspondence between tosa.conv2d and LinAlg's linalg.conv_2d_nhwc_fhwc and that it should be the responsibility of further transformations to introduce the tranpose if the pipeline in question wants that ordering.

{h,w,c,f} can be a useful ordering for things like img2col where the contiguous {c,f} filter dimensions allow you to employ a GEMM (which can be useful if you support optimized routines / BLAS libraries). On the other hand we now also support an img2col type transformation for the {f,h,w,c} filter ordering, which could in theory allow better cache coherency as all memory accesses happen in a contiguous way.

So I guess in summary, this pass allows anyone who wants to restore the old behaviour to do so but in a way such that they know they are materializing the extra transposition. I hope this helps, sorry for rambling. I agree this is quite a specific transformation but I'm not sure how else to add it without a whole pass.

I guess I'm still a little confused about the utility of converting {h,w,c,f} to {f,w,h,c} here when the purpose is to sort of align on the {h,w,c,f} ordering for the filters, perhaps we could add an option or something to define which way you want to go?

@FranklandJack
Copy link
Contributor Author

Could you also add an integration test similar to https://github.com/llvm/llvm-project/blob/main/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir ?

Sure, I'm happy to add a test, although I'm not sure if this is testing the transformation or the nhwc-fhwc conv2d operation itself? I see we don't have tests for that convolution channel but it existed before this PR.

@FranklandJack FranklandJack force-pushed the jacfra01/tranpose-conv2d-nhwc+fhwc branch from 3f7ee74 to 944c68b Compare October 19, 2023 11:09
@banach-space
Copy link
Contributor

Thanks for the explanation @FranklandJack !

I guess I'm still a little confused about the utility of converting {h,w,c,f} to {f,w,h,c} here when the purpose is to sort of align on the {h,w,c,f} ordering for the filters, perhaps we could add an option or something to define which way you want to go?

It sounds like you are only interested in one specific case, which is totally fine :)

However, could you update the summary to make the justification a bit more visible? In particular, the TOSA context makes a lot of sense and IMHO is key (it gives the rationale for this pass). Otherwise it's not clear why would LinalgTransposeConv2D work for only one very specific case. You could also add a comment that while the opposite transformation (hwcf -> fhwc) is not supported yet, it could be added in the future.

Sure, I'm happy to add a test, although I'm not sure if this is testing the transformation or the nhwc-fhwc conv2d operation itself? I see we don't have tests for that convolution channel but it existed before this PR.

What happened to that test? Ideally there would be one test for 2D convolutions and you'd simply add another configuration to run (that would include your pass). In fact, I would start by parametrizing test-conv-2d-nhwc-hwcf-call.mlir so that it works for both formats. And take it from there. But this could be a task for a separate patch.

By the way as an aside, the TOSA documentatino refers to the filter as "weights" so I guess this is where I got that from

Ha, makes sense :) Naming is hard!

@github-actions
Copy link

github-actions bot commented Oct 19, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@FranklandJack FranklandJack force-pushed the jacfra01/tranpose-conv2d-nhwc+fhwc branch 2 times, most recently from 98400d5 to 80d4c10 Compare October 19, 2023 16:32
@FranklandJack
Copy link
Contributor Author

However, could you update the summary to make the justification a bit more visible? In particular, the TOSA context makes a lot of sense and IMHO is key (it gives the rationale for this pass). Otherwise it's not clear why would LinalgTransposeConv2D work for only one very specific case. You could also add a comment that while the opposite transformation (hwcf -> fhwc) is not supported yet, it could be added in the future.

Yep, great idea!

What happened to that test? Ideally there would be one test for 2D convolutions and you'd simply add another configuration to run (that would include your pass). In fact, I would start by parametrizing test-conv-2d-nhwc-hwcf-call.mlir so that it works for both formats. And take it from there. But this could be a task for a separate patch.

Would it be acceptable to do this as part of another patch? The way I see it is that we have lit tests for this transformation, so in that respect it should be "correct". The CPU tests in question seemed to validate the behaviour of convolutions on the CPU runner. I agree it would be beneficial to test the code generated from this tranformation, but since that would amount to testing the behaviour of linalg.transpose and linalg.conv2d_nhwc_hwcf rather than the output of this transformation it seems like an orthogonal concern to the idea behind this PR? Would be interesting to know what @cfRod thinks as well?

@FranklandJack FranklandJack force-pushed the jacfra01/tranpose-conv2d-nhwc+fhwc branch from 80d4c10 to 923b8b7 Compare October 20, 2023 14:53
@cfRod
Copy link
Contributor

cfRod commented Oct 20, 2023

Could you also add an integration test similar to https://github.com/llvm/llvm-project/blob/main/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir ?

Sure, I'm happy to add a test, although I'm not sure if this is testing the transformation or the nhwc-fhwc conv2d operation itself? I see we don't have tests for that convolution channel but it existed before this PR.

Right, it is testing the conv2d operation itself but I was thinking in the lines of testing the transformation. For now, I think you can leave the integration test for another patch.

@FranklandJack FranklandJack force-pushed the jacfra01/tranpose-conv2d-nhwc+fhwc branch from 923b8b7 to a7c5def Compare October 23, 2023 15:38
@FranklandJack FranklandJack changed the title [mlir][linalg] Add TransposeConv2D Pass [mlir][linalg] Add TransposeConv2D Transform Op Oct 23, 2023
Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good thanks!

Approving eagerly, please address the last 2 comments in the tests.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing my comments!

Please address the final comments from Nicolas before landing this :)

* Add a linalg transform op to convert 2D convolutions and quantized 2D
  convolutions that have the `FHWC` filter channel ordering into a
  transpose followed by 2D convolutions that have the `HWCF` channel
  ordering.

* Add a lit test to check the semantics of the transformation are
  correct for both quantized and unquantized variants.

Signed-off-by: Jack Frankland <[email protected]>
@FranklandJack FranklandJack force-pushed the jacfra01/tranpose-conv2d-nhwc+fhwc branch from a7c5def to 4ba5e31 Compare November 27, 2023 17:13
@banach-space banach-space merged commit 4a3d208 into llvm:main Nov 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants