From d47b3cfdfe83706f29a356fe4019ca4f79960121 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Apr 2025 18:33:44 -0700 Subject: [PATCH 1/3] [IR] Introduce short name for dtypes --- onnxscript/ir/_enums.py | 52 ++++++++++++++++++++++++++++++++++++ onnxscript/ir/_enums_test.py | 28 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py index 95cfff8682..31d8e50603 100644 --- a/onnxscript/ir/_enums.py +++ b/onnxscript/ir/_enums.py @@ -100,6 +100,17 @@ def from_numpy(cls, dtype: np.dtype) -> DataType: return DataType.FLOAT4E2M1 raise TypeError(f"Unsupported numpy data type: {dtype}") + @classmethod + def from_short_name(cls, short_name: str) -> DataType: + """Returns the ONNX data type for the short name. + + Raises: + TypeError: If the short name is not available for the data type. + """ + if short_name not in _SHORT_NAME_TO_DATA_TYPE: + raise TypeError(f"Unknown short name: {short_name}") + return cls(_SHORT_NAME_TO_DATA_TYPE[short_name]) + @property def itemsize(self) -> float: """Returns the size of the data type in bytes.""" @@ -115,6 +126,22 @@ def numpy(self) -> np.dtype: raise TypeError(f"Numpy does not support ONNX data type: {self}") return _DATA_TYPE_TO_NP_TYPE[self] + def short_name(self) -> str: + """Returns the short name of the data type. + + The short name is a string that is used to represent the data type in a more + compact form. For example, the short name for `DataType.FLOAT` is "f32". + To get the corresponding data type back, call ``from_short_name`` on a string. + + Naming reference: https://github.com/pytorch/pytorch/blob/4bead7b85ea4160243c74109e0ce9bb80686d016/torch/utils/_dtype_abbrs.py + + Raises: + TypeError: If the short name is not available for the data type. + """ + if self not in _DATA_TYPE_TO_SHORT_NAME: + raise TypeError(f"Short name not available for ONNX data type: {self}") + return _DATA_TYPE_TO_SHORT_NAME[self] + def __repr__(self) -> str: return self.name @@ -184,3 +211,28 @@ def __str__(self) -> str: # ONNX DataType to Numpy dtype. _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()} + +_DATA_TYPE_TO_SHORT_NAME = { + DataType.BFLOAT16: "bf16", + DataType.DOUBLE: "f64", + DataType.FLOAT: "f32", + DataType.FLOAT16: "f16", + DataType.FLOAT8E4M3FN: "f8e4m3fn", + DataType.FLOAT8E5M2: "f8e5m2", + DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz", + DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz", + DataType.FLOAT4E2M1: "f4e2m1", + DataType.COMPLEX64: "c64", + DataType.COMPLEX128: "c128", + DataType.INT8: "i8", + DataType.INT16: "i16", + DataType.INT32: "i32", + DataType.INT64: "i64", + DataType.BOOL: "b8", + DataType.UINT8: "u8", + DataType.UINT16: "u16", + DataType.UINT32: "u32", + DataType.UINT64: "u64", +} + +_SHORT_NAME_TO_DATA_TYPE = {v: k for k, v in _DATA_TYPE_TO_SHORT_NAME.items()} diff --git a/onnxscript/ir/_enums_test.py b/onnxscript/ir/_enums_test.py index 1b22f2cdb6..e79eec26bb 100644 --- a/onnxscript/ir/_enums_test.py +++ b/onnxscript/ir/_enums_test.py @@ -122,6 +122,34 @@ def test_repr_and_str_return_name(self): self.assertEqual(str(_enums.DataType.DOUBLE), "DOUBLE") self.assertEqual(repr(_enums.DataType.DOUBLE), "DOUBLE") + @parameterized.parameterized.expand( + [ + ("bf16",), + ("f64",), + ("f32",), + ("f16",), + ("f8e4m3fn",), + ("f8e5m2",), + ("f8e4m3fnuz",), + ("f8e5m2fnuz",), + ("f4e2m1",), + ("c64",), + ("c128",), + ("i8",), + ("i16",), + ("i32",), + ("i64",), + ("b8",), + ("u8",), + ("u16",), + ("u32",), + ("u64",), + ] + ) + def test_short_name_conversion(self, short_name: str): + dtype = _enums.DataType.from_short_name(short_name) + self.assertEqual(dtype.short_name(), short_name) + class AttributeTypeTest(unittest.TestCase): def test_enums_are_the_same_as_spec(self): From 1d55cda05c87488cd1ee81a24503acb16d529cf6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Apr 2025 18:37:49 -0700 Subject: [PATCH 2/3] test --- onnxscript/ir/_enums_test.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/onnxscript/ir/_enums_test.py b/onnxscript/ir/_enums_test.py index e79eec26bb..54cc7f530a 100644 --- a/onnxscript/ir/_enums_test.py +++ b/onnxscript/ir/_enums_test.py @@ -150,6 +150,32 @@ def test_short_name_conversion(self, short_name: str): dtype = _enums.DataType.from_short_name(short_name) self.assertEqual(dtype.short_name(), short_name) + def test_access_by_name(self): + self.assertEqual(_enums.DataType["FLOAT"], _enums.DataType.FLOAT) + self.assertEqual(_enums.DataType["UINT8"], _enums.DataType.UINT8) + self.assertEqual(_enums.DataType["INT8"], _enums.DataType.INT8) + self.assertEqual(_enums.DataType["UINT16"], _enums.DataType.UINT16) + self.assertEqual(_enums.DataType["INT16"], _enums.DataType.INT16) + self.assertEqual(_enums.DataType["INT32"], _enums.DataType.INT32) + self.assertEqual(_enums.DataType["INT64"], _enums.DataType.INT64) + self.assertEqual(_enums.DataType["STRING"], _enums.DataType.STRING) + self.assertEqual(_enums.DataType["BOOL"], _enums.DataType.BOOL) + self.assertEqual(_enums.DataType["FLOAT16"], _enums.DataType.FLOAT16) + self.assertEqual(_enums.DataType["DOUBLE"], _enums.DataType.DOUBLE) + self.assertEqual(_enums.DataType["UINT32"], _enums.DataType.UINT32) + self.assertEqual(_enums.DataType["UINT64"], _enums.DataType.UINT64) + self.assertEqual(_enums.DataType["COMPLEX64"], _enums.DataType.COMPLEX64) + self.assertEqual(_enums.DataType["COMPLEX128"], _enums.DataType.COMPLEX128) + self.assertEqual(_enums.DataType["BFLOAT16"], _enums.DataType.BFLOAT16) + self.assertEqual(_enums.DataType["FLOAT8E4M3FN"], _enums.DataType.FLOAT8E4M3FN) + self.assertEqual(_enums.DataType["FLOAT8E4M3FNUZ"], _enums.DataType.FLOAT8E4M3FNUZ) + self.assertEqual(_enums.DataType["FLOAT8E5M2"], _enums.DataType.FLOAT8E5M2) + self.assertEqual(_enums.DataType["FLOAT8E5M2FNUZ"], _enums.DataType.FLOAT8E5M2FNUZ) + self.assertEqual(_enums.DataType["UINT4"], _enums.DataType.UINT4) + self.assertEqual(_enums.DataType["INT4"], _enums.DataType.INT4) + self.assertEqual(_enums.DataType["FLOAT4E2M1"], _enums.DataType.FLOAT4E2M1) + self.assertEqual(_enums.DataType["UNDEFINED"], _enums.DataType.UNDEFINED) + class AttributeTypeTest(unittest.TestCase): def test_enums_are_the_same_as_spec(self): From a207f2217db2f05afab0a9bd77b6f67e2f3ee3cb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Apr 2025 18:46:39 -0700 Subject: [PATCH 3/3] fix short --- onnxscript/ir/_enums.py | 4 ++++ onnxscript/ir/_enums_test.py | 31 ++++--------------------------- 2 files changed, 8 insertions(+), 27 deletions(-) diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py index 31d8e50603..9ecce9fed3 100644 --- a/onnxscript/ir/_enums.py +++ b/onnxscript/ir/_enums.py @@ -213,6 +213,7 @@ def __str__(self) -> str: _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()} _DATA_TYPE_TO_SHORT_NAME = { + DataType.UNDEFINED: "undefined", DataType.BFLOAT16: "bf16", DataType.DOUBLE: "f64", DataType.FLOAT: "f32", @@ -224,15 +225,18 @@ def __str__(self) -> str: DataType.FLOAT4E2M1: "f4e2m1", DataType.COMPLEX64: "c64", DataType.COMPLEX128: "c128", + DataType.INT4: "i4", DataType.INT8: "i8", DataType.INT16: "i16", DataType.INT32: "i32", DataType.INT64: "i64", DataType.BOOL: "b8", + DataType.UINT4: "u4", DataType.UINT8: "u8", DataType.UINT16: "u16", DataType.UINT32: "u32", DataType.UINT64: "u64", + DataType.STRING: "s", } _SHORT_NAME_TO_DATA_TYPE = {v: k for k, v in _DATA_TYPE_TO_SHORT_NAME.items()} diff --git a/onnxscript/ir/_enums_test.py b/onnxscript/ir/_enums_test.py index 54cc7f530a..906bf7b572 100644 --- a/onnxscript/ir/_enums_test.py +++ b/onnxscript/ir/_enums_test.py @@ -122,33 +122,10 @@ def test_repr_and_str_return_name(self): self.assertEqual(str(_enums.DataType.DOUBLE), "DOUBLE") self.assertEqual(repr(_enums.DataType.DOUBLE), "DOUBLE") - @parameterized.parameterized.expand( - [ - ("bf16",), - ("f64",), - ("f32",), - ("f16",), - ("f8e4m3fn",), - ("f8e5m2",), - ("f8e4m3fnuz",), - ("f8e5m2fnuz",), - ("f4e2m1",), - ("c64",), - ("c128",), - ("i8",), - ("i16",), - ("i32",), - ("i64",), - ("b8",), - ("u8",), - ("u16",), - ("u32",), - ("u64",), - ] - ) - def test_short_name_conversion(self, short_name: str): - dtype = _enums.DataType.from_short_name(short_name) - self.assertEqual(dtype.short_name(), short_name) + def test_short_name_conversion(self): + for dtype in _enums.DataType: + short_name = dtype.short_name() + self.assertEqual(_enums.DataType.from_short_name(short_name), dtype) def test_access_by_name(self): self.assertEqual(_enums.DataType["FLOAT"], _enums.DataType.FLOAT)