Skip to content

Commit 081ae34

Browse files
committed
typing
1 parent acfa81c commit 081ae34

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

onnxscript/ir/_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2718,11 +2718,11 @@ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
27182718
"""
27192719
self._graph.remove(nodes, safe=safe)
27202720

2721-
def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None:
2721+
def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
27222722
"""Insert new nodes after the given node in O(#new_nodes) time."""
27232723
self._graph.insert_after(node, new_nodes)
27242724

2725-
def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
2725+
def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
27262726
"""Insert new nodes before the given node in O(#new_nodes) time."""
27272727
self._graph.insert_before(node, new_nodes)
27282728

onnxscript/ir/_protocols.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,11 @@ def remove(self, node: NodeProtocol, /) -> None:
320320
"""Remove a node from the graph."""
321321
...
322322

323-
def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
323+
def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None:
324324
"""Insert new nodes after the given node."""
325325
...
326326

327-
def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
327+
def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None:
328328
"""Insert new nodes before the given node."""
329329
...
330330

@@ -589,11 +589,11 @@ def remove(self, node: NodeProtocol, /) -> None:
589589
"""Remove a node from the function."""
590590
...
591591

592-
def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
592+
def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None:
593593
"""Insert new nodes after the given node."""
594594
...
595595

596-
def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
596+
def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None:
597597
"""Insert new nodes before the given node."""
598598
...
599599

0 commit comments

Comments
 (0)