Skip to content

Commit c481b2d

Browse files
authored
Minor quick fix for RewriterContext (#2314)
For a reported import error issue. Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 2ae13be commit c481b2d

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import onnx.reference.ops
1717

1818
import onnxscript.ir as ir
19-
import onnxscript.rewriter.pattern as orp
19+
import onnxscript.ir._tape as _tape
2020
import onnxscript.utils.utils as utils
2121

2222
DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024
@@ -202,10 +202,9 @@ def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
202202
# the ir.Value or ir.Values to replace the output values of the node, when the new nodes
203203
# can be inferred from the RewriterContext used to build the new nodes.
204204

205+
RewriterContext = _tape.Builder
205206
ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None]
206-
PartialEvaluatorFunction = Callable[
207-
[ir.Node, orp.RewriterContext, OptimizerState], ReturnValue
208-
]
207+
PartialEvaluatorFunction = Callable[[ir.Node, RewriterContext, OptimizerState], ReturnValue]
209208

210209

211210
@dataclasses.dataclass
@@ -991,7 +990,7 @@ def process_node(self, node: ir.Node) -> Replacement | None:
991990
op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
992991
for optimizer in op_optimizers:
993992
assert optimizer
994-
context = orp.RewriterContext()
993+
context = RewriterContext()
995994
output = optimizer(node, context, self._state)
996995
if output is not None:
997996
if isinstance(output, Replacement):

onnxscript/version_converter/_version_converter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import logging
1010
from typing import Callable, Sequence, Union
1111

12+
import onnxscript.ir._tape as _tape
1213
import onnxscript.ir.convenience as ir_convenience
13-
import onnxscript.rewriter.pattern as orp
1414
from onnxscript import ir
1515

1616
logger = logging.getLogger(__name__)
@@ -35,8 +35,9 @@ class Replacement:
3535
# A version-adapter function takes a node, a RewriterContext and returns
3636
# a Replacement for the node or None (if no replacement is needed).
3737

38+
RewriterContext = _tape.Builder
3839
ReturnValue = Union[Sequence[ir.Value], ir.Value, None]
39-
AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue]
40+
AdapterFunction = Callable[[ir.Node, RewriterContext], ReturnValue]
4041

4142

4243
def version_supported(model: ir.Model, target_version: int) -> bool:
@@ -236,7 +237,7 @@ def process_node(
236237
)
237238
if adapter is None:
238239
return None
239-
context = orp.RewriterContext()
240+
context = RewriterContext()
240241
output = adapter(node, context)
241242
if output is not None:
242243
if isinstance(output, ir.Value):

0 commit comments

Comments
 (0)