diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 2af10646de..6f81feb7a6 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -1314,6 +1314,16 @@ def test_take_inputs(self): self.assertIs(self.value2.graph, self.graph) self.assertIsNone(self.value3.graph) + def test_inputs_copy(self): + self.graph.inputs.extend([self.value1, self.value2]) + inputs_copy = self.graph.inputs.copy() + self.assertEqual(inputs_copy, [self.value1, self.value2]) + self.assertIsNot(inputs_copy, self.graph.inputs) + # Modifying the copy does not affect the original + inputs_copy.append(self.value3) + self.assertNotIn(self.value3, self.graph.inputs) + self.assertIn(self.value3, inputs_copy) + def test_append_to_outputs(self): self.graph.outputs.append(self.value2) self.assertIn(self.value2, self.graph.outputs) @@ -1423,6 +1433,16 @@ def test_take_outputs(self): self.assertIs(self.value2.graph, self.graph) self.assertIsNone(self.value3.graph) + def test_outputs_copy(self): + self.graph.outputs.extend([self.value1, self.value2]) + outputs_copy = self.graph.outputs.copy() + self.assertEqual(outputs_copy, [self.value1, self.value2]) + self.assertIsNot(outputs_copy, self.graph.outputs) + # Modifying the copy does not affect the original + outputs_copy.append(self.value3) + self.assertNotIn(self.value3, self.graph.outputs) + self.assertIn(self.value3, outputs_copy) + def test_set_initializers(self): self.graph.initializers["initializer1"] = self.value3 self.assertIn("initializer1", self.graph.initializers) diff --git a/onnxscript/ir/_graph_containers.py b/onnxscript/ir/_graph_containers.py index 620e73e86b..9aab17d006 100644 --- a/onnxscript/ir/_graph_containers.py +++ b/onnxscript/ir/_graph_containers.py @@ -90,6 +90,11 @@ def clear(self) -> None: self._maybe_unset_graph(value) super().clear() + def copy(self) -> list[_core.Value]: + """Return a shallow copy of the list.""" + # This is a shallow copy, so the values are not copied, just the references + return self.data.copy() + def __setitem__(self, i, item) -> None: """Replace an input/output to the node.""" if isinstance(item, Iterable) and isinstance(i, slice): @@ -124,7 +129,6 @@ def _unimplemented(self, *_args, **_kwargs): __iadd__ = _unimplemented __mul__ = _unimplemented __rmul__ = _unimplemented - copy = _unimplemented class GraphInputs(_GraphIO):