Skip to content

[IR] Handle ONNX custom types in DataType.from_numpy #2131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions onnxscript/ir/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,25 @@
TypeError: If the data type is not supported by ONNX.
"""
if dtype not in _NP_TYPE_TO_DATA_TYPE:
# Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
# Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
if hasattr(dtype, "names"):
if dtype.names == ("bfloat16",):
return DataType.BFLOAT16

Check warning on line 81 in onnxscript/ir/_enums.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums.py#L81

Added line #L81 was not covered by tests
if dtype.names == ("e4m3fn",):
return DataType.FLOAT8E4M3FN

Check warning on line 83 in onnxscript/ir/_enums.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums.py#L83

Added line #L83 was not covered by tests
if dtype.names == ("e4m3fnuz",):
return DataType.FLOAT8E4M3FNUZ

Check warning on line 85 in onnxscript/ir/_enums.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums.py#L85

Added line #L85 was not covered by tests
if dtype.names == ("e5m2",):
return DataType.FLOAT8E5M2

Check warning on line 87 in onnxscript/ir/_enums.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums.py#L87

Added line #L87 was not covered by tests
if dtype.names == ("e5m2fnuz",):
return DataType.FLOAT8E5M2FNUZ

Check warning on line 89 in onnxscript/ir/_enums.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums.py#L89

Added line #L89 was not covered by tests
if dtype.names == ("uint4",):
return DataType.UINT4

Check warning on line 91 in onnxscript/ir/_enums.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums.py#L91

Added line #L91 was not covered by tests
if dtype.names == ("int4",):
return DataType.INT4

Check warning on line 93 in onnxscript/ir/_enums.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums.py#L93

Added line #L93 was not covered by tests
if dtype.names == ("float4e2m1",):
return DataType.FLOAT4E2M1

Check warning on line 95 in onnxscript/ir/_enums.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums.py#L95

Added line #L95 was not covered by tests
raise TypeError(f"Unsupported numpy data type: {dtype}")
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])

Expand Down
81 changes: 78 additions & 3 deletions onnxscript/ir/_enums_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=protected-access
import unittest

import ml_dtypes
import numpy as np
import onnx
import onnx._custom_element_types
import parameterized

from onnxscript.ir import _enums

Expand Down Expand Up @@ -36,9 +40,80 @@
self.assertEqual(_enums.DataType.FLOAT4E2M1, onnx.TensorProto.FLOAT4E2M1)
self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED)

def test_from_numpy_takes_np_dtype_and_returns_data_type(self):
array = np.array([], dtype=np.float64)
self.assertEqual(_enums.DataType.from_numpy(array.dtype), _enums.DataType.DOUBLE)
@parameterized.parameterized.expand(
[
("float64", np.dtype(np.float64), _enums.DataType.DOUBLE),
("float32", np.dtype(np.float32), _enums.DataType.FLOAT),
("float16", np.dtype(np.float16), _enums.DataType.FLOAT16),
("int32", np.dtype(np.int32), _enums.DataType.INT32),
("int16", np.dtype(np.int16), _enums.DataType.INT16),
("int8", np.dtype(np.int8), _enums.DataType.INT8),
("int64", np.dtype(np.int64), _enums.DataType.INT64),
("uint8", np.dtype(np.uint8), _enums.DataType.UINT8),
("uint16", np.dtype(np.uint16), _enums.DataType.UINT16),
("uint32", np.dtype(np.uint32), _enums.DataType.UINT32),
("uint64", np.dtype(np.uint64), _enums.DataType.UINT64),
("bool", np.dtype(np.bool_), _enums.DataType.BOOL),
("complex64", np.dtype(np.complex64), _enums.DataType.COMPLEX64),
("complex128", np.dtype(np.complex128), _enums.DataType.COMPLEX128),
("bfloat16", np.dtype(ml_dtypes.bfloat16), _enums.DataType.BFLOAT16),
("float8e4m3fn", np.dtype(ml_dtypes.float8_e4m3fn), _enums.DataType.FLOAT8E4M3FN),
(
"float8e4m3fnuz",
np.dtype(ml_dtypes.float8_e4m3fnuz),
_enums.DataType.FLOAT8E4M3FNUZ,
),
("float8e5m2", np.dtype(ml_dtypes.float8_e5m2), _enums.DataType.FLOAT8E5M2),
(
"float8e5m2fnuz",
np.dtype(ml_dtypes.float8_e5m2fnuz),
_enums.DataType.FLOAT8E5M2FNUZ,
),
("uint4", np.dtype(ml_dtypes.uint4), _enums.DataType.UINT4),
("int4", np.dtype(ml_dtypes.int4), _enums.DataType.INT4),
("float4e2m1", np.dtype(ml_dtypes.float4_e2m1fn), _enums.DataType.FLOAT4E2M1),
(
"onnx_ref_bfloat16",
onnx._custom_element_types.bfloat16,
_enums.DataType.BFLOAT16,
),
(
"onnx_ref_float8e4m3fn",
onnx._custom_element_types.float8e4m3fn,
_enums.DataType.FLOAT8E4M3FN,
),
(
"onnx_ref_float8e4m3fnuz",
onnx._custom_element_types.float8e4m3fnuz,
_enums.DataType.FLOAT8E4M3FNUZ,
),
(
"onnx_ref_float8e5m2",
onnx._custom_element_types.float8e5m2,
_enums.DataType.FLOAT8E5M2,
),
(
"onnx_ref_float8e5m2fnuz",
onnx._custom_element_types.float8e5m2fnuz,
_enums.DataType.FLOAT8E5M2FNUZ,
),
(
"onnx_ref_uint4",
onnx._custom_element_types.uint4,
_enums.DataType.UINT4,
),
("onnx_ref_int4", onnx._custom_element_types.int4, _enums.DataType.INT4),
(
"onnx_ref_float4e2m1",
onnx._custom_element_types.float4e2m1,
_enums.DataType.FLOAT4E2M1,
),
]
)
def test_from_numpy_takes_np_dtype_and_returns_data_type(

Check warning on line 113 in onnxscript/ir/_enums_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums_test.py#L113

Added line #L113 was not covered by tests
self, _: str, np_dtype: np.dtype, onnx_type: _enums.DataType
):
self.assertEqual(_enums.DataType.from_numpy(np_dtype), onnx_type)

Check warning on line 116 in onnxscript/ir/_enums_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums_test.py#L116

Added line #L116 was not covered by tests

def test_numpy_returns_np_dtype(self):
self.assertEqual(_enums.DataType.DOUBLE.numpy(), np.dtype(np.float64))
Expand Down
Loading