Skip to content

[mlir][tosa] Enhance the conv2d verifier #128693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2025

Conversation

lhutton1
Copy link
Contributor

This commit adds additional checks to the conv2d verifier that check error_if conditions from the tosa specification. Notably, it adds padding, stride and dilation invalid value checking, output height and width checking and bias size checking.

@llvmbot
Copy link
Member

llvmbot commented Feb 25, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit adds additional checks to the conv2d verifier that check error_if conditions from the tosa specification. Notably, it adds padding, stride and dilation invalid value checking, output height and width checking and bias size checking.


Patch is 26.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128693.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+102-7)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+6-6)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+6-6)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+72)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+24-24)
  • (modified) mlir/test/Dialect/Tosa/quant-test.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+3-3)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 295b010da0ee0..5920079118ab5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -214,6 +214,16 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Tosa utilities.
+//===----------------------------------------------------------------------===//
+
+std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
+  if (lhs % rhs != 0)
+    return std::nullopt;
+  return lhs / rhs;
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -1666,13 +1676,6 @@ LogicalResult tosa::ResizeOp::verify() {
   const int64_t borderY = borderValues[0];
   const int64_t borderX = borderValues[1];
 
-  auto idivCheck = [](const int64_t lhs,
-                      const int64_t rhs) -> std::optional<int64_t> {
-    if (lhs % rhs != 0)
-      return std::nullopt;
-    return lhs / rhs;
-  };
-
   // Don't check with input height that could be broadcast (ih != 1)
   // since Linalg, a consumer of TOSA, expects broadcasting support
   // in resize to be available. Taking the cautious approach for now,
@@ -2012,6 +2015,98 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
 LogicalResult Conv2DOp::verify() {
   if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
     return failure();
+
+  llvm::ArrayRef<int64_t> padding = getPad();
+  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
+    return emitOpError("expect all padding values to be >= 0, got ") << padding;
+
+  llvm::ArrayRef<int64_t> strides = getStride();
+  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
+    return emitOpError("expect all stride values to be >= 1, got ") << strides;
+
+  llvm::ArrayRef<int64_t> dilations = getDilation();
+  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
+    return emitOpError("expect all dilation values to be >= 1, got ")
+           << dilations;
+
+  const RankedTensorType outputType =
+      llvm::dyn_cast<RankedTensorType>(getOutput().getType());
+  if (!outputType)
+    // Skip following checks if output is not ranked
+    return success();
+
+  const RankedTensorType inputType =
+      llvm::dyn_cast<RankedTensorType>(getInput().getType());
+  const RankedTensorType weightType =
+      llvm::dyn_cast<RankedTensorType>(getWeight().getType());
+
+  if (inputType && weightType) {
+    const int64_t ih = inputType.getDimSize(1);
+    const int64_t kh = weightType.getDimSize(1);
+    const int64_t oh = outputType.getDimSize(1);
+
+    const int64_t pad_top = padding[0];
+    const int64_t pad_bottom = padding[1];
+    const int64_t stride_y = strides[0];
+    const int64_t dilation_y = dilations[0];
+
+    if (ih != ShapedType::kDynamic) {
+      const std::optional<int64_t> calculatedOutHeightMinusOne = idivCheck(
+          (ih - 1) + pad_top + pad_bottom - (kh - 1) * dilation_y, stride_y);
+      if (!calculatedOutHeightMinusOne.has_value())
+        return emitOpError("expected (input_height - 1) + pad_top + pad_bottom "
+                           "- (kernel_height - 1) * dilation_y ")
+               << "to be wholly divisible by stride_y, got ((" << ih
+               << " - 1) + " << pad_top << " + " << pad_bottom << " - (" << kh
+               << " - 1) * " << dilation_y << ") / " << stride_y;
+      const int64_t calculatedOutHeight =
+          calculatedOutHeightMinusOne.value() + 1;
+      if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
+        return emitOpError("calculated output height did not match expected: ")
+               << "calculated=" << calculatedOutHeight << ", expected=" << oh;
+    }
+
+    const int64_t iw = inputType.getDimSize(2);
+    const int64_t kw = weightType.getDimSize(2);
+    const int64_t ow = outputType.getDimSize(2);
+
+    const int64_t pad_left = padding[2];
+    const int64_t pad_right = padding[3];
+    const int64_t stride_x = strides[1];
+    const int64_t dilation_x = dilations[1];
+
+    if (iw != ShapedType::kDynamic) {
+      const std::optional<int64_t> calculatedOutWidthMinusOne = idivCheck(
+          (iw - 1) + pad_left + pad_right - (kw - 1) * dilation_x, stride_x);
+      if (!calculatedOutWidthMinusOne.has_value())
+        return emitOpError("expected (input_width - 1) + pad_left + pad_right "
+                           "- (kernel_width - 1) * dilation_x ")
+               << "to be wholly divisible by stride_x, got ((" << iw
+               << " - 1) + " << pad_left << " + " << pad_right << " - (" << kw
+               << " - 1) * " << dilation_x << ") / " << stride_x;
+      const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
+      if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
+        return emitOpError("calculated output width did not match expected: ")
+               << "calculated=" << calculatedOutWidth << ", expected=" << ow;
+    }
+  }
+
+  const RankedTensorType biasType =
+      llvm::dyn_cast<RankedTensorType>(getBias().getType());
+  if (!biasType)
+    // Skip following checks if bias is not ranked
+    return success();
+
+  const int64_t bc = biasType.getDimSize(0);
+  const int64_t oc = outputType.getDimSize(3);
+  if (bc == ShapedType::kDynamic || oc == ShapedType::kDynamic)
+    // Skip following checks if bc or oc is dynamic dim
+    return success();
+
+  if (bc != oc && bc != 1)
+    return emitOpError(
+               "bias channels expected to be equal to output channels (")
+           << oc << ") or 1, got " << bc;
   return success();
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index a524359b49759..0e7ac8655689c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -465,16 +465,16 @@ func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tens
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
   // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
-  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
-  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x49x42x28xi32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x49x42x28xi32>) {
   // CHECK:   arith.extsi
   // CHECK:   linalg.yield
-  // CHECK: } -> tensor<1x45x40x28xi32>
-  // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
-  // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+  // CHECK: } -> tensor<1x49x42x28xi32>
+  // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x49x42x28xi32>) -> tensor<1x49x42x28xi32>
+  // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x49x42x28xi32>) -> tensor<1x49x42x28xi32>
 
   %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
-  %0 = tosa.conv2d %input, %weights, %bias, %zp, %zp {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x45x40x28xi32>
+  %0 = tosa.conv2d %input, %weights, %bias, %zp, %zp {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x49x42x28xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 0e177a076ee7a..3f37c2e9cc6ff 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -201,23 +201,23 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
 // -----
 
 // CHECK-LABEL: @conv2d_stride_2
-func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
+func.func @conv2d_stride_2(%arg0: tensor<4x11x11x2xf32>) -> tensor<4x6x6x3xf32> {
   // CHECK: tosa.conv2d
   %weight = "tosa.const"() {value = dense<[[[[1.0, 1.0]]], [[[1.0, 1.0]]], [[[1.0, 1.0]]]]> : tensor<3x1x1x2xf32>} : ()-> tensor<3x1x1x2xf32>
   %bias = "tosa.const"() {value = dense<0.0> : tensor<3xf32>} : ()-> tensor<3xf32>
-  %0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
-  return %0 : tensor<4x10x10x3xf32>
+  %0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x11x11x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x6x6x3xf32>
+  return %0 : tensor<4x6x6x3xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @conv2d_weight_2x2
-func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
+func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x9x9x1xf32> {
   // CHECK: tosa.conv2d
   %weight = "tosa.const"() {value = dense<[[[[1.0], [1.0]], [[1.0], [1.0]]]]> : tensor<1x2x2x1xf32>} : ()-> tensor<1x2x2x1xf32>
   %bias = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : ()-> tensor<1xf32>
-  %0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32>
-  return %0 : tensor<4x10x10x1xf32>
+  %0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x9x9x1xf32>
+  return %0 : tensor<4x9x9x1xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1aa8547cb2fdb..e63cbdc08ff10 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1197,3 +1197,75 @@ func.func @broadcast_resize_bilinear_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x
 
   return %resize : tensor<3x4x5x7xi32>
 }
+
+// -----
+
+func.func @test_conv2d_invalid_padding(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
+  // expected-error@+1 {{'tosa.conv2d' op expect all padding values to be >= 0, got 0, 0, -1, 0}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, -1, 0>, stride = array<i64: 1, 1>, local_bound = true}
+    : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_invalid_stride(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
+  // expected-error@+1 {{'tosa.conv2d' op expect all stride values to be >= 1, got 0, 1}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 0, 1>, local_bound = true}
+    : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_invalid_dilation(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
+  // expected-error@+1 {{'tosa.conv2d' op expect all dilation values to be >= 1, got 1, 0}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 0>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
+    : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_wholly_divisible_height(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
+  // expected-error@+1 {{'tosa.conv2d' op expected (input_height - 1) + pad_top + pad_bottom - (kernel_height - 1) * dilation_y to be wholly divisible by stride_y, got ((4 - 1) + 0 + 0 - (1 - 1) * 1) / 2}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 1>, local_bound = true}
+    : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_wholly_divisible_width(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
+  // expected-error@+1 {{'tosa.conv2d' op expected (input_width - 1) + pad_left + pad_right - (kernel_width - 1) * dilation_x to be wholly divisible by stride_x, got ((4 - 1) + 0 + 0 - (1 - 1) * 1) / 2}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 2>, local_bound = true}
+    : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_unexpected_output_height(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x6x4x8xf32> {
+  // expected-error@+1 {{'tosa.conv2d' op calculated output height did not match expected: calculated=4, expected=6}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
+    : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x6x4x8xf32>
+  return %0 : tensor<1x6x4x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_unexpected_output_width(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x6x8xf32> {
+  // expected-error@+1 {{'tosa.conv2d' op calculated output width did not match expected: calculated=4, expected=6}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
+    : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x6x8xf32>
+  return %0 : tensor<1x4x6x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<7xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
+  // expected-error@+1 {{'tosa.conv2d' op bias channels expected to be equal to output channels (8) or 1, got 7}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
+    : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 90c4551564d1e..ccc8119271f59 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -226,74 +226,74 @@ func.func @test_avgpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32
 
 // -----
 
-func.func @test_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+func.func @test_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<*xf32> {
   // expected-error@+1 {{'tosa.conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}}
   %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
-            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
-  return %0 : tensor<1x32x32x16xf32>
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
 }
 
 // -----
 
-func.func @test_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+func.func @test_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<*xf32> {
   // expected-error@+1 {{'tosa.conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}}
   %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
-            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
-  return %0 : tensor<1x32x32x16xf32>
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
 }
 
 // -----
 
-func.func @test_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+func.func @test_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x8225x32x16xf32> {
   // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}}
   %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} :
-            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
-  return %0 : tensor<1x32x32x16xf32>
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x8225x32x16xf32>
+  return %0 : tensor<1x8225x32x16xf32>
 }
 
 // -----
 
-func.func @test_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+func.func @test_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x8224x32x16xf32> {
   // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}}
   %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>...
[truncated]

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment

llvm::dyn_cast<RankedTensorType>(getWeight().getType());

if (inputType && weightType) {
const int64_t ih = inputType.getDimSize(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is replicated for W,H. Should we refactor to a common function?
I presume in conv3d verifier this shall be used across 3 dimensions as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've refactored into a common lambda function - we can pull this out when similar checks for conv3d are added

This commit adds additional checks to the conv2d verififer that check
error_if conditions from the tosa specification. Notably, it adds
padding, stride and dilation invalid value checking, output height and
width checking and bias size checking.

Change-Id: Ic5b2a459587bd781b9c8a55a912eb4b02eeb963d
Signed-off-by: Luke Hutton <[email protected]>
@lhutton1 lhutton1 force-pushed the conv2d-verifier-enhancement branch from 58e214b to 6a9e886 Compare February 25, 2025 19:02
@GeorgeARM GeorgeARM merged commit 0ba2000 into llvm:main Feb 26, 2025
11 checks passed
@lhutton1 lhutton1 deleted the conv2d-verifier-enhancement branch February 26, 2025 19:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants