Skip to content

[mlir][VectorOps] Don't fold extract chains that include dynamic indices #68333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 6, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Oct 5, 2023

This is not yet supported and previously led to a confusing crash where an extract op with a kDynamic marker, but no dynamic positions was created. The verifier has also been updated to check for this, and hint at where the problem is likely to be.

This is not yet supported and previously led to a confusing crash where
an extract op with a kDynamic marker, but no dynamic positions was
created. The verifier has also been updated to check for this, and hint
at where the problem is likely to be.
@llvmbot
Copy link
Member

llvmbot commented Oct 5, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Changes

This is not yet supported and previously led to a confusing crash where an extract op with a kDynamic marker, but no dynamic positions was created. The verifier has also been updated to check for this, and hint at where the problem is likely to be.


Full diff: https://github.com/llvm/llvm-project/pull/68333.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+11)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 027ef3605aeba46..f84a574c4634fc3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1244,6 +1244,14 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 }
 
 LogicalResult vector::ExtractOp::verify() {
+  // Note: This check must come before getMixedPosition() to prevent a crash.
+  auto dynamicMarkersCount =
+      llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
+  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
+    return emitOpError(
+        "mismatch between dynamic and static positions (kDynamic marker but no "
+        "corresponding dynamic position) -- this can only happen due to an "
+        "incorrect/fold rewrite");
   auto position = getMixedPosition();
   if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
     return emitOpError(
@@ -1285,6 +1293,9 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
   globalPosition.append(extrPos.rbegin(), extrPos.rend());
   while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
     currentOp = nextOp;
+    // TODO: Canonicalization for dynamic position not implemented yet.
+    if (currentOp.hasDynamicPosition())
+      return failure();
     ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
     globalPosition.append(extrPos.rbegin(), extrPos.rend());
   }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 05615b96ae6d69f..924886c50030967 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1693,6 +1693,18 @@ func.func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, %
 
 // -----
 
+// CHECK-LABEL: extract_from_extract_chain_should_not_fold_dynamic_extracts
+//  CHECK-SAME: (%[[VEC:.*]]: vector<2x4xf32>, %[[IDX:.*]]: index)
+//       CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]]] : vector<4xf32> from vector<2x4xf32>
+//       CHECK: %[[B:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32>
+func.func @extract_from_extract_chain_should_not_fold_dynamic_extracts(%v: vector<2x4xf32>, %index: index) -> f32 {
+  %0 = vector.extract %v[%index] : vector<4xf32> from vector<2x4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// -----
+
 // CHECK-LABEL: extract_extract_strided2
 //  CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
 //       CHECK: %[[V:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@MacDue MacDue merged commit 469b9cb into llvm:main Oct 6, 2023
@MacDue MacDue deleted the extract_fold_fix branch October 6, 2023 13:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants