Skip to content

Commit 862de43

Browse files
committed
[IR] introduce slice support (#2302)
1 parent b617ad5 commit 862de43

File tree

2 files changed

+8
-17
lines changed

2 files changed

+8
-17
lines changed

onnxscript/ir/_linked_list.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -136,27 +136,12 @@ def __len__(self) -> int:
136136
)
137137
return self._length
138138

139-
def __getitem__(self, index: int) -> T:
139+
def __getitem__(self, index: int | slice) -> T:
140140
"""Get the node at the given index.
141141
142142
Complexity is O(n).
143143
"""
144-
if index >= self._length or index < -self._length:
145-
raise IndexError(
146-
f"Index out of range: {index} not in range [-{self._length}, {self._length})"
147-
)
148-
if index < 0:
149-
# Look up from the end of the list
150-
iterator = reversed(self)
151-
item = next(iterator)
152-
for _ in range(-index - 1):
153-
item = next(iterator)
154-
else:
155-
iterator = iter(self) # type: ignore[assignment]
156-
item = next(iterator)
157-
for _ in range(index):
158-
item = next(iterator)
159-
return item
144+
return tuple(self)[index]
160145

161146
def _insert_one_after(
162147
self,

onnxscript/ir/_linked_list_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def test_append_multiple_elements(self):
6060
self.assertEqual(linked_list[-3], elems[0])
6161
self.assertEqual(list(linked_list), elems)
6262
self.assertEqual(list(reversed(linked_list)), list(reversed(elems)))
63+
self.assertEqual(list(linked_list[1:2]), elems[1:2])
64+
self.assertEqual(list(linked_list[:2]), elems[:2])
65+
self.assertEqual(list(linked_list[-2:]), elems[-2:])
6366

6467
def test_extend(self):
6568
elems = [_TestElement(i) for i in range(3)]
@@ -73,6 +76,9 @@ def test_extend(self):
7376
self.assertEqual(linked_list[-3], elems[0])
7477
self.assertEqual(list(linked_list), elems)
7578
self.assertEqual(list(reversed(linked_list)), list(reversed(elems)))
79+
self.assertEqual(list(linked_list[1:2]), elems[1:2])
80+
self.assertEqual(list(linked_list[:2]), elems[:2])
81+
self.assertEqual(list(linked_list[-2:]), elems[-2:])
7682

7783
@parameterized.parameterized.expand(
7884
[

0 commit comments

Comments
 (0)