Skip to content

[mlir][complex] Allow integer element types in complex.constant ops #74564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Dec 6, 2023

The op used to support only float element types. This was inconsistent with ConstantOp::isBuildableWith, which allows integer element types. The complex type allows any float/integer element type.

Note: The other complex dialect ops do not support non-float element types yet. The main purpose of this change to fix Tensor/canonicalize.mlir, which is currently failing when verifying the IR after each pattern application (#74270).

within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: error: 'complex.constant' op result #0 must be complex type with floating-point elements, but got 'complex<i32>'
  %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
              ^
within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: note: see current operation: %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32>
"func.func"() <{function_type = () -> tensor<3xcomplex<i32>>, sym_name = "extract_from_elements_complex_i"}> ({
  %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32>
  %1 = "arith.constant"() <{value = dense<(3,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>>
  %2 = "arith.constant"() <{value = dense<(1,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>>
  %3 = "tensor.extract"(%1) : (tensor<complex<i32>>) -> complex<i32>
  %4 = "tensor.from_elements"(%0, %3, %0) : (complex<i32>, complex<i32>, complex<i32>) -> tensor<3xcomplex<i32>>
  "func.return"(%4) : (tensor<3xcomplex<i32>>) -> ()
}) : () -> ()

The op used to support only float element types. This was inconsistent with `ConstantOp::isBuildableWith`, which allows integer element types. The complex type allows any float/integer element type.

Note: The other complex dialect ops do not support non-float element types yet. The purpose of this change to fix `Tensor/canonicalize.mlir`, which is currently failing when verifying the IR after each pattern application (llvm#74270).

```
within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: error: 'complex.constant' op result #0 must be complex type with floating-point elements, but got 'complex<i32>'
  %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
              ^
within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: note: see current operation: %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32>
"func.func"() <{function_type = () -> tensor<3xcomplex<i32>>, sym_name = "extract_from_elements_complex_i"}> ({
  %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32>
  %1 = "arith.constant"() <{value = dense<(3,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>>
  %2 = "arith.constant"() <{value = dense<(1,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>>
  %3 = "tensor.extract"(%1) : (tensor<complex<i32>>) -> complex<i32>
  %4 = "tensor.from_elements"(%0, %3, %0) : (complex<i32>, complex<i32>, complex<i32>) -> tensor<3xcomplex<i32>>
  "func.return"(%4) : (tensor<3xcomplex<i32>>) -> ()
}) : () -> ()
```
@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2023

@llvm/pr-subscribers-mlir-complex

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The op used to support only float element types. This was inconsistent with ConstantOp::isBuildableWith, which allows integer element types. The complex type allows any float/integer element type.

Note: The other complex dialect ops do not support non-float element types yet. The purpose of this change to fix Tensor/canonicalize.mlir, which is currently failing when verifying the IR after each pattern application (#74270).

within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: error: 'complex.constant' op result #<!-- -->0 must be complex type with floating-point elements, but got 'complex&lt;i32&gt;'
  %complex1 = tensor.extract %c1[] : tensor&lt;complex&lt;i32&gt;&gt;
              ^
within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: note: see current operation: %0 = "complex.constant"() &lt;{value = [1 : i32, 2 : i32]}&gt; : () -&gt; complex&lt;i32&gt;
"func.func"() &lt;{function_type = () -&gt; tensor&lt;3xcomplex&lt;i32&gt;&gt;, sym_name = "extract_from_elements_complex_i"}&gt; ({
  %0 = "complex.constant"() &lt;{value = [1 : i32, 2 : i32]}&gt; : () -&gt; complex&lt;i32&gt;
  %1 = "arith.constant"() &lt;{value = dense&lt;(3,2)&gt; : tensor&lt;complex&lt;i32&gt;&gt;}&gt; : () -&gt; tensor&lt;complex&lt;i32&gt;&gt;
  %2 = "arith.constant"() &lt;{value = dense&lt;(1,2)&gt; : tensor&lt;complex&lt;i32&gt;&gt;}&gt; : () -&gt; tensor&lt;complex&lt;i32&gt;&gt;
  %3 = "tensor.extract"(%1) : (tensor&lt;complex&lt;i32&gt;&gt;) -&gt; complex&lt;i32&gt;
  %4 = "tensor.from_elements"(%0, %3, %0) : (complex&lt;i32&gt;, complex&lt;i32&gt;, complex&lt;i32&gt;) -&gt; tensor&lt;3xcomplex&lt;i32&gt;&gt;
  "func.return"(%4) : (tensor&lt;3xcomplex&lt;i32&gt;&gt;) -&gt; ()
}) : () -&gt; ()

Full diff: https://github.com/llvm/llvm-project/pull/74564.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td (+1-1)
  • (modified) mlir/lib/Dialect/Complex/IR/ComplexOps.cpp (+6-4)
  • (modified) mlir/test/Dialect/Complex/ops.mlir (+3)
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index ada6c14b5b713..e19d714cadf8a 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -145,7 +145,7 @@ def ConstantOp : Complex_Op<"constant", [
   }];
 
   let arguments = (ins ArrayAttr:$value);
-  let results = (outs Complex<AnyFloat>:$complex);
+  let results = (outs AnyComplex:$complex);
 
   let assemblyFormat = "$value attr-dict `:` type($complex)";
   let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..0557de65ff43c 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -58,10 +58,12 @@ LogicalResult ConstantOp::verify() {
   }
 
   auto complexEltTy = getType().getElementType();
-  auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]);
-  auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]);
-  if (!re || !im)
-    return emitOpError("requires attribute's elements to be float attributes");
+  if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
+      !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
+    return emitOpError(
+        "requires attribute's elements to be float or integer attributes");
+  auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
+  auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
   if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
     return emitOpError()
            << "requires attribute's element types (" << re.getType() << ", "
diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index 1050ad0dcd530..96f17b2898c83 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -11,6 +11,9 @@ func.func @ops(%f: f32) {
   // CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32>
   %cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32>
 
+  // CHECK: complex.constant [true, false] : complex<i1>
+  %cst_i1 = complex.constant [1 : i1, 0 : i1] : complex<i1>
+
   // CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
   %complex = complex.create %f, %f : complex<f32>
 

Copy link
Member

@Lewuathe Lewuathe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks. We can progressively support the integer type for other complex ops.

@matthias-springer matthias-springer merged commit 1612993 into llvm:main Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:complex MLIR complex dialect mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants