Skip to content

Commit ea565d8

Browse files
[mlir][Transform] Relax the applicability of transform.foreach_match to also take into account the op itself
1 parent 42c25fd commit ea565d8

File tree

3 files changed

+28
-11
lines changed

3 files changed

+28
-11
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,16 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
481481
This operation consumes the operand and produces a new handle associated
482482
with the same payload. This is necessary to trigger invalidation of handles
483483
to any of the payload operations nested in the payload operations associated
484-
with the operand, as those are likely to be modified by actions. Note that
485-
the root payload operation associated with the operand are not matched.
484+
with the operand, as those are likely to be modified by actions.
485+
486+
By default, the root payload operation associated with the operand is not
487+
matched. This is to support the conservative case where applied actions may
488+
invalidate the root payload operation. If the optional `restrict_root`
489+
attribute is set, the root operand is guaranteed to not be invalidated by any
490+
of the applied actions. In such cases, the root payload operation is also
491+
matched. This is useful because matching the root payload operation is a
492+
common idiom, when e.g. matching a func.func directly and operations nested
493+
under it.
486494

487495
The operation succeeds if none of the matchers produced a definite failure
488496
during application and if all of the applied actions produced success. Note
@@ -495,13 +503,19 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
495503
}];
496504

497505
let arguments = (ins TransformHandleTypeInterface:$root,
506+
UnitAttr:$restrict_root,
498507
SymbolRefArrayAttr:$matchers,
499508
SymbolRefArrayAttr:$actions);
500509
let results = (outs TransformHandleTypeInterface:$updated);
501510

502-
let assemblyFormat =
503-
"`in` $root custom<ForeachMatchSymbols>($matchers, $actions) "
504-
"attr-dict `:` functional-type($root, $updated)";
511+
let assemblyFormat = [{
512+
(`restrict_root` $restrict_root^)?
513+
`in`
514+
$root
515+
custom<ForeachMatchSymbols>($matchers, $actions)
516+
attr-dict
517+
`:` functional-type($root, $updated)
518+
}];
505519

506520
let hasVerifier = 1;
507521
}

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -850,8 +850,9 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
850850

851851
for (Operation *root : state.getPayloadOps(getRoot())) {
852852
WalkResult walkResult = root->walk([&](Operation *op) {
853-
// Skip over the root op itself so we don't invalidate it.
854-
if (op == root)
853+
// If getRestrictRoot is not present, skip over the root op itself so we
854+
// don't invalidate it.
855+
if (!getRestrictRoot() && op == root)
855856
return WalkResult::advance();
856857

857858
DEBUG_MATCHER({
@@ -1556,10 +1557,10 @@ DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
15561557
::std::optional<::mlir::Operation *> maybeCurrent,
15571558
transform::TransformResults &results, transform::TransformState &state) {
15581559
if (!maybeCurrent.has_value()) {
1559-
DBGS_MATCHER() << "MatchOperationEmptyOp success\n";
1560+
DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
15601561
return DiagnosedSilenceableFailure::success();
15611562
}
1562-
DBGS_MATCHER() << "MatchOperationEmptyOp failure\n";
1563+
DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
15631564
return emitSilenceableError() << "operation is not empty";
15641565
}
15651566

@@ -1961,7 +1962,8 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
19611962
state.addAttribute(SymbolTable::getSymbolAttrName(),
19621963
builder.getStringAttr(symName));
19631964
state.addAttribute(getFunctionTypeAttrName(state.name),
1964-
TypeAttr::get(FunctionType::get(builder.getContext(), rootType, resultTypes)));
1965+
TypeAttr::get(FunctionType::get(builder.getContext(),
1966+
rootType, resultTypes)));
19651967
state.attributes.append(attrs.begin(), attrs.end());
19661968
state.addRegion();
19671969

mlir/test/Dialect/Linalg/match-ops-interpreter.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,13 @@ module attributes { transform.with_named_sequence } {
100100
}
101101

102102
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
103-
transform.foreach_match in %arg0
103+
transform.foreach_match restrict_root in %arg0
104104
@match_structured_suppress -> @do_nothing
105105
: (!transform.any_op) -> !transform.any_op
106106
transform.yield
107107
}
108108

109+
// expected-remark @below {{other}}
109110
func.func @payload() attributes { transform.target_tag = "start_here" } {
110111
// expected-remark @below {{other}}
111112
%D = arith.constant dense<1.0> : tensor<2x4xf32>

0 commit comments

Comments
 (0)