From 4ba5e31c4cef900b97a48f7f9e467462570cee37 Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Mon, 25 Sep 2023 19:55:39 +0100 Subject: [PATCH] [mlir][linalg] Add TransposeConv2D transform op * Add a linalg transform op to convert 2D convolutions and quantized 2D convolutions that have the `FHWC` filter channel ordering into a transpose followed by 2D convolutions that have the `HWCF` channel ordering. * Add a lit test to check the semantics of the transformation are correct for both quantized and unquantized variants. Signed-off-by: Jack Frankland --- .../Linalg/TransformOps/LinalgTransformOps.td | 49 +++++ .../Dialect/Linalg/Transforms/Transforms.h | 7 + .../TransformOps/LinalgTransformOps.cpp | 27 +++ .../Dialect/Linalg/Transforms/CMakeLists.txt | 1 + .../Linalg/Transforms/TransposeConv2D.cpp | 150 +++++++++++++++ .../test/Dialect/Linalg/transpose-conv2d.mlir | 177 ++++++++++++++++++ 6 files changed, 411 insertions(+) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp create mode 100644 mlir/test/Dialect/Linalg/transpose-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index f1c3d717f1fa9..fb660c6461266 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2249,6 +2249,55 @@ def ConvertConv2DToImg2ColOp : Op { + let description = [{ + Convert linalg.conv_2d_nhwc_fhwc into linalg.conv_2d_nhwc_hwcf by introducing + a linalg.transpose on the filter tensor/memref. + + Whilst the fhwc filter channel ordering can be desirable for certain targets + and is a more direct mapping to higher level dialects such as TOSA (which only + supports this ordering) hwcf is better suited for transformations such as + img2col which can make use of optimized BLAS routines such as GEMM. + + Returns one handle: + - The final operation of the sequence that replaces the original + convolution. + + #### Return modes: + + Returns a definite failure if target is not isolated from above. + Returns a silenceable failure if the pattern application failed. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::LinalgOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + //===----------------------------------------------------------------------===// // InsertSliceToCopyOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 6547648f7495c..6c4e16bd94f47 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1225,6 +1225,13 @@ rewriteInIm2Col(RewriterBase &rewriter, FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp); +/// Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by +/// materializing transpose. +FailureOr transposeConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op); +FailureOr transposeConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcQOp op); + //===----------------------------------------------------------------------===// // Rewrite patterns wrapping transformations. // TODO: every single such pattern should be a close to noop wrapper around a diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index ef5d88d46dd28..14404d837ff74 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3169,6 +3169,33 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// TransposeConv2DOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne( + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + auto maybeTransformed = + TypeSwitch>(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { + return transposeConv2D(rewriter, op); + }) + .Case([&](linalg::Conv2DNhwcFhwcQOp op) { + return transposeConv2D(rewriter, op); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(op, "not supported"); + }); + if (failed(maybeTransformed)) + return emitDefaultSilenceableFailure(target); + // Handle to the new Conv2D operation with transposed filters + results.push_back(*maybeTransformed); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // InsertSliceToCopyOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 2f7b556bb2460..4f47e3b871845 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Tiling.cpp TilingInterfaceImpl.cpp Transforms.cpp + TransposeConv2D.cpp Vectorization.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp new file mode 100644 index 0000000000000..9e0829ee67c01 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp @@ -0,0 +1,150 @@ +//===- TransposeConv2D.cpp - Convolution transposition -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/RWMutex.h" +#include +#include + +namespace mlir { +namespace linalg { +namespace { +// clang-format off +/// Convolution converter that applies the following rewrite: +/// +/// Before: +/// +/// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, +/// strides = dense<2> : tensor<2xi64>} +/// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) +/// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> +/// +/// After: +/// +/// %cst = arith.constant 0.000000e+00 : f32 +/// %0 = tensor.empty() : tensor<2x2x6x8xf32> +/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32> +/// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>) +/// permutation = [1, 2, 3, 0] +/// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} +/// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>) +/// -> tensor<1x2x2x8xf32> +/// +/// with an analogous example for the quantized case. +// clang-format on +template +FailureOr transposeConv2DHelper(RewriterBase &rewriter, + FHWCConvOp op) { + // Construct a permutation of the filter tensor dimensions. For a 2D + // convolution this will be known statically as [1, 2, 3, 0]. + SmallVector filterPerm({1, 2, 3, 0}); + + // Create the type for the transposed filter tensor. + auto filter = op->getOperand(1); + auto filterTy = cast(filter.getType()); + SmallVector newFilterShape(filterPerm.size()); + std::generate(std::begin(newFilterShape), std::end(newFilterShape), + [dim = 0, &filterTy, &filterPerm]() mutable { + return filterTy.getShape()[filterPerm[dim++]]; + }); + + // Because linalg.transpose expects an "out" parameter we need to pass it a + // tensor of zeros of the result type so here we construct that tensor. + auto inputType = op->getOperand(0).getType(); + auto elementTy = cast(inputType).getElementType(); + auto loc = op->getLoc(); + + const auto isTensorOp = isa(inputType); + Value input; + if (isTensorOp) { + + input = rewriter.create(loc, newFilterShape, elementTy) + .getResult(); + } else { + input = rewriter + .create( + loc, MemRefType::get(newFilterShape, elementTy)) + .getResult(); + } + + // We can then construct the transposition on our filter. + auto transpose = + rewriter.create(loc, filter, input, filterPerm); + + Value newFilter; + if (isTensorOp) { + newFilter = transpose.getResult()[0]; + } else { + newFilter = input; + } + + SmallVector newInputs{op.getInputs()}; + // The filter is always the second input argument, the other inputs can be + // left as they are. + newInputs[1] = newFilter; + // It is possible the convolution doesn't define any results and its + // out argument is just used instead. + SmallVector resultTy; + if (op.getNumResults()) { + resultTy.push_back(op->getResult(0).getType()); + } + auto newConv = + rewriter.create(loc, resultTy, newInputs, op.getOutputs(), + op.getStrides(), op.getDilations()); + rewriter.replaceOp(op, newConv); + return newConv.getOperation(); +} + +template +class ConvConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(FHWCConvOp op, + PatternRewriter &rewriter) const final { + if (failed(transposeConv2DHelper(rewriter, op))) { + return failure(); + } + return success(); + } +}; +} // namespace + +FailureOr transposeConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op) { + + return transposeConv2DHelper(rewriter, op); +} + +FailureOr transposeConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcQOp op) { + + return transposeConv2DHelper(rewriter, op); +} + +void populateTranposeConv2DPatterns(RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.insert< + ConvConverter, + ConvConverter>( + context); +} +} // namespace linalg +} // namespace mlir diff --git a/mlir/test/Dialect/Linalg/transpose-conv2d.mlir b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir new file mode 100644 index 0000000000000..4655a261d986b --- /dev/null +++ b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir @@ -0,0 +1,177 @@ +// RUN: mlir-opt %s -transform-interpreter -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_f64 +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf64>, %[[FILTER:.+]]: tensor<8x2x2x6xf64>, %[[INIT:.+]]: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf64> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf64>) outs(%[[NEWF]] : tensor<2x2x6x8xf64>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf64>, tensor<2x2x6x8xf64>) outs(%[[INIT]] : tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> +// CHECK: return %[[CONV]] : tensor<1x2x2x8xf64> +func.func @conv_2d_nhwc_fhwc_f64(%input: tensor<1x4x4x6xf64>, %filter: tensor<8x2x2x6xf64>, %init: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor<1x4x4x6xf64>, tensor<8x2x2x6xf64>) + outs (%init: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> + return %0 : tensor<1x2x2x8xf64> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32 +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> +// CHECK: return %[[CONV]] : tensor<1x2x2x8xf32> +func.func @conv_2d_nhwc_fhwc_f32(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) + outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> + return %0 : tensor<1x2x2x8xf32> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_f16 +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf16>, %[[FILTER:.+]]: tensor<8x2x2x6xf16>, %[[INIT:.+]]: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf16> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf16>) outs(%[[NEWF]] : tensor<2x2x6x8xf16>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf16>, tensor<2x2x6x8xf16>) outs(%[[INIT]] : tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> +// CHECK: return %[[CONV]] : tensor<1x2x2x8xf16> +func.func @conv_2d_nhwc_fhwc_f16(%input: tensor<1x4x4x6xf16>, %filter: tensor<8x2x2x6xf16>, %init: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor<1x4x4x6xf16>, tensor<8x2x2x6xf16>) + outs (%init: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> + return %0 : tensor<1x2x2x8xf16> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_b16 +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xbf16>, %[[FILTER:.+]]: tensor<8x2x2x6xbf16>, %[[INIT:.+]]: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xbf16> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xbf16>) outs(%[[NEWF]] : tensor<2x2x6x8xbf16>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xbf16>, tensor<2x2x6x8xbf16>) outs(%[[INIT]] : tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> +// CHECK: return %[[CONV]] : tensor<1x2x2x8xbf16> +func.func @conv_2d_nhwc_fhwc_b16(%input: tensor<1x4x4x6xbf16>, %filter: tensor<8x2x2x6xbf16>, %init: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor<1x4x4x6xbf16>, tensor<8x2x2x6xbf16>) + outs (%init: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> + return %0 : tensor<1x2x2x8xbf16> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi64>, %[[FILTER:.+]]: tensor<8x2x2x6xi64>, %[[INIT:.+]]: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi64> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi64>) outs(%[[NEWF]] : tensor<2x2x6x8xi64>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi64>, tensor<2x2x6x8xi64>) outs(%[[INIT]] : tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> +// CHECK: return %[[CONV]] : tensor<1x2x2x8xi64> +func.func @conv_2d_nhwc_fhwc_i64(%input: tensor<1x4x4x6xi64>, %filter: tensor<8x2x2x6xi64>, %init: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor<1x4x4x6xi64>, tensor<8x2x2x6xi64>) + outs (%init: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> + return %0 : tensor<1x2x2x8xi64> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_i32 +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi32>, %[[FILTER:.+]]: tensor<8x2x2x6xi32>, %[[INIT:.+]]: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi32> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi32>) outs(%[[NEWF]] : tensor<2x2x6x8xi32>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi32>, tensor<2x2x6x8xi32>) outs(%[[INIT]] : tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> +// CHECK: return %[[CONV]] : tensor<1x2x2x8xi32> +func.func @conv_2d_nhwc_fhwc_i32(%input: tensor<1x4x4x6xi32>, %filter: tensor<8x2x2x6xi32>, %init: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor<1x4x4x6xi32>, tensor<8x2x2x6xi32>) + outs (%init: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> + return %0 : tensor<1x2x2x8xi32> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_i16 +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi16>, %[[FILTER:.+]]: tensor<8x2x2x6xi16>, %[[INIT:.+]]: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi16> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi16>) outs(%[[NEWF]] : tensor<2x2x6x8xi16>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi16>, tensor<2x2x6x8xi16>) outs(%[[INIT]] : tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> +// CHECK: return %[[CONV]] : tensor<1x2x2x8xi16> +func.func @conv_2d_nhwc_fhwc_i16(%input: tensor<1x4x4x6xi16>, %filter: tensor<8x2x2x6xi16>, %init: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor<1x4x4x6xi16>, tensor<8x2x2x6xi16>) + outs (%init: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> + return %0 : tensor<1x2x2x8xi16> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_i8 +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi8>, %[[FILTER:.+]]: tensor<8x2x2x6xi8>, %[[INIT:.+]]: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi8> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi8>) outs(%[[NEWF]] : tensor<2x2x6x8xi8>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi8>, tensor<2x2x6x8xi8>) outs(%[[INIT]] : tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> +// CHECK: return %[[CONV]] : tensor<1x2x2x8xi8> +func.func @conv_2d_nhwc_fhwc_i8(%input: tensor<1x4x4x6xi8>, %filter: tensor<8x2x2x6xi8>, %init: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor<1x4x4x6xi8>, tensor<8x2x2x6xi8>) + outs (%init: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> + return %0 : tensor<1x2x2x8xi8> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_q +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>, %[[A:.+]]: i32, %[[B:.+]]: i32) -> tensor<1x2x2x8xf32> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]], %[[A]], %[[B]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>, i32, i32) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> +// CHECK: return %[[CONV]] : tensor<1x2x2x8xf32> + func.func @conv_2d_nhwc_fhwc_q(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>, %a: i32, %b: i32) -> tensor<1x2x2x8xf32> { + %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter, %a, %b: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>, i32, i32) + outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> + return %0 : tensor<1x2x2x8xf32> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32_unit_stride +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> +// CHECK: return %[[CONV]] : tensor<1x3x3x8xf32> +func.func @conv_2d_nhwc_fhwc_f32_unit_stride(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) + outs (%init: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> + return %0 : tensor<1x3x3x8xf32> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32_2_dialation +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> { +// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32> +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0] +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> +// CHECK: return %[[CONV]] : tensor<1x2x2x8xf32> +func.func @conv_2d_nhwc_fhwc_f32_2_dialation(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<2> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) + outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> + return %0 : tensor<1x2x2x8xf32> +} + +// CHECK-LABEL: @conv_2d_nhwc_fhwc_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<1x4x4x6xf32>, %[[FILTER:.+]]: memref<8x2x2x6xf32>, %[[INIT:.+]]: memref<1x2x2x8xf32>) -> memref<1x2x2x8xf32> { +// CHECK-DAG: %[[NEWF:.+]] = memref.alloc() : memref<2x2x6x8xf32> +// CHECK: linalg.transpose ins(%[[FILTER]] : memref<8x2x2x6xf32>) outs(%[[NEWF]] : memref<2x2x6x8xf32>) permutation = [1, 2, 3, 0] +// CHECK: linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[NEWF]] : memref<1x4x4x6xf32>, memref<2x2x6x8xf32>) outs(%[[INIT]] : memref<1x2x2x8xf32>) +// CHECK: return %[[INIT]] : memref<1x2x2x8xf32> +func.func @conv_2d_nhwc_fhwc_memref(%input: memref<1x4x4x6xf32>, %filter: memref<8x2x2x6xf32>, %init: memref<1x2x2x8xf32>) -> memref<1x2x2x8xf32> { + linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: memref<1x4x4x6xf32>, memref<8x2x2x6xf32>) + outs (%init: memref<1x2x2x8xf32>) + return %init : memref<1x2x2x8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc", "linalg.conv_2d_nhwc_fhwc_q"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +}