diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 027ef3605aeba..044b6cc07d3d6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1244,6 +1244,14 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { } LogicalResult vector::ExtractOp::verify() { + // Note: This check must come before getMixedPosition() to prevent a crash. + auto dynamicMarkersCount = + llvm::count_if(getStaticPosition(), ShapedType::isDynamic); + if (static_cast(dynamicMarkersCount) != getDynamicPosition().size()) + return emitOpError( + "mismatch between dynamic and static positions (kDynamic marker but no " + "corresponding dynamic position) -- this can only happen due to an " + "incorrect fold/rewrite"); auto position = getMixedPosition(); if (position.size() > static_cast(getSourceVectorType().getRank())) return emitOpError( @@ -1285,6 +1293,9 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { globalPosition.append(extrPos.rbegin(), extrPos.rend()); while (ExtractOp nextOp = currentOp.getVector().getDefiningOp()) { currentOp = nextOp; + // TODO: Canonicalization for dynamic position not implemented yet. + if (currentOp.hasDynamicPosition()) + return failure(); ArrayRef extrPos = currentOp.getStaticPosition(); globalPosition.append(extrPos.rbegin(), extrPos.rend()); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 05615b96ae6d6..924886c500309 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1693,6 +1693,18 @@ func.func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, % // ----- +// CHECK-LABEL: extract_from_extract_chain_should_not_fold_dynamic_extracts +// CHECK-SAME: (%[[VEC:.*]]: vector<2x4xf32>, %[[IDX:.*]]: index) +// CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]]] : vector<4xf32> from vector<2x4xf32> +// CHECK: %[[B:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32> +func.func @extract_from_extract_chain_should_not_fold_dynamic_extracts(%v: vector<2x4xf32>, %index: index) -> f32 { + %0 = vector.extract %v[%index] : vector<4xf32> from vector<2x4xf32> + %1 = vector.extract %0[1] : f32 from vector<4xf32> + return %1 : f32 +} + +// ----- + // CHECK-LABEL: extract_extract_strided2 // CHECK-SAME: %[[A:.*]]: vector<2x4xf32> // CHECK: %[[V:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>