Skip to content

Commit 12bc1f6

Browse files
justinchubybmehta001
authored andcommitted
[IR] Handle ONNX custom types in DataType.from_numpy (microsoft#2131)
Fixes microsoft#1893 where the IR was confused about ONNX custom types. In the long run we should update onnx to use ml_dtypes.
1 parent 6f58641 commit 12bc1f6

File tree

2 files changed

+100
-6
lines changed

2 files changed

+100
-6
lines changed

onnxscript/ir/_enums.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,32 @@ def from_numpy(cls, dtype: np.dtype) -> DataType:
7373
Raises:
7474
TypeError: If the data type is not supported by ONNX.
7575
"""
76-
if dtype not in _NP_TYPE_TO_DATA_TYPE:
77-
raise TypeError(f"Unsupported numpy data type: {dtype}")
78-
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
76+
if dtype in _NP_TYPE_TO_DATA_TYPE:
77+
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
78+
79+
if np.issubdtype(dtype, np.str_):
80+
return DataType.STRING
81+
82+
# Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
83+
# Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
84+
if hasattr(dtype, "names"):
85+
if dtype.names == ("bfloat16",):
86+
return DataType.BFLOAT16
87+
if dtype.names == ("e4m3fn",):
88+
return DataType.FLOAT8E4M3FN
89+
if dtype.names == ("e4m3fnuz",):
90+
return DataType.FLOAT8E4M3FNUZ
91+
if dtype.names == ("e5m2",):
92+
return DataType.FLOAT8E5M2
93+
if dtype.names == ("e5m2fnuz",):
94+
return DataType.FLOAT8E5M2FNUZ
95+
if dtype.names == ("uint4",):
96+
return DataType.UINT4
97+
if dtype.names == ("int4",):
98+
return DataType.INT4
99+
if dtype.names == ("float4e2m1",):
100+
return DataType.FLOAT4E2M1
101+
raise TypeError(f"Unsupported numpy data type: {dtype}")
79102

80103
@property
81104
def itemsize(self) -> float:

onnxscript/ir/_enums_test.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3+
# pylint: disable=protected-access
34
import unittest
45

6+
import ml_dtypes
57
import numpy as np
68
import onnx
9+
import onnx._custom_element_types
10+
import parameterized
711

812
from onnxscript.ir import _enums
913

@@ -36,9 +40,76 @@ def test_enums_are_the_same_as_spec(self):
3640
self.assertEqual(_enums.DataType.FLOAT4E2M1, onnx.TensorProto.FLOAT4E2M1)
3741
self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED)
3842

39-
def test_from_numpy_takes_np_dtype_and_returns_data_type(self):
40-
array = np.array([], dtype=np.float64)
41-
self.assertEqual(_enums.DataType.from_numpy(array.dtype), _enums.DataType.DOUBLE)
43+
@parameterized.parameterized.expand(
44+
[
45+
("string", np.array("some_string").dtype, _enums.DataType.STRING),
46+
("float64", np.dtype(np.float64), _enums.DataType.DOUBLE),
47+
("float32", np.dtype(np.float32), _enums.DataType.FLOAT),
48+
("float16", np.dtype(np.float16), _enums.DataType.FLOAT16),
49+
("int32", np.dtype(np.int32), _enums.DataType.INT32),
50+
("int16", np.dtype(np.int16), _enums.DataType.INT16),
51+
("int8", np.dtype(np.int8), _enums.DataType.INT8),
52+
("int64", np.dtype(np.int64), _enums.DataType.INT64),
53+
("uint8", np.dtype(np.uint8), _enums.DataType.UINT8),
54+
("uint16", np.dtype(np.uint16), _enums.DataType.UINT16),
55+
("uint32", np.dtype(np.uint32), _enums.DataType.UINT32),
56+
("uint64", np.dtype(np.uint64), _enums.DataType.UINT64),
57+
("bool", np.dtype(np.bool_), _enums.DataType.BOOL),
58+
("complex64", np.dtype(np.complex64), _enums.DataType.COMPLEX64),
59+
("complex128", np.dtype(np.complex128), _enums.DataType.COMPLEX128),
60+
("bfloat16", np.dtype(ml_dtypes.bfloat16), _enums.DataType.BFLOAT16),
61+
("float8e4m3fn", np.dtype(ml_dtypes.float8_e4m3fn), _enums.DataType.FLOAT8E4M3FN),
62+
(
63+
"float8e4m3fnuz",
64+
np.dtype(ml_dtypes.float8_e4m3fnuz),
65+
_enums.DataType.FLOAT8E4M3FNUZ,
66+
),
67+
("float8e5m2", np.dtype(ml_dtypes.float8_e5m2), _enums.DataType.FLOAT8E5M2),
68+
(
69+
"float8e5m2fnuz",
70+
np.dtype(ml_dtypes.float8_e5m2fnuz),
71+
_enums.DataType.FLOAT8E5M2FNUZ,
72+
),
73+
("uint4", np.dtype(ml_dtypes.uint4), _enums.DataType.UINT4),
74+
("int4", np.dtype(ml_dtypes.int4), _enums.DataType.INT4),
75+
("float4e2m1", np.dtype(ml_dtypes.float4_e2m1fn), _enums.DataType.FLOAT4E2M1),
76+
(
77+
"onnx_ref_bfloat16",
78+
onnx._custom_element_types.bfloat16,
79+
_enums.DataType.BFLOAT16,
80+
),
81+
(
82+
"onnx_ref_float8e4m3fn",
83+
onnx._custom_element_types.float8e4m3fn,
84+
_enums.DataType.FLOAT8E4M3FN,
85+
),
86+
(
87+
"onnx_ref_float8e4m3fnuz",
88+
onnx._custom_element_types.float8e4m3fnuz,
89+
_enums.DataType.FLOAT8E4M3FNUZ,
90+
),
91+
(
92+
"onnx_ref_float8e5m2",
93+
onnx._custom_element_types.float8e5m2,
94+
_enums.DataType.FLOAT8E5M2,
95+
),
96+
(
97+
"onnx_ref_float8e5m2fnuz",
98+
onnx._custom_element_types.float8e5m2fnuz,
99+
_enums.DataType.FLOAT8E5M2FNUZ,
100+
),
101+
(
102+
"onnx_ref_uint4",
103+
onnx._custom_element_types.uint4,
104+
_enums.DataType.UINT4,
105+
),
106+
("onnx_ref_int4", onnx._custom_element_types.int4, _enums.DataType.INT4),
107+
]
108+
)
109+
def test_from_numpy_takes_np_dtype_and_returns_data_type(
110+
self, _: str, np_dtype: np.dtype, onnx_type: _enums.DataType
111+
):
112+
self.assertEqual(_enums.DataType.from_numpy(np_dtype), onnx_type)
42113

43114
def test_numpy_returns_np_dtype(self):
44115
self.assertEqual(_enums.DataType.DOUBLE.numpy(), np.dtype(np.float64))

0 commit comments

Comments
 (0)