Skip to content

Commit e53ec9f

Browse files
committed
Introduce slice support (#24)
1 parent 1013c8b commit e53ec9f

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

src/onnx_ir/_core.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2285,7 +2285,12 @@ def doc_string(self, value: str | None) -> None:
22852285
def opset_imports(self) -> dict[str, int]:
22862286
return self._opset_imports
22872287

2288-
def __getitem__(self, index: int) -> Node:
2288+
@typing.overload
2289+
def __getitem__(self, index: int) -> Node: ...
2290+
@typing.overload
2291+
def __getitem__(self, index: slice) -> Sequence[Node]: ...
2292+
2293+
def __getitem__(self, index):
22892294
return self._nodes[index]
22902295

22912296
def __len__(self) -> int:
@@ -2715,7 +2720,12 @@ def __init__(
27152720
self._metadata_props: dict[str, str] | None = metadata_props
27162721
self._nodes: tuple[Node, ...] = tuple(nodes)
27172722

2718-
def __getitem__(self, index: int) -> Node:
2723+
@typing.overload
2724+
def __getitem__(self, index: int) -> Node: ...
2725+
@typing.overload
2726+
def __getitem__(self, index: slice) -> Sequence[Node]: ...
2727+
2728+
def __getitem__(self, index):
27192729
return self._nodes[index]
27202730

27212731
def __len__(self) -> int:
@@ -2964,7 +2974,12 @@ def outputs(self) -> MutableSequence[Value]:
29642974
def attributes(self) -> OrderedDict[str, Attr]:
29652975
return self._attributes
29662976

2967-
def __getitem__(self, index: int) -> Node:
2977+
@typing.overload
2978+
def __getitem__(self, index: int) -> Node: ...
2979+
@typing.overload
2980+
def __getitem__(self, index: slice) -> Sequence[Node]: ...
2981+
2982+
def __getitem__(self, index):
29682983
return self._graph.__getitem__(index)
29692984

29702985
def __len__(self) -> int:

src/onnx_ir/_linked_list.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
from collections.abc import Iterable, Iterator, Sequence
8-
from typing import Generic, TypeVar
8+
from typing import Generic, TypeVar, overload
99

1010
T = TypeVar("T")
1111

@@ -137,11 +137,18 @@ def __len__(self) -> int:
137137
)
138138
return self._length
139139

140-
def __getitem__(self, index: int) -> T:
140+
@overload
141+
def __getitem__(self, index: int) -> T: ...
142+
@overload
143+
def __getitem__(self, index: slice) -> Sequence[T]: ...
144+
145+
def __getitem__(self, index):
141146
"""Get the node at the given index.
142147
143148
Complexity is O(n).
144149
"""
150+
if isinstance(index, slice):
151+
return tuple(self)[index]
145152
if index >= self._length or index < -self._length:
146153
raise IndexError(
147154
f"Index out of range: {index} not in range [-{self._length}, {self._length})"

src/onnx_ir/_linked_list_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,15 @@ def test_insert_after_supports_taking_elements_from_another_doubly_linked_list(
373373
self.assertEqual(len(other_linked_list), 1)
374374
self.assertEqual([elem.value for elem in other_linked_list], [42])
375375

376+
@parameterized.parameterized.expand(
377+
[(s, t, p) for s in [-2, 0, 2, 3] for t in [2, -1, -2] for p in [-3, -1, 1, 2]]
378+
)
379+
def test_get_item_slice(self, start, stop, step):
380+
elems = [_TestElement(i) for i in range(5)]
381+
linked_list = _linked_list.DoublyLinkedSet(elems)
382+
self.assertEqual(len(linked_list), 5)
383+
self.assertEqual(list(linked_list[start:stop:step]), elems[start:stop:step])
384+
376385

377386
if __name__ == "__main__":
378387
unittest.main()

0 commit comments

Comments
 (0)