-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][complex] Allow integer element types in complex.constant
ops
#74564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][complex] Allow integer element types in complex.constant
ops
#74564
Conversation
The op used to support only float element types. This was inconsistent with `ConstantOp::isBuildableWith`, which allows integer element types. The complex type allows any float/integer element type. Note: The other complex dialect ops do not support non-float element types yet. The purpose of this change to fix `Tensor/canonicalize.mlir`, which is currently failing when verifying the IR after each pattern application (llvm#74270). ``` within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: error: 'complex.constant' op result #0 must be complex type with floating-point elements, but got 'complex<i32>' %complex1 = tensor.extract %c1[] : tensor<complex<i32>> ^ within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: note: see current operation: %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32> "func.func"() <{function_type = () -> tensor<3xcomplex<i32>>, sym_name = "extract_from_elements_complex_i"}> ({ %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32> %1 = "arith.constant"() <{value = dense<(3,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>> %2 = "arith.constant"() <{value = dense<(1,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>> %3 = "tensor.extract"(%1) : (tensor<complex<i32>>) -> complex<i32> %4 = "tensor.from_elements"(%0, %3, %0) : (complex<i32>, complex<i32>, complex<i32>) -> tensor<3xcomplex<i32>> "func.return"(%4) : (tensor<3xcomplex<i32>>) -> () }) : () -> () ```
@llvm/pr-subscribers-mlir-complex @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe op used to support only float element types. This was inconsistent with Note: The other complex dialect ops do not support non-float element types yet. The purpose of this change to fix
Full diff: https://github.com/llvm/llvm-project/pull/74564.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index ada6c14b5b713..e19d714cadf8a 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -145,7 +145,7 @@ def ConstantOp : Complex_Op<"constant", [
}];
let arguments = (ins ArrayAttr:$value);
- let results = (outs Complex<AnyFloat>:$complex);
+ let results = (outs AnyComplex:$complex);
let assemblyFormat = "$value attr-dict `:` type($complex)";
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..0557de65ff43c 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -58,10 +58,12 @@ LogicalResult ConstantOp::verify() {
}
auto complexEltTy = getType().getElementType();
- auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]);
- auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]);
- if (!re || !im)
- return emitOpError("requires attribute's elements to be float attributes");
+ if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
+ !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
+ return emitOpError(
+ "requires attribute's elements to be float or integer attributes");
+ auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
+ auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
return emitOpError()
<< "requires attribute's element types (" << re.getType() << ", "
diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index 1050ad0dcd530..96f17b2898c83 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -11,6 +11,9 @@ func.func @ops(%f: f32) {
// CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32>
%cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32>
+ // CHECK: complex.constant [true, false] : complex<i1>
+ %cst_i1 = complex.constant [1 : i1, 0 : i1] : complex<i1>
+
// CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
%complex = complex.create %f, %f : complex<f32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks. We can progressively support the integer type for other complex ops.
The op used to support only float element types. This was inconsistent with
ConstantOp::isBuildableWith
, which allows integer element types. The complex type allows any float/integer element type.Note: The other complex dialect ops do not support non-float element types yet. The main purpose of this change to fix
Tensor/canonicalize.mlir
, which is currently failing when verifying the IR after each pattern application (#74270).