Skip to content

Commit e2bb47c

Browse files
[mlir][Arm] Fix invalid rewrite pattern API violations (#78246)
This commit fixes rewrite pattern API violations: * Rewrite pattern must return "failure" if the IR was not modified. * In-place op modifications must be communicated to the rewriter (`updateRootInPlace`). This commit fixes `test/Dialect/ArmSVE/legalize-vector-storage.mlir`, `test/Dialect/ArmSME/vector-ops-to-llvm.mlir`, `test/Dialect/ArmSME/tile-allocation-invalid.mlir`, `test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir`, `test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir`, `test/Conversion/ArmSMEToLLVM/unsupported.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`. --------- Co-authored-by: Benjamin Maxwell <[email protected]>
1 parent b1eaffd commit e2bb47c

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp

+14-6
Original file line numberDiff line numberDiff line change
@@ -232,14 +232,11 @@ struct AssignTileIDsPattern
232232
static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
233233
auto tileId = allocateTileId(*tileType, tilesInUse);
234234
bool tileIsInMemory = failed(tileId);
235-
if (!tileIsInMemory)
236-
setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
237-
else {
235+
if (tileIsInMemory) {
238236
// If we could not find a real tile ID, use an in-memory tile ID (ID >=
239237
// 16). A later pass will insert the necessary spills and reloads.
240238
tileId =
241239
getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
242-
setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
243240
tileOp->emitWarning(
244241
"failed to allocate SME virtual tile to operation, all tile "
245242
"operations will go through memory, expect degraded performance");
@@ -263,14 +260,25 @@ struct AssignTileIDsPattern
263260
SetVector<Operation *> dependantOps;
264261
findDependantOps(tileOp->getResult(0), dependantOps);
265262
auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
266-
rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
267263
for (auto *op : dependantOps) {
268264
if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
269265
auto currentTileId = dependantTileOp.getTileId();
270266
if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
271267
return dependantTileOp.emitOpError(
272268
"already assigned different SME virtual tile!");
273-
dependantTileOp.setTileId(tileIDAttr);
269+
}
270+
}
271+
272+
// Rewrite IR.
273+
if (!tileIsInMemory)
274+
setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
275+
else
276+
setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
277+
rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
278+
for (auto *op : dependantOps) {
279+
if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
280+
rewriter.updateRootInPlace(
281+
dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
274282
}
275283
}
276284

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ struct RelaxScalableVectorAllocaAlignment
106106

107107
// Set alignment based on the defaults for SVE vectors and predicates.
108108
unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
109-
allocaOp.setAlignment(aligment);
109+
rewriter.updateRootInPlace(allocaOp,
110+
[&] { allocaOp.setAlignment(aligment); });
110111

111112
return success();
112113
}

0 commit comments

Comments
 (0)