Skip to content

Commit 21df325

Browse files
authored
[mlir,python] Expose replaceAllUsesExcept to Python bindings (llvm#115850)
Problem originally described in [the forums here](https://discourse.llvm.org/t/mlir-python-expose-replaceallusesexcept/83068/1). Using the MLIR Python bindings, the method [`replaceAllUsesWith`](https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#ac56b0fdb6246bcf7fa1805ba0eb71aa2) for `Value` is exposed, e.g., ```python orig_value.replace_all_uses_with( new_value ) ``` However, in my use-case I am separating a block into multiple blocks, so thus want to exclude certain Operations from having their Values replaced (since I want them to diverge). Within Value, we have [`replaceAllUsesExcept`](https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#a9ec8d5c61f8a6aada4062f609372cce4), where we can pass the Operations which should be skipped. This is not currently exposed in the Python bindings: this PR fixes this. Adds `replace_all_uses_except`, which works with individual Operations, and lists of Operations.
1 parent 581f755 commit 21df325

File tree

4 files changed

+124
-0
lines changed

4 files changed

+124
-0
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,15 @@ MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
956956
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of,
957957
MlirValue with);
958958

959+
/// Replace all uses of 'of' value with 'with' value, updating anything in the
960+
/// IR that uses 'of' to use 'with' instead, except if the user is listed in
961+
/// 'exceptions'. The 'exceptions' parameter is an array of MlirOperation
962+
/// pointers with a length of 'numExceptions'.
963+
MLIR_CAPI_EXPORTED void
964+
mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with,
965+
intptr_t numExceptions,
966+
MlirOperation *exceptions);
967+
959968
//===----------------------------------------------------------------------===//
960969
// OpOperand API.
961970
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ static const char kValueReplaceAllUsesWithDocstring[] =
178178
the IR that uses 'self' to use the other value instead.
179179
)";
180180

181+
static const char kValueReplaceAllUsesExceptDocstring[] =
182+
R"("Replace all uses of this value with the 'with' value, except for those
183+
in 'exceptions'. 'exceptions' can be either a single operation or a list of
184+
operations.
185+
)";
186+
181187
//------------------------------------------------------------------------------
182188
// Utilities.
183189
//------------------------------------------------------------------------------
@@ -3718,6 +3724,29 @@ void mlir::python::populateIRCore(py::module &m) {
37183724
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
37193725
},
37203726
kValueReplaceAllUsesWithDocstring)
3727+
.def(
3728+
"replace_all_uses_except",
3729+
[](MlirValue self, MlirValue with, PyOperation &exception) {
3730+
MlirOperation exceptedUser = exception.get();
3731+
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
3732+
},
3733+
py::arg("with"), py::arg("exceptions"),
3734+
kValueReplaceAllUsesExceptDocstring)
3735+
.def(
3736+
"replace_all_uses_except",
3737+
[](MlirValue self, MlirValue with, py::list exceptions) {
3738+
// Convert Python list to a SmallVector of MlirOperations
3739+
llvm::SmallVector<MlirOperation> exceptionOps;
3740+
for (py::handle exception : exceptions) {
3741+
exceptionOps.push_back(exception.cast<PyOperation &>().get());
3742+
}
3743+
3744+
mlirValueReplaceAllUsesExcept(
3745+
self, with, static_cast<intptr_t>(exceptionOps.size()),
3746+
exceptionOps.data());
3747+
},
3748+
py::arg("with"), py::arg("exceptions"),
3749+
kValueReplaceAllUsesExceptDocstring)
37213750
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
37223751
[](PyValue &self) { return self.maybeDownCast(); });
37233752
PyBlockArgument::bind(m);

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/IR/Visitors.h"
2929
#include "mlir/Interfaces/InferTypeOpInterface.h"
3030
#include "mlir/Parser/Parser.h"
31+
#include "llvm/ADT/SmallPtrSet.h"
3132
#include "llvm/Support/ThreadPool.h"
3233

3334
#include <cstddef>
@@ -1009,6 +1010,20 @@ void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
10091010
unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
10101011
}
10111012

