-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir][TosaToLinalg] Fix TosaToLinalg to restrict tosa.cast
types to integer or float
#128859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR fixes a bug where Full diff: https://github.com/llvm/llvm-project/pull/128859.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 607667fcc6945..e5994cdc777b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -524,6 +524,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::CastOp>(op)) {
Type srcTy = elementTy;
Type dstTy = resultTypes.front();
+ if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
+ (void)rewriter.notifyMatchFailure(op,"unsupported type");
+ return nullptr;
+ }
+
bool bitExtend =
srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 460e207d62de6..5db3f56cf459e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -54,3 +54,11 @@ func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
+
+// -----
+
+func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>> {
+ // expected-error@+1 {{failed to legalize operation 'tosa.cast'}}
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+ return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+}
|
@llvm/pr-subscribers-mlir-linalg Author: Longsheng Mou (CoTinker) ChangesThis PR fixes a bug where Full diff: https://github.com/llvm/llvm-project/pull/128859.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 607667fcc6945..e5994cdc777b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -524,6 +524,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::CastOp>(op)) {
Type srcTy = elementTy;
Type dstTy = resultTypes.front();
+ if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
+ (void)rewriter.notifyMatchFailure(op,"unsupported type");
+ return nullptr;
+ }
+
bool bitExtend =
srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 460e207d62de6..5db3f56cf459e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -54,3 +54,11 @@ func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
+
+// -----
+
+func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>> {
+ // expected-error@+1 {{failed to legalize operation 'tosa.cast'}}
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+ return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
…o integer or float This PR fixes a bug where `TosaToLinalg` incorrectly allows `tosa.cast` to accept types other than integer or float.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Just saw comments here regarding this issue (https://discourse.llvm.org/t/tosa-cast-valid-supported-data-types/84808). I think we should abandon this PR now. What do you think? @CoTinker |
This relates to the TOSA legalizations and casting of types like MLIR quantized ones that linalg will probably fail/misbehave. |
Quote from @sjarus 's comment here: https://discourse.llvm.org/t/tosa-cast-valid-supported-data-types/84808
|
Sure but the legalization cannot handle it properly so is failing. If the legalization supports it then yes we can enable and leave the validation pass for compliance checking. |
I agree. Let's merge it. |
This PR fixes a bug where
TosaToLinalg
incorrectly allowstosa.cast
to accept types other than integer or float.Fixes #116342.