Skip to content

Commit 0db0320

Browse files
gramalingamCopilot
andcommitted
Refactor: move op() to TapeBuilder and address PR review comments
- Move op() method and __getattr__ from BuilderBase to TapeBuilder to eliminate confusing shadowing with GraphBuilder.op property. - Fix OpBuilder._call_op to pop and forward _name kwarg. - Remove duplicated _dtype_suffix/_constant_name helpers from builder.py (import from tape_builder.py instead). - Inline BuilderFeature.any_schema_feature into SCHEMA_AWARE flag check. - Add _get_default_opset_version() hook so CastLike gets proper version context from GraphBuilder. - Document underscore-prefix convention in op() docstring. - Document name-determinism behavior in call_op() docstring. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent d3b52dd commit 0db0320

2 files changed

Lines changed: 102 additions & 102 deletions

File tree

onnxscript/_internal/builder.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from onnxscript._internal.tape_builder import (
2121
BuilderBase,
2222
BuilderFeature,
23+
_constant_name,
24+
_dtype_suffix,
2325
)
2426

2527
# A permissible value for an op input, which can be converted to an ir.Value.
@@ -51,33 +53,6 @@ def _type_suffix(element_type: type) -> str:
5153
return dtype.short_name() if dtype is not None else ""
5254

5355

54-
def _dtype_suffix(dtype: ir.DataType) -> str:
55-
"""Return a short type suffix for naming constants based on ir.DataType."""
56-
return dtype.short_name()
57-
58-
59-
def _constant_name(
60-
value: int | float | bool | str | Sequence, type_suffix: str, num: int = 0
61-
) -> str:
62-
"""Generate a descriptive name for a constant value.
63-
64-
Args:
65-
value: The constant value
66-
type_suffix: Type suffix (e.g., 'F', 'I64')
67-
num: A number used for generating unique names for str/sequences
68-
69-
Returns:
70-
A name string for the constant
71-
"""
72-
if isinstance(value, str):
73-
# For strings, use a generic name with cache size as unique identifier
74-
return f"const_str_{num}"
75-
if isinstance(value, (int, float, bool)):
76-
return f"const_{value}_{type_suffix}" if type_suffix else f"const_{value}"
77-
# Sequence: use generic name with cache size as unique identifier
78-
return f"const_1d_{num}"
79-
80-
8156
def lift_initializers_to_constants(graph: ir.Graph) -> None:
8257
"""Replace every initializer in *graph* with a ``Constant`` node.
8358
@@ -498,6 +473,10 @@ def _record_opset(self, domain: str, version: int | None) -> None:
498473
# BuilderBase hook overrides
499474
# ------------------------------------------------------------------
500475

476+
def _get_default_opset_version(self, domain: str = "") -> int | None:
477+
"""Return the graph's ambient opset version for the given domain."""
478+
return self._graph.opset_imports.get(domain)
479+
501480
def _promote_constant(self, value: Any, dtype: ir.DataType | None) -> ir.Value:
502481
"""Cache-based constant promotion.
503482
@@ -924,8 +903,9 @@ def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]):
924903
domain = kwargs.pop("_domain", self._domain)
925904
version = kwargs.pop("_version", self._version)
926905
outputs = kwargs.pop("_outputs", 1)
906+
name = kwargs.pop("_name", None)
927907
return self._builder.call_op(
928-
op_type, inputs, kwargs, domain=domain, version=version, outputs=outputs
908+
op_type, inputs, kwargs, domain=domain, version=version, outputs=outputs, name=name
929909
)
930910

931911
def __getattr__(self, op_type: str) -> Callable:

onnxscript/_internal/tape_builder.py

Lines changed: 94 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
55
This module defines:
66
7-
- ``BuilderBase``: Abstract base class for building ONNX IR nodes via a
8-
dynamic dispatch interface (``op.Relu(x)``, ``op.op("Relu", x)``,
9-
``op.initializer(...)``).
7+
- ``BuilderBase``: Abstract base class for building ONNX IR nodes.
8+
Provides the core node-creation pipeline (``call_op``) and an
9+
initializer creation API.
1010
Subclasses implement the storage strategy by overriding ``_add_node``,
1111
``_add_initializer``, and ``_record_opset``.
1212
1313
- ``TapeBuilder``: Concrete subclass backed by simple lists. Engines
1414
(rewriter, optimizer, version converter) create an instance, pass it to a
1515
rule or evaluator, and harvest the accumulated nodes / initializers / opsets
16-
after it returns.
16+
after it returns. Provides the user-facing ``op()`` method and
17+
``__getattr__``-based dynamic dispatch (``op.Relu(x)``).
1718
1819
- ``BuilderFeature``: Flag enum controlling optional processing steps
1920
(schema partitioning, input casting, shape inference, etc.).
@@ -64,26 +65,12 @@ class BuilderFeature(enum.Flag):
6465
SCHEMA_AWARE = SCHEMA_PARTITION | CAST_INPUTS | CAST_ATTRIBUTES
6566
FULL = SCHEMA_AWARE | INFER_SHAPES | CONSTANT_PROPAGATION
6667

67-
@property
68-
def any_schema_feature(self) -> bool:
69-
"""True if any schema-dependent feature is enabled."""
70-
return bool(
71-
self
72-
& (
73-
BuilderFeature.SCHEMA_PARTITION
74-
| BuilderFeature.CAST_INPUTS
75-
| BuilderFeature.CAST_ATTRIBUTES
76-
)
77-
)
78-
7968

8069
class BuilderBase(abc.ABC):
8170
"""Abstract base class for building ONNX IR nodes.
8271
83-
Supports two creation operations:
84-
85-
1. **Op creation** — ``op.op("Relu", x)`` or ``op.Relu(x)`` (syntactic sugar).
86-
2. **Initializer creation** — ``op.initializer(tensor, name=...)``.
72+
Provides the core node-creation pipeline via :meth:`call_op` and an
73+
initializer creation API via :meth:`initializer`.
8774
8875
Subclasses must implement the three protected methods that define where
8976
created nodes and initializers are stored:
@@ -123,6 +110,15 @@ def _record_opset(self, domain: str, version: int | None) -> None:
123110
# Overridable hook methods
124111
# ------------------------------------------------------------------
125112

