Skip to content

Commit b0e00ca

Browse files
authored
[mlir][python] fix replace=True for register_operation and register_type_caster (#70264)
<img src="https://github.com/llvm/llvm-project/assets/5657668/443852b6-ac25-45bb-a38b-5dfbda09d5a7" height="400" /> <p></p> So turns out that none of the `replace=True` things actually work because of the map caches (except for `register_attribute_builder(replace=True)`, which doesn't use such a cache). This was hidden by a series of unfortunate events: 1. `register_type_caster` failure was hidden because it was the same `TestIntegerRankedTensorType` being replaced with itself (d'oh). 2. `register_operation` failure was hidden behind the "order of events" in the lifecycle of typical extension import/use. Since extensions are loaded/registered almost immediately after generated builders are registered, there is no opportunity for the `operationClassMapCache` to be populated (through e.g., `module.body.operations[2]` or `module.body.operations[2].opview` or something). Of course as soon as you as actually do "late-bind/late-register" the extension, you see it's not successfully replacing the stale one in `operationClassMapCache`. I'll take this opportunity to propose we ditch the caches all together. I've been cargo-culting them but I really don't understand how they work. There's this comment above `operationClassMapCache` ```cpp /// Cache of operation name to external operation class object. This is /// maintained on lookup as a shadow of operationClassMap in order for repeat /// lookups of the classes to only incur the cost of one hashtable lookup. llvm::StringMap<pybind11::object> operationClassMapCache; ``` But I don't understand how that's true given that the canonical thing `operationClassMap` is already a map: ```cpp /// Map of full operation name to external operation class object. llvm::StringMap<pybind11::object> operationClassMap; ``` Maybe it wasn't always the case? Anyway things work now but it seems like an unnecessary layer of complexity for not much gain? But maybe I'm wrong.
1 parent 3343bd9 commit b0e00ca

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

mlir/lib/Bindings/Python/IRModule.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
8282
if (found && !found.is_none() && !replace)
8383
throw std::runtime_error("Type caster is already registered");
8484
found = std::move(typeCaster);
85+
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
86+
if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
87+
typeCasterMapCache[mlirTypeID] = found;
88+
}
8589
}
8690

8791
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
@@ -104,6 +108,10 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
104108
.str());
105109
}
106110
found = std::move(pyClass);
111+
auto foundIt = operationClassMapCache.find(operationName);
112+
if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
113+
operationClassMapCache[operationName] = found;
114+
}
107115
}
108116

109117
std::optional<py::function>

mlir/test/python/dialects/python_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,19 @@ def type_caster(pytype):
510510
except RuntimeError as e:
511511
print(e)
512512

513+
def type_caster(pytype):
514+
return RankedTensorType(pytype)
515+
516+
# python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
517+
# So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
518+
register_type_caster(c.typeid, type_caster, replace=True)
519+
520+
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
521+
# CHECK: tensor<10x10xi5>
522+
print(d.type)
523+
# CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
524+
print("ranked tensor type", repr(d.type))
525+
513526
def type_caster(pytype):
514527
return test.TestIntegerRankedTensorType(pytype)
515528

mlir/test/python/ir/operation.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import itertools
66
from mlir.ir import *
77
from mlir.dialects.builtin import ModuleOp
8+
from mlir.dialects import arith
9+
from mlir.dialects._ods_common import _cext
810

911

1012
def run(f):
@@ -646,6 +648,7 @@ def testKnownOpView():
646648
%1 = "custom.f32"() : () -> f32
647649
%2 = "custom.f32"() : () -> f32
648650
%3 = arith.addf %1, %2 : f32
651+
%4 = arith.constant 0 : i32
649652
"""
650653
)
651654
print(module)
@@ -668,6 +671,31 @@ def testKnownOpView():
668671
# CHECK: OpView object
669672
print(repr(custom))
670673

674+
# constant should map to an extension OpView class in the arithmetic dialect.
675+
constant = module.body.operations[3]
676+
# CHECK: <mlir.dialects.arith.ConstantOp object
677+
print(repr(constant))
678+
# Checks that the arith extension is being registered successfully
679+
# (literal_value is a property on the extension class but not on the default OpView).
680+
# CHECK: literal value 0
681+
print("literal value", constant.literal_value)
682+
683+
# Checks that "late" registration/replacement (i.e., post all module loading/initialization)
684+
# is working correctly.
685+
@_cext.register_operation(arith._Dialect, replace=True)
686+
class ConstantOp(arith.ConstantOp):
687+
def __init__(self, result, value, *, loc=None, ip=None):
688+
if isinstance(value, int):
689+
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
690+
elif isinstance(value, float):
691+
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
692+
else:
693+
super().__init__(value, loc=loc, ip=ip)
694+
695+
constant = module.body.operations[3]
696+
# CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object
697+
print(repr(constant))
698+
671699

672700
# CHECK-LABEL: TEST: testSingleResultProperty
673701
@run

0 commit comments

Comments
 (0)