From 8bf0a28cdeb6fb48a8f0c15ef1da6de38fb86d4b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 16:16:57 -0700 Subject: [PATCH 1/8] [IR] Allow to copy an unfrozen version of the Shape When a shape is frozen, the dims of the shape cannot be modified. Users can call ``` new_shape = shape.copy() new_shape[0] = 1 ``` to assign to the new shape. --- onnxscript/ir/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 51c6d83502..e46fa25b9b 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1090,9 +1090,9 @@ def __init__( ) self._frozen: bool = frozen - def copy(self): + def copy(self, frozen: bool = False): """Return a copy of the shape.""" - return Shape(self._dims, self._denotations, self._frozen) + return Shape(self._dims, self._denotations, frozen=frozen) @property def dims(self) -> tuple[int | SymbolicDim, ...]: From 07fe2200c8ec14d8d65955f0b27fd47e8615f145 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 16:18:12 -0700 Subject: [PATCH 2/8] Update _core.py --- onnxscript/ir/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index e46fa25b9b..eff4e3c599 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1138,7 +1138,7 @@ def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None TypeError: If the value is not an int or SymbolicDim. """ if self._frozen: - raise TypeError("The shape is frozen and cannot be modified.") + raise TypeError("The shape is frozen and cannot be modified. You can call .copy() to get a new mutable shape") self._dims[index] = _maybe_convert_to_symbolic_dim(value) From 8abad155d6db299d50932103b30946c70074152c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 16:18:32 -0700 Subject: [PATCH 3/8] format --- onnxscript/ir/_core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index eff4e3c599..07b102cd3d 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1138,7 +1138,9 @@ def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None TypeError: If the value is not an int or SymbolicDim. """ if self._frozen: - raise TypeError("The shape is frozen and cannot be modified. You can call .copy() to get a new mutable shape") + raise TypeError( + "The shape is frozen and cannot be modified. You can call .copy() to get a new mutable shape" + ) self._dims[index] = _maybe_convert_to_symbolic_dim(value) From 8b7f015acb74ad299c2bc52cb937801da3cb431f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 16:30:09 -0700 Subject: [PATCH 4/8] Allow changing symbolic dims to other symbolic dims --- onnxscript/ir/_core.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 07b102cd3d..9e547d9420 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -994,6 +994,7 @@ def meta(self) -> _metadata.MetadataStore: class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): + """Immutable symbolic dimension that can be shared across multiple shapes.""" __slots__ = ("_value",) def __init__(self, value: str | None) -> None: @@ -1134,13 +1135,19 @@ def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None value: The value of the dimension. Raises: - TypeError: If the shape is frozen and cannot be modified. - TypeError: If the value is not an int or SymbolicDim. + TypeError: If the shape is frozen and cannot be modified, unless the + existing dim and the new dim are both symbolic. + TypeError: If the value is not an int, str, SymbolicDim or None. """ if self._frozen: - raise TypeError( - "The shape is frozen and cannot be modified. You can call .copy() to get a new mutable shape" - ) + maybe_symbol = _maybe_convert_to_symbolic_dim(value) + if isinstance(self._dims[index], SymbolicDim) and isinstance( + maybe_symbol, SymbolicDim + ): + # Allow changing symbolic dims to other symbolic dims + self._dims[index] = maybe_symbol + return + raise TypeError("The shape is frozen and cannot be modified.") self._dims[index] = _maybe_convert_to_symbolic_dim(value) From e20126d5cbfed0ecafad60b5047e979213110fc8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 16:32:29 -0700 Subject: [PATCH 5/8] test --- onnxscript/ir/_core.py | 1 + onnxscript/ir/_core_test.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 9e547d9420..99f57bcdad 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -995,6 +995,7 @@ def meta(self) -> _metadata.MetadataStore: class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): """Immutable symbolic dimension that can be shared across multiple shapes.""" + __slots__ = ("_value",) def __init__(self, value: str | None) -> None: diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 7068a8da8f..ff0fff6566 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -622,6 +622,20 @@ def test_setitem_raises_when_shape_is_frozen(self): with self.assertRaisesRegex(TypeError, "frozen"): shape[0] = 1 + with self.assertRaisesRegex(TypeError, "frozen"): + shape[0] = "some_string" + + def test_setitem_allowed_swapping_sym_dim_when_shape_is_frozen(self): + shape = _core.Shape(["some_dim"], denotations=("DATA_CHANNEL",), frozen=True) + with self.assertRaisesRegex(TypeError, "frozen"): + shape[0] = 1 + + # These are ok + shape[0] = "some_string" + self.assertEqual(shape[0], "some_string") + shape[0] = _core.SymbolicDim("some_other_string") + self.assertEqual(shape[0], "some_other_string") + def test_getitem(self): shape = _core.Shape([42], denotations=("DATA_CHANNEL",)) self.assertEqual(shape[0], 42) From 0072cd70e3a4526d42801c4c1a027b38a6bd855f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Apr 2025 07:40:42 -0700 Subject: [PATCH 6/8] wip --- onnxscript/ir/_core.py | 44 ++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 99f57bcdad..0954167306 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1056,6 +1056,10 @@ def _maybe_convert_to_symbolic_dim( class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable): + """Shape of a tensor. + + + """ __slots__ = ("_dims", "_frozen") def __init__( @@ -1078,7 +1082,8 @@ def __init__( Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition for pre-defined dimension denotations. frozen: If True, the shape is immutable and cannot be modified. This - is useful when the shape is initialized by a Tensor. + is useful when the shape is initialized by a Tensor or when the shape + is shared across multiple tensors. The default is False. """ self._dims: list[int | SymbolicDim] = [ _maybe_convert_to_symbolic_dim(dim) for dim in dims @@ -1092,10 +1097,6 @@ def __init__( ) self._frozen: bool = frozen - def copy(self, frozen: bool = False): - """Return a copy of the shape.""" - return Shape(self._dims, self._denotations, frozen=frozen) - @property def dims(self) -> tuple[int | SymbolicDim, ...]: """All dimensions in the shape. @@ -1104,6 +1105,27 @@ def dims(self) -> tuple[int | SymbolicDim, ...]: """ return tuple(self._dims) + @property + def frozen(self) -> bool: + """Whether the shape is frozen. + + When the shape is frozen, it cannot be unfrozen, making it suitable to be shared. + Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a + new shape with the same dimensions that can be modified. + """ + return self._frozen + + def freeze(self) -> None: + """Freeze the shape. + + When the shape is frozen, it cannot be unfrozen, making it suitable to be shared. + """ + self._frozen = True + + def copy(self, frozen: bool = False): + """Return a copy of the shape.""" + return Shape(self._dims, self._denotations, frozen=frozen) + def rank(self) -> int: """The rank of the shape.""" return len(self._dims) @@ -1136,18 +1158,10 @@ def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None value: The value of the dimension. Raises: - TypeError: If the shape is frozen and cannot be modified, unless the - existing dim and the new dim are both symbolic. - TypeError: If the value is not an int, str, SymbolicDim or None. + TypeError: If the shape is frozen and cannot be modified. + TypeError: If the value is not an int or SymbolicDim. """ if self._frozen: - maybe_symbol = _maybe_convert_to_symbolic_dim(value) - if isinstance(self._dims[index], SymbolicDim) and isinstance( - maybe_symbol, SymbolicDim - ): - # Allow changing symbolic dims to other symbolic dims - self._dims[index] = maybe_symbol - return raise TypeError("The shape is frozen and cannot be modified.") self._dims[index] = _maybe_convert_to_symbolic_dim(value) From 10a127f865f1b00f0363d5ee6221a6d09ddfcee8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Apr 2025 07:53:10 -0700 Subject: [PATCH 7/8] example --- onnxscript/ir/_core.py | 46 +++++++++++++++++++++++++++++++++++-- onnxscript/ir/_core_test.py | 11 --------- 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 0954167306..5c0684e22b 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1056,9 +1056,51 @@ def _maybe_convert_to_symbolic_dim( class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable): - """Shape of a tensor. + """The shape of a tensor, including its dimensions and optional denotations. + The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or + symbolic dimensions. + A shape can be compared to another shape or plain Python list. + + A shape can be frozen (made immutable). When the shape is frozen, it cannot be + unfrozen, making it suitable to be shared across tensors or values. + Call :method:`freeze` to freeze the shape. + + To update the dimension of a frozen shape, call :method:`copy` to create a + new shape with the same dimensions that can be modified. + + Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations. + + Example:: + + >>> from onnxscript import ir + >>> shape = ir.Shape(["B", None, 3]) + >>> shape.rank() + 3 + >>> shape.is_static() + False + >>> shape.is_dynamic() + True + >>> shape.is_static(dim=2) + True + >>> shape[0] = 1 + >>> shape[1] = 2 + >>> shape.dims + (1, 2, 3) + >>> shape == [1, 2, 3] + True + >>> shape.frozen + False + >>> shape.freeze() + >>> shape.frozen + True + + Attributes: + dims: A tuple of dimensions representing the shape. + Each dimension can be an integer, None or a :class:`SymbolicDim`. + frozen: Indicates whether the shape is immutable. When frozen, the shape + cannot be modified or unfrozen. """ __slots__ = ("_dims", "_frozen") @@ -1127,7 +1169,7 @@ def copy(self, frozen: bool = False): return Shape(self._dims, self._denotations, frozen=frozen) def rank(self) -> int: - """The rank of the shape.""" + """The rank of the tensor this shape represents.""" return len(self._dims) def numpy(self) -> tuple[int, ...]: diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index ff0fff6566..ee2b0f389c 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -625,17 +625,6 @@ def test_setitem_raises_when_shape_is_frozen(self): with self.assertRaisesRegex(TypeError, "frozen"): shape[0] = "some_string" - def test_setitem_allowed_swapping_sym_dim_when_shape_is_frozen(self): - shape = _core.Shape(["some_dim"], denotations=("DATA_CHANNEL",), frozen=True) - with self.assertRaisesRegex(TypeError, "frozen"): - shape[0] = 1 - - # These are ok - shape[0] = "some_string" - self.assertEqual(shape[0], "some_string") - shape[0] = _core.SymbolicDim("some_other_string") - self.assertEqual(shape[0], "some_other_string") - def test_getitem(self): shape = _core.Shape([42], denotations=("DATA_CHANNEL",)) self.assertEqual(shape[0], 42) From 028bb86d5d13ffeadec05d21a82c95e79b330421 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Apr 2025 07:53:53 -0700 Subject: [PATCH 8/8] lint --- onnxscript/ir/_core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 5c0684e22b..ae2cfee95d 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1102,6 +1102,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable): frozen: Indicates whether the shape is immutable. When frozen, the shape cannot be modified or unfrozen. """ + __slots__ = ("_dims", "_frozen") def __init__(