Skip to content

[mlir][Tosa] fix fp16/bf16 support for AvgPool2d #68718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 13, 2023

Conversation

fabrizio-indirli
Copy link
Contributor

Currently, the AvgPool2d operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, conversely to what stated in the TOSA specification.
This issue was previously raised: #63424 here on Github and it is due to a bug in the AvgPool2d verifier.

This patch fixes the AvgPool2d verifier to accept fp16 & bf16 datatype for input/output tensors and accumulator, and it adds related LIT test cases in Tosa/ops.mlir.

In TOSA MLIR dialect, fix the AvgPool2d verifier to accept
fp16 & bf16 datatype for input/output tensors and accumulator.
Add related test case in Tosa/ops.mlir.
@llvmbot
Copy link
Member

llvmbot commented Oct 10, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: None (fabrizio-indirli)

Changes

Currently, the AvgPool2d operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, conversely to what stated in the TOSA specification.
This issue was previously raised: #63424 here on Github and it is due to a bug in the AvgPool2d verifier.

This patch fixes the AvgPool2d verifier to accept fp16 & bf16 datatype for input/output tensors and accumulator, and it adds related LIT test cases in Tosa/ops.mlir.


Full diff: https://github.com/llvm/llvm-project/pull/68718.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+10-8)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+14)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a719171b2b359d2..6db04fe38bcd356 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -247,18 +247,20 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
     return emitOpError("accumulator type for integer tensor is not i32");
 
-  if ((inputETy.isBF16() || inputETy.isF16()) &&
-      !(accType.isF16() || accType.isF32()))
-    return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32");
+  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
+    return emitOpError("accumulator type for f16 tensor is not f16/f32");
+
+  if (inputETy.isBF16() && !accType.isF32())
+    return emitOpError("accumulator type for bf16 tensor is not f32");
 
   if (inputETy.isF32() && !accType.isF32())
     return emitOpError("accumulator type for f32 tensor is not f32");
 
-  if (inputETy.isF32() && resultETy.isF32())
-    return success();
-  if (inputETy.isInteger(8) && resultETy.isInteger(8))
-    return success();
-  if (inputETy.isInteger(16) && resultETy.isInteger(16))
+  if ((inputETy.isF32() && resultETy.isF32()) ||
+      (inputETy.isF16() && resultETy.isF16()) ||
+      (inputETy.isBF16() && resultETy.isBF16()) ||
+      (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
+      (inputETy.isInteger(16) && resultETy.isInteger(16)))
     return success();
 
   return emitOpError("input/output element types are incompatible.");
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 7d7f2d31a4244cd..e62bea515d06baa 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -16,6 +16,20 @@ func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32
   return %0 : tensor<1x7x7x9xf32>
 }
 
+// -----
+// CHECK-LABEL: avg_pool2d_f16
+func.func @test_avg_pool2d_f16(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
+  return %0 : tensor<1x7x7x9xf16>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_f16_accumf32
+func.func @test_avg_pool2d_f16_accumf32(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
+  return %0 : tensor<1x7x7x9xf16>
+}
+
 // -----
 // CHECK-LABEL: avg_pool2d_i8
 func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> {

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this, looks correct to me.

@fabrizio-indirli
Copy link
Contributor Author

@eric-k256 Thanks for the review! :)
I can't merge since I don't have write access to this repository.
If this is validated, could this be merged?

@eric-k256 eric-k256 merged commit 2c9ddfc into llvm:main Oct 13, 2023
ttjost pushed a commit to Xilinx/llvm-project that referenced this pull request Feb 20, 2024
Currently, the AvgPool2d operation in the TOSA MLIR dialect does not
accept half-precision Fp16 and Bf16 tensors, conversely to what stated
in the [TOSA
specification](https://www.mlplatform.org/tosa/tosa_spec.html#_avg_pool2d).
This issue was previously raised: llvm#63424 here on Github and it is due to
a bug in the AvgPool2d verifier.

This patch fixes the AvgPool2d verifier to accept fp16 & bf16 datatype
for input/output tensors and accumulator, and it adds related LIT test
cases in Tosa/ops.mlir.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants