44
55This module defines:
66
7- - ``BuilderBase``: Abstract base class for building ONNX IR nodes via a
8- dynamic dispatch interface (``op.Relu(x)``, ``op.op("Relu", x)``,
9- ``op. initializer(...)``) .
7+ - ``BuilderBase``: Abstract base class for building ONNX IR nodes.
8+ Provides the core node-creation pipeline (``call_op``) and an
9+ initializer creation API .
1010 Subclasses implement the storage strategy by overriding ``_add_node``,
1111 ``_add_initializer``, and ``_record_opset``.
1212
1313- ``TapeBuilder``: Concrete subclass backed by simple lists. Engines
1414 (rewriter, optimizer, version converter) create an instance, pass it to a
1515 rule or evaluator, and harvest the accumulated nodes / initializers / opsets
16- after it returns.
16+ after it returns. Provides the user-facing ``op()`` method and
17+ ``__getattr__``-based dynamic dispatch (``op.Relu(x)``).
1718
1819- ``BuilderFeature``: Flag enum controlling optional processing steps
1920 (schema partitioning, input casting, shape inference, etc.).
@@ -64,26 +65,12 @@ class BuilderFeature(enum.Flag):
6465 SCHEMA_AWARE = SCHEMA_PARTITION | CAST_INPUTS | CAST_ATTRIBUTES
6566 FULL = SCHEMA_AWARE | INFER_SHAPES | CONSTANT_PROPAGATION
6667
67- @property
68- def any_schema_feature (self ) -> bool :
69- """True if any schema-dependent feature is enabled."""
70- return bool (
71- self
72- & (
73- BuilderFeature .SCHEMA_PARTITION
74- | BuilderFeature .CAST_INPUTS
75- | BuilderFeature .CAST_ATTRIBUTES
76- )
77- )
78-
7968
8069class BuilderBase (abc .ABC ):
8170 """Abstract base class for building ONNX IR nodes.
8271
83- Supports two creation operations:
84-
85- 1. **Op creation** — ``op.op("Relu", x)`` or ``op.Relu(x)`` (syntactic sugar).
86- 2. **Initializer creation** — ``op.initializer(tensor, name=...)``.
72+ Provides the core node-creation pipeline via :meth:`call_op` and an
73+ initializer creation API via :meth:`initializer`.
8774
8875 Subclasses must implement the three protected methods that define where
8976 created nodes and initializers are stored:
@@ -123,6 +110,15 @@ def _record_opset(self, domain: str, version: int | None) -> None:
123110 # Overridable hook methods
124111 # ------------------------------------------------------------------
125112
113+ def _get_default_opset_version (self , domain : str = "" ) -> int | None :
114+ """Return the default opset version for internal ops (e.g. CastLike).
115+
116+ Default: None (no version). Override in GraphBuilder to return the
117+ graph's ambient opset version.
118+ """
119+ del domain # Unused in base implementation
120+ return None
121+
126122 def _get_schema (
127123 self , op_type : str , domain : str , version : int | None
128124 ) -> onnx .defs .OpSchema | None :
@@ -247,7 +243,12 @@ def _input_to_ir_value(
247243 needs_dynamic_cast = like_type is not None and dtype is None
248244 ir_value = self ._promote_constant (value , dtype )
249245 if needs_dynamic_cast :
250- ir_value = self .call_op ("CastLike" , [ir_value , like_type ], {})
246+ ir_value = self .call_op (
247+ "CastLike" ,
248+ [ir_value , like_type ],
249+ {},
250+ version = self ._get_default_opset_version ("" ),
251+ )
251252 return ir_value
252253
253254 def _promote_constant (self , value : Any , dtype : ir .DataType | None ) -> ir .Value :
@@ -326,57 +327,6 @@ def _annotate_node(self, node: ir.Node) -> None: # noqa: B027
326327 # Public API (concrete)
327328 # ------------------------------------------------------------------
328329
329- def __getattr__ (self , op_type : str ) -> Any :
330- """Dynamic op dispatch: ``op.Relu(x)``, ``op.MatMul(a, b)``, etc.
331-
332- Syntactic sugar for ``op.op(op_type, ...)``.
333-
334- Returns a callable that creates a node of the given ``op_type``
335- and records it via the subclass storage implementation.
336- """
337- return lambda * args , ** kwargs : self .op (op_type , * args , ** kwargs )
338-
339- def op (
340- self ,
341- op_type : str ,
342- / ,
343- * args : ir .Value | None ,
344- _domain : str = "" ,
345- _version : int | None = None ,
346- _outputs : int | Sequence [str ] = 1 ,
347- _name : str | None = None ,
348- ** kwargs : Any ,
349- ) -> ir .Value | Sequence [ir .Value ]:
350- """Create an ONNX node.
351-
352- This is the single entry point for all node creation.
353- ``op.Relu(x)`` is equivalent to ``op.op("Relu", x)``.
354-
355- Args:
356- op_type: The operator type (e.g., ``"Relu"``, ``"Conv"``).
357- *args: Positional arguments — the node's input values.
358- _domain: Op domain (default ``""``).
359- _version: Opset version.
360- _outputs: Number of outputs or list of explicit output names.
361- _name: Optional node name (must be unique).
362- **kwargs: Keyword arguments — node attributes.
363- Values can be Python scalars/lists (auto-converted) or
364- ``ir.Attr`` instances (passed through).
365-
366- Returns:
367- A single ``ir.Value`` if the node has one output, otherwise
368- a sequence of ``ir.Value``.
369- """
370- return self .call_op (
371- op_type ,
372- args ,
373- kwargs ,
374- domain = _domain ,
375- version = _version ,
376- outputs = _outputs ,
377- name = _name ,
378- )
379-
380330 def call_op (
381331 self ,
382332 op_type : str ,
@@ -390,15 +340,23 @@ def call_op(
390340 ) -> ir .Value | Sequence [ir .Value ]:
391341 """Create an ONNX node and add it to the graph, returning its output value(s).
392342
393- This is the core node-creation method. Both ``BuilderBase .op()`` and
343+ This is the core node-creation method. Both ``TapeBuilder .op()`` and
394344 ``OpBuilder.__getattr__`` delegate here. The processing steps are
395345 controlled by :attr:`features` flags and overridable hook methods.
346+
347+ Note:
348+ When input casting is enabled, helper nodes (e.g. Constant,
349+ CastLike) may be created before the requested node. In builders
350+ that generate names from a node counter (like GraphBuilder), this
351+ means the outer node's auto-generated name reflects the *total*
352+ node count including helpers — not a sequential index of
353+ user-requested ops.
396354 """
397355 features = self ._features
398356
399357 # 1. Schema lookup (if any schema-dependent feature is enabled)
400358 schema = None
401- if features . any_schema_feature :
359+ if features & BuilderFeature . SCHEMA_AWARE :
402360 schema = self ._get_schema (op_type , domain , version )
403361
404362 # 2. Partition args into inputs and attributes using schema
@@ -495,6 +453,68 @@ def __init__(self, *, features: BuilderFeature = BuilderFeature.NONE) -> None:
495453 self ._initializers : list [ir .Value ] = []
496454 self ._used_opsets : UsedOpsets = set ()
497455
456+ # ------------------------------------------------------------------
457+ # Public op-creation API
458+ # ------------------------------------------------------------------
459+
460+ def __getattr__ (self , op_type : str ) -> Any :
461+ """Dynamic op dispatch: ``op.Relu(x)``, ``op.MatMul(a, b)``, etc.
462+
463+ Syntactic sugar for ``op.op(op_type, ...)``.
464+
465+ Returns a callable that creates a node of the given ``op_type``
466+ and records it via the subclass storage implementation.
467+ """
468+ return lambda * args , ** kwargs : self .op (op_type , * args , ** kwargs )
469+
470+ def op (
471+ self ,
472+ op_type : str ,
473+ / ,
474+ * args : ir .Value | None ,
475+ _domain : str = "" ,
476+ _version : int | None = None ,
477+ _outputs : int | Sequence [str ] = 1 ,
478+ _name : str | None = None ,
479+ ** kwargs : Any ,
480+ ) -> ir .Value | Sequence [ir .Value ]:
481+ """Create an ONNX node.
482+
483+ This is the single entry point for all node creation.
484+ ``op.Relu(x)`` is equivalent to ``op.op("Relu", x)``.
485+
486+ Reserved keyword arguments are prefixed with an underscore to avoid
487+ clashing with ONNX attribute names passed via ``**kwargs``.
488+
489+ Args:
490+ op_type: The operator type (e.g., ``"Relu"``, ``"Conv"``).
491+ *args: Positional arguments — the node's input values.
492+ _domain: Op domain (default ``""``).
493+ _version: Opset version.
494+ _outputs: Number of outputs or list of explicit output names.
495+ _name: Optional node name (must be unique).
496+ **kwargs: Keyword arguments — node attributes.
497+ Values can be Python scalars/lists (auto-converted) or
498+ ``ir.Attr`` instances (passed through).
499+
500+ Returns:
501+ A single ``ir.Value`` if the node has one output, otherwise
502+ a sequence of ``ir.Value``.
503+ """
504+ return self .call_op (
505+ op_type ,
506+ args ,
507+ kwargs ,
508+ domain = _domain ,
509+ version = _version ,
510+ outputs = _outputs ,
511+ name = _name ,
512+ )
513+
514+ # ------------------------------------------------------------------
515+ # Storage interface
516+ # ------------------------------------------------------------------
517+
498518 def _add_node (self , node : ir .Node ) -> None :
499519 self ._nodes .append (node )
500520
0 commit comments