Skip to content

Commit fed67b9

Browse files
gramalingamCopilot
andauthored
Refactor: Introduce OpBuilderBase/TapeBuilder as unified op-building interface (#2905)
## Summary Introduces `OpBuilderBase` (ABC) and `TapeBuilder` (concrete implementation) as the unified interface for building IR nodes across the rewriter, optimizer, and version converter. This is a step towards unifying these builders with the recently introduced GraphBuilder. As of now, there is still some differences internally (as these are oriented towards incrementally modifying an existing graph), but for end-users (eg., writing rewrite-rules) we should be able to move towards same API. ## Changes ### New: `onnxscript/tape_builder.py` - **`OpBuilderBase`** — Abstract base class providing the op-building API: - Dynamic dispatch: `op.Relu(x)`, `op.MatMul(a, b, _domain=...)` - Explicit creation: `op.op("Conv", inputs, attributes, domain=...)` - Initializer creation: `op.initializer(tensor, name=...)` - Subclasses implement `_add_node`, `_add_initializer`, `_record_opset` - **`TapeBuilder`** — Concrete subclass with list-based storage and harvesting properties (`nodes`, `initializers`, `used_opsets`) - Both are exported as public names from `onnxscript` ### Rewriter - `rewriter/_context.py` slimmed to re-exports + local alias `RewriterContext = OpBuilderBase` - `_rewrite_rule.py` uses `TapeBuilder()` directly; harvests from the context - Deleted `_node_sink.py` (the intermediate NodeSink/TapeSink layer) ### Optimizer - `optimizer/_constant_folding.py` defines `OptimizerContext = OpBuilderBase` locally - Uses `TapeBuilder()` instead of the old `_tape.Builder()` ### Version Converter - `version_converter/_version_converter.py` defines `VCContext = OpBuilderBase` locally - Uses `TapeBuilder()` instead of the old `_tape.Builder()` ### Cleanup - Deleted `onnxscript/ir/_tape.py` and `_tape_test.py` (fully superseded) - Removed incorrect `ir.Value` type annotations from `pattern()` method signatures in rule files (pattern methods work with pattern-value objects, not `ir.Value`) ## Design ``` OpBuilderBase (ABC) <- shared interface _add_node() <- abstract _add_initializer() <- abstract _record_opset() <- abstract __getattr__() <- dynamic dispatch (concrete) op() <- explicit node creation (concrete) initializer() <- initializer creation (concrete) TapeBuilder(OpBuilderBase) <- list-backed implementation nodes <- harvesting property initializers <- harvesting property used_opsets <- harvesting property Aliases: RewriterContext = OpBuilderBase (in rewriter) OptimizerContext = OpBuilderBase (in optimizer) VCContext = OpBuilderBase (in version converter) ``` This design allows future alternative implementations (e.g., graph-backed builder) by subclassing `OpBuilderBase` without changing rule/evaluator code. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 38bec07 commit fed67b9

16 files changed

Lines changed: 428 additions & 232 deletions

onnxscript/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
"TracedOnnxFunction",
1515
"GraphBuilder",
1616
"OpBuilder",
17+
"OpBuilderBase",
18+
"TapeBuilder",
1719
"proto2python",
1820
"external_tensor",
1921
"BFLOAT16",
@@ -131,6 +133,7 @@
131133

132134
from . import ir, nn, optimizer, rewriter, version_converter
133135
from ._internal.builder import GraphBuilder, OpBuilder
136+
from ._internal.tape_builder import OpBuilderBase, TapeBuilder
134137
from ._internal.utils import external_tensor
135138
from ._internal.values import OnnxFunction, TracedOnnxFunction
136139

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Op builder base class and tape-backed implementation.
4+
5+
This module defines:
6+
7+
- ``OpBuilderBase``: Abstract base class for building ONNX IR nodes via a
8+
dynamic dispatch interface (``op.Relu(x)``, ``op.op(...)``, ``op.initializer(...)``).
9+
Subclasses implement the storage strategy by overriding ``_add_node``,
10+
``_add_initializer``, and ``_record_opset``.
11+
12+
- ``TapeBuilder``: Concrete subclass backed by simple lists. Engines
13+
(rewriter, optimizer, version converter) create an instance, pass it to a
14+
rule or evaluator, and harvest the accumulated nodes / initializers / opsets
15+
after it returns.
16+
"""
17+
18+
from __future__ import annotations
19+
20+
import abc
21+
from typing import Any, Mapping, Optional, Sequence
22+
23+
import onnx_ir as ir
24+
from onnx_ir import _convenience
25+
26+
UsedOpsets = set[tuple[str, Optional[int]]]
27+
28+
29+
class OpBuilderBase(abc.ABC):
30+
"""Abstract base class for building ONNX IR nodes.
31+
32+
Supports three creation operations:
33+
34+
1. **Dynamic op dispatch** — ``op.Relu(x)``, ``op.MatMul(a, b, _domain=...)``, etc.
35+
2. **Explicit op creation** — ``op.op("Conv", inputs, attrs, domain=...)``.
36+
3. **Initializer creation** — ``op.initializer(tensor, name=...)``.
37+
38+
Subclasses must implement the three protected methods that define where
39+
created nodes and initializers are stored:
40+
41+
- :meth:`_add_node`
42+
- :meth:`_add_initializer`
43+
- :meth:`_record_opset`
44+
"""
45+
46+
# ------------------------------------------------------------------
47+
# Abstract storage interface (to be implemented by subclasses)
48+
# ------------------------------------------------------------------
49+
50+
@abc.abstractmethod
51+
def _add_node(self, node: ir.Node) -> None:
52+
"""Record a newly created node."""
53+
raise NotImplementedError
54+
55+
@abc.abstractmethod
56+
def _add_initializer(self, value: ir.Value) -> None:
57+
"""Record a newly created initializer."""
58+
raise NotImplementedError
59+
60+
@abc.abstractmethod
61+
def _record_opset(self, domain: str, version: int | None) -> None:
62+
"""Record that an opset domain/version was referenced."""
63+
raise NotImplementedError
64+
65+
# ------------------------------------------------------------------
66+
# Public API (concrete)
67+
# ------------------------------------------------------------------
68+
69+
def __getattr__(self, op_type: str) -> Any:
70+
"""Dynamic op dispatch: ``op.Relu(x)``, ``op.MatMul(a, b)``, etc.
71+
72+
Returns a callable that creates a node of the given ``op_type``
73+
and records it via the subclass storage implementation.
74+
75+
Supported keyword arguments on the returned callable:
76+
_domain (str): Op domain (default ``""``).
77+
_version (int | None): Opset version.
78+
_outputs (int | list[str]): Number of outputs or explicit output names.
79+
_name (str | None): Optional node name (must be unique).
80+
"""
81+
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)
82+
83+
def _make_node(
84+
self, op_type: str, inputs: Sequence[ir.Value | None], kwargs: dict[str, Any]
85+
) -> ir.Value | Sequence[ir.Value]:
86+
"""Create one or more output values by building an ``ir.Node``."""
87+
domain = kwargs.pop("_domain", "")
88+
version = kwargs.pop("_version", None)
89+
outputs = kwargs.pop("_outputs", 1)
90+
name = kwargs.pop("_name", None)
91+
92+
if isinstance(outputs, Sequence):
93+
num_outputs = len(outputs)
94+
else:
95+
assert isinstance(outputs, int)
96+
num_outputs = outputs
97+
98+
attrs: Sequence[ir.Attr] = _convenience.convert_attributes(kwargs) if kwargs else ()
99+
node = ir.Node(
100+
domain,
101+
op_type,
102+
inputs,
103+
attributes=attrs,
104+
num_outputs=num_outputs,
105+
version=version,
106+
name=name,
107+
)
108+
self._add_node(node)
109+
self._record_opset(domain, version)
110+
111+
if num_outputs == 1:
112+
if isinstance(outputs, Sequence):
113+
node.outputs[0].name = outputs[0]
114+
return node.outputs[0]
115+
116+
if isinstance(outputs, Sequence):
117+
for value, output_name in zip(node.outputs, outputs):
118+
value.name = output_name
119+
return node.outputs
120+
121+
def op(
122+
self,
123+
op_type: str,
124+
inputs: Sequence[ir.Value | None],
125+
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
126+
*,
127+
domain: str = "",
128+
version: int | None = None,
129+
name: str | None = None,
130+
) -> ir.Value:
131+
"""Create a single-output node with an explicit op type.
132+
133+
This is useful when the op type is determined dynamically or when
134+
forwarding attributes from a matched node.
135+
"""
136+
attrs: Sequence[ir.Attr] = (
137+
_convenience.convert_attributes(attributes) if attributes else ()
138+
)
139+
node = ir.Node(
140+
domain,
141+
op_type,
142+
inputs,
143+
attributes=attrs,
144+
num_outputs=1,
145+
version=version,
146+
name=name,
147+
)
148+
self._add_node(node)
149+
self._record_opset(domain, version)
150+
return node.outputs[0]
151+
152+
def initializer(
153+
self,
154+
tensor: ir.TensorProtocol,
155+
name: str | None = None,
156+
) -> ir.Value:
157+
"""Create a new constant initializer and return its ``ir.Value``."""
158+
name = name or tensor.name
159+
if name is None:
160+
raise ValueError("Name must be provided for initializer.")
161+
shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims)
162+
value = ir.Value(
163+
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
164+
)
165+
self._add_initializer(value)
166+
return value
167+
168+
169+
class TapeBuilder(OpBuilderBase):
170+
"""Concrete builder backed by simple lists (tape-like storage).
171+
172+
Engines (rewriter, optimizer, version converter) create an instance,
173+
pass it to a rule or evaluator, and after it returns, harvest the
174+
accumulated results via the ``nodes``, ``initializers``, and
175+
``used_opsets`` properties.
176+
"""
177+
178+
def __init__(self) -> None:
179+
self._nodes: list[ir.Node] = []
180+
self._initializers: list[ir.Value] = []
181+
self._used_opsets: UsedOpsets = set()
182+
183+
def _add_node(self, node: ir.Node) -> None:
184+
self._nodes.append(node)
185+
186+
def _add_initializer(self, value: ir.Value) -> None:
187+
self._initializers.append(value)
188+
189+
def _record_opset(self, domain: str, version: int | None) -> None:
190+
self._used_opsets.add((domain, version))
191+
192+
# --- Harvesting properties ---
193+
194+
@property
195+
def nodes(self) -> Sequence[ir.Node]:
196+
"""All nodes created during this context's lifetime."""
197+
return tuple(self._nodes)
198+
199+
@property
200+
def initializers(self) -> Sequence[ir.Value]:
201+
"""All initializers created during this context's lifetime."""
202+
return tuple(self._initializers)
203+
204+
@property
205+
def used_opsets(self) -> UsedOpsets:
206+
"""Opset domains/versions referenced by created nodes."""
207+
return self._used_opsets

onnxscript/ir/_tape.py

Lines changed: 0 additions & 71 deletions
This file was deleted.

0 commit comments

Comments
 (0)