-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][Tosa] fix fp16/bf16 support for AvgPool2d #68718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
In TOSA MLIR dialect, fix the AvgPool2d verifier to accept fp16 & bf16 datatype for input/output tensors and accumulator. Add related test case in Tosa/ops.mlir.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: None (fabrizio-indirli) ChangesCurrently, the AvgPool2d operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, conversely to what stated in the TOSA specification. This patch fixes the AvgPool2d verifier to accept fp16 & bf16 datatype for input/output tensors and accumulator, and it adds related LIT test cases in Tosa/ops.mlir. Full diff: https://github.com/llvm/llvm-project/pull/68718.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a719171b2b359d2..6db04fe38bcd356 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -247,18 +247,20 @@ LogicalResult tosa::AvgPool2dOp::verify() {
if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
return emitOpError("accumulator type for integer tensor is not i32");
- if ((inputETy.isBF16() || inputETy.isF16()) &&
- !(accType.isF16() || accType.isF32()))
- return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32");
+ if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
+ return emitOpError("accumulator type for f16 tensor is not f16/f32");
+
+ if (inputETy.isBF16() && !accType.isF32())
+ return emitOpError("accumulator type for bf16 tensor is not f32");
if (inputETy.isF32() && !accType.isF32())
return emitOpError("accumulator type for f32 tensor is not f32");
- if (inputETy.isF32() && resultETy.isF32())
- return success();
- if (inputETy.isInteger(8) && resultETy.isInteger(8))
- return success();
- if (inputETy.isInteger(16) && resultETy.isInteger(16))
+ if ((inputETy.isF32() && resultETy.isF32()) ||
+ (inputETy.isF16() && resultETy.isF16()) ||
+ (inputETy.isBF16() && resultETy.isBF16()) ||
+ (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
+ (inputETy.isInteger(16) && resultETy.isInteger(16)))
return success();
return emitOpError("input/output element types are incompatible.");
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 7d7f2d31a4244cd..e62bea515d06baa 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -16,6 +16,20 @@ func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32
return %0 : tensor<1x7x7x9xf32>
}
+// -----
+// CHECK-LABEL: avg_pool2d_f16
+func.func @test_avg_pool2d_f16(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
+ %0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
+ return %0 : tensor<1x7x7x9xf16>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_f16_accumf32
+func.func @test_avg_pool2d_f16_accumf32(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
+ %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
+ return %0 : tensor<1x7x7x9xf16>
+}
+
// -----
// CHECK-LABEL: avg_pool2d_i8
func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this, looks correct to me.
@eric-k256 Thanks for the review! :) |
Currently, the AvgPool2d operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, conversely to what stated in the [TOSA specification](https://www.mlplatform.org/tosa/tosa_spec.html#_avg_pool2d). This issue was previously raised: llvm#63424 here on Github and it is due to a bug in the AvgPool2d verifier. This patch fixes the AvgPool2d verifier to accept fp16 & bf16 datatype for input/output tensors and accumulator, and it adds related LIT test cases in Tosa/ops.mlir.
Currently, the AvgPool2d operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, conversely to what stated in the TOSA specification.
This issue was previously raised: #63424 here on Github and it is due to a bug in the AvgPool2d verifier.
This patch fixes the AvgPool2d verifier to accept fp16 & bf16 datatype for input/output tensors and accumulator, and it adds related LIT test cases in Tosa/ops.mlir.