1013+
void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
1014+
intptr_t numExceptions,
1015+
MlirOperation *exceptions) {
1016+
Value oldValueCpp = unwrap(oldValue);
1017+
Value newValueCpp = unwrap(newValue);
1018+
1019+
llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet;
1020+
for (intptr_t i = 0; i < numExceptions; ++i) {
1021+
exceptionSet.insert(unwrap(exceptions[i]));
1022+
}
1023+
1024+
oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
1025+
}
1026+
10121027
//===----------------------------------------------------------------------===//
10131028
// OpOperand API.
10141029
//===----------------------------------------------------------------------===//

mlir/test/python/ir/value.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,77 @@ def testValueReplaceAllUsesWith():
148148
print(f"Use operand_number: {use.operand_number}")
149149

150150

151+
# CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept
152+
@run
153+
def testValueReplaceAllUsesWithExcept():
154+
ctx = Context()
155+
ctx.allow_unregistered_dialects = True
156+
with Location.unknown(ctx):
157+
i32 = IntegerType.get_signless(32)
158+
module = Module.create()
159+
with InsertionPoint(module.body):
160+
value = Operation.create("custom.op1", results=[i32]).results[0]
161+
op1 = Operation.create("custom.op1", operands=[value])
162+
op2 = Operation.create("custom.op2", operands=[value])
163+
value2 = Operation.create("custom.op3", results=[i32]).results[0]
164+
value.replace_all_uses_except(value2, op1)
165+
166+
assert len(list(value.uses)) == 1
167+
168+
# CHECK: Use owner: "custom.op2"
169+
# CHECK: Use operand_number: 0
170+
for use in value2.uses:
171+
assert use.owner in [op2]
172+
print(f"Use owner: {use.owner}")
173+
print(f"Use operand_number: {use.operand_number}")
174+
175+
# CHECK: Use owner: "custom.op1"
176+
# CHECK: Use operand_number: 0
177+
for use in value.uses:
178+
assert use.owner in [op1]
179+
print(f"Use owner: {use.owner}")
180+
print(f"Use operand_number: {use.operand_number}")
181+
182+
183+
# CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions
184+
@run
185+
def testValueReplaceAllUsesWithMultipleExceptions():
186+
ctx = Context()
187+
ctx.allow_unregistered_dialects = True
188+
with Location.unknown(ctx):
189+
i32 = IntegerType.get_signless(32)
190+
module = Module.create()
191+
with InsertionPoint(module.body):
192+
value = Operation.create("custom.op1", results=[i32]).results[0]
193+
op1 = Operation.create("custom.op1", operands=[value])
194+
op2 = Operation.create("custom.op2", operands=[value])
195+
op3 = Operation.create("custom.op3", operands=[value])
196+
value2 = Operation.create("custom.op4", results=[i32]).results[0]
197+
198+
# Replace all uses of `value` with `value2`, except for `op1` and `op2`.
199+
value.replace_all_uses_except(value2, [op1, op2])
200+
201+
# After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`.
202+
assert len(list(value.uses)) == 2
203+
assert len(list(value2.uses)) == 1
204+
205+
# CHECK: Use owner: "custom.op3"
206+
# CHECK: Use operand_number: 0
207+
for use in value2.uses:
208+
assert use.owner in [op3]
209+
print(f"Use owner: {use.owner}")
210+
print(f"Use operand_number: {use.operand_number}")
211+
212+
# CHECK: Use owner: "custom.op2"
213+
# CHECK: Use operand_number: 0
214+
# CHECK: Use owner: "custom.op1"
215+
# CHECK: Use operand_number: 0
216+
for use in value.uses:
217+
assert use.owner in [op1, op2]
218+
print(f"Use owner: {use.owner}")
219+
print(f"Use operand_number: {use.operand_number}")
220+
221+
151222
# CHECK-LABEL: TEST: testValuePrintAsOperand
152223
@run
153224
def testValuePrintAsOperand():

0 commit comments

Comments
 (0)