113+
def _get_default_opset_version(self, domain: str = "") -> int | None:
114+
"""Return the default opset version for internal ops (e.g. CastLike).
115+
116+
Default: None (no version). Override in GraphBuilder to return the
117+
graph's ambient opset version.
118+
"""
119+
del domain # Unused in base implementation
120+
return None
121+
126122
def _get_schema(
127123
self, op_type: str, domain: str, version: int | None
128124
) -> onnx.defs.OpSchema | None:
@@ -247,7 +243,12 @@ def _input_to_ir_value(
247243
needs_dynamic_cast = like_type is not None and dtype is None
248244
ir_value = self._promote_constant(value, dtype)
249245
if needs_dynamic_cast:
250-
ir_value = self.call_op("CastLike", [ir_value, like_type], {})
246+
ir_value = self.call_op(
247+
"CastLike",
248+
[ir_value, like_type],
249+
{},
250+
version=self._get_default_opset_version(""),
251+
)
251252
return ir_value
252253

253254
def _promote_constant(self, value: Any, dtype: ir.DataType | None) -> ir.Value:
@@ -326,57 +327,6 @@ def _annotate_node(self, node: ir.Node) -> None: # noqa: B027
326327
# Public API (concrete)
327328
# ------------------------------------------------------------------
328329

329-
def __getattr__(self, op_type: str) -> Any:
330-
"""Dynamic op dispatch: ``op.Relu(x)``, ``op.MatMul(a, b)``, etc.
331-
332-
Syntactic sugar for ``op.op(op_type, ...)``.
333-
334-
Returns a callable that creates a node of the given ``op_type``
335-
and records it via the subclass storage implementation.
336-
"""
337-
return lambda *args, **kwargs: self.op(op_type, *args, **kwargs)
338-
339-
def op(
340-
self,
341-
op_type: str,
342-
/,
343-
*args: ir.Value | None,
344-
_domain: str = "",
345-
_version: int | None = None,
346-
_outputs: int | Sequence[str] = 1,
347-
_name: str | None = None,
348-
**kwargs: Any,
349-
) -> ir.Value | Sequence[ir.Value]:
350-
"""Create an ONNX node.
351-
352-
This is the single entry point for all node creation.
353-
``op.Relu(x)`` is equivalent to ``op.op("Relu", x)``.
354-
355-
Args:
356-
op_type: The operator type (e.g., ``"Relu"``, ``"Conv"``).
357-
*args: Positional arguments — the node's input values.
358-
_domain: Op domain (default ``""``).
359-
_version: Opset version.
360-
_outputs: Number of outputs or list of explicit output names.
361-
_name: Optional node name (must be unique).
362-
**kwargs: Keyword arguments — node attributes.
363-
Values can be Python scalars/lists (auto-converted) or
364-
``ir.Attr`` instances (passed through).
365-
366-
Returns:
367-
A single ``ir.Value`` if the node has one output, otherwise
368-
a sequence of ``ir.Value``.
369-
"""
370-
return self.call_op(
371-
op_type,
372-
args,
373-
kwargs,
374-
domain=_domain,
375-
version=_version,
376-
outputs=_outputs,
377-
name=_name,
378-
)
379-
380330
def call_op(
381331
self,
382332
op_type: str,
@@ -390,15 +340,23 @@ def call_op(
390340
) -> ir.Value | Sequence[ir.Value]:
391341
"""Create an ONNX node and add it to the graph, returning its output value(s).
392342
393-
This is the core node-creation method. Both ``BuilderBase.op()`` and
343+
This is the core node-creation method. Both ``TapeBuilder.op()`` and
394344
``OpBuilder.__getattr__`` delegate here. The processing steps are
395345
controlled by :attr:`features` flags and overridable hook methods.
346+
347+
Note:
348+
When input casting is enabled, helper nodes (e.g. Constant,
349+
CastLike) may be created before the requested node. In builders
350+
that generate names from a node counter (like GraphBuilder), this
351+
means the outer node's auto-generated name reflects the *total*
352+
node count including helpers — not a sequential index of
353+
user-requested ops.
396354
"""
397355
features = self._features
398356

399357
# 1. Schema lookup (if any schema-dependent feature is enabled)
400358
schema = None
401-
if features.any_schema_feature:
359+
if features & BuilderFeature.SCHEMA_AWARE:
402360
schema = self._get_schema(op_type, domain, version)
403361

404362
# 2. Partition args into inputs and attributes using schema
@@ -495,6 +453,68 @@ def __init__(self, *, features: BuilderFeature = BuilderFeature.NONE) -> None:
495453
self._initializers: list[ir.Value] = []
496454
self._used_opsets: UsedOpsets = set()
497455

456+
# ------------------------------------------------------------------
457+
# Public op-creation API
458+
# ------------------------------------------------------------------
459+
460+
def __getattr__(self, op_type: str) -> Any:
461+
"""Dynamic op dispatch: ``op.Relu(x)``, ``op.MatMul(a, b)``, etc.
462+
463+
Syntactic sugar for ``op.op(op_type, ...)``.
464+
465+
Returns a callable that creates a node of the given ``op_type``
466+
and records it via the subclass storage implementation.
467+
"""
468+
return lambda *args, **kwargs: self.op(op_type, *args, **kwargs)
469+
470+
def op(
471+
self,
472+
op_type: str,
473+
/,
474+
*args: ir.Value | None,
475+
_domain: str = "",
476+
_version: int | None = None,
477+
_outputs: int | Sequence[str] = 1,
478+
_name: str | None = None,
479+
**kwargs: Any,
480+
) -> ir.Value | Sequence[ir.Value]:
481+
"""Create an ONNX node.
482+
483+
This is the single entry point for all node creation.
484+
``op.Relu(x)`` is equivalent to ``op.op("Relu", x)``.
485+
486+
Reserved keyword arguments are prefixed with an underscore to avoid
487+
clashing with ONNX attribute names passed via ``**kwargs``.
488+
489+
Args:
490+
op_type: The operator type (e.g., ``"Relu"``, ``"Conv"``).
491+
*args: Positional arguments — the node's input values.
492+
_domain: Op domain (default ``""``).
493+
_version: Opset version.
494+
_outputs: Number of outputs or list of explicit output names.
495+
_name: Optional node name (must be unique).
496+
**kwargs: Keyword arguments — node attributes.
497+
Values can be Python scalars/lists (auto-converted) or
498+
``ir.Attr`` instances (passed through).
499+
500+
Returns:
501+
A single ``ir.Value`` if the node has one output, otherwise
502+
a sequence of ``ir.Value``.
503+
"""
504+
return self.call_op(
505+
op_type,
506+
args,
507+
kwargs,
508+
domain=_domain,
509+
version=_version,
510+
outputs=_outputs,
511+
name=_name,
512+
)
513+
514+
# ------------------------------------------------------------------
515+
# Storage interface
516+
# ------------------------------------------------------------------
517+
498518
def _add_node(self, node: ir.Node) -> None:
499519
self._nodes.append(node)
500520

0 commit comments

Comments
 (0)