|
1 | 1 | # Copyright (c) Microsoft Corporation.
|
2 | 2 | # Licensed under the MIT License.
|
| 3 | +# pylint: disable=protected-access |
3 | 4 | import unittest
|
4 | 5 |
|
| 6 | +import ml_dtypes |
5 | 7 | import numpy as np
|
6 | 8 | import onnx
|
| 9 | +import onnx._custom_element_types |
| 10 | +import parameterized |
7 | 11 |
|
8 | 12 | from onnxscript.ir import _enums
|
9 | 13 |
|
@@ -36,9 +40,76 @@ def test_enums_are_the_same_as_spec(self):
|
36 | 40 | self.assertEqual(_enums.DataType.FLOAT4E2M1, onnx.TensorProto.FLOAT4E2M1)
|
37 | 41 | self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED)
|
38 | 42 |
|
39 |
| - def test_from_numpy_takes_np_dtype_and_returns_data_type(self): |
40 |
| - array = np.array([], dtype=np.float64) |
41 |
| - self.assertEqual(_enums.DataType.from_numpy(array.dtype), _enums.DataType.DOUBLE) |
| 43 | + @parameterized.parameterized.expand( |
| 44 | + [ |
| 45 | + ("string", np.array("some_string").dtype, _enums.DataType.STRING), |
| 46 | + ("float64", np.dtype(np.float64), _enums.DataType.DOUBLE), |
| 47 | + ("float32", np.dtype(np.float32), _enums.DataType.FLOAT), |
| 48 | + ("float16", np.dtype(np.float16), _enums.DataType.FLOAT16), |
| 49 | + ("int32", np.dtype(np.int32), _enums.DataType.INT32), |
| 50 | + ("int16", np.dtype(np.int16), _enums.DataType.INT16), |
| 51 | + ("int8", np.dtype(np.int8), _enums.DataType.INT8), |
| 52 | + ("int64", np.dtype(np.int64), _enums.DataType.INT64), |
| 53 | + ("uint8", np.dtype(np.uint8), _enums.DataType.UINT8), |
| 54 | + ("uint16", np.dtype(np.uint16), _enums.DataType.UINT16), |
| 55 | + ("uint32", np.dtype(np.uint32), _enums.DataType.UINT32), |
| 56 | + ("uint64", np.dtype(np.uint64), _enums.DataType.UINT64), |
| 57 | + ("bool", np.dtype(np.bool_), _enums.DataType.BOOL), |
| 58 | + ("complex64", np.dtype(np.complex64), _enums.DataType.COMPLEX64), |
| 59 | + ("complex128", np.dtype(np.complex128), _enums.DataType.COMPLEX128), |
| 60 | + ("bfloat16", np.dtype(ml_dtypes.bfloat16), _enums.DataType.BFLOAT16), |
| 61 | + ("float8e4m3fn", np.dtype(ml_dtypes.float8_e4m3fn), _enums.DataType.FLOAT8E4M3FN), |
| 62 | + ( |
| 63 | + "float8e4m3fnuz", |
| 64 | + np.dtype(ml_dtypes.float8_e4m3fnuz), |
| 65 | + _enums.DataType.FLOAT8E4M3FNUZ, |
| 66 | + ), |
| 67 | + ("float8e5m2", np.dtype(ml_dtypes.float8_e5m2), _enums.DataType.FLOAT8E5M2), |
| 68 | + ( |
| 69 | + "float8e5m2fnuz", |
| 70 | + np.dtype(ml_dtypes.float8_e5m2fnuz), |
| 71 | + _enums.DataType.FLOAT8E5M2FNUZ, |
| 72 | + ), |
| 73 | + ("uint4", np.dtype(ml_dtypes.uint4), _enums.DataType.UINT4), |
| 74 | + ("int4", np.dtype(ml_dtypes.int4), _enums.DataType.INT4), |
| 75 | + ("float4e2m1", np.dtype(ml_dtypes.float4_e2m1fn), _enums.DataType.FLOAT4E2M1), |
| 76 | + ( |
| 77 | + "onnx_ref_bfloat16", |
| 78 | + onnx._custom_element_types.bfloat16, |
| 79 | + _enums.DataType.BFLOAT16, |
| 80 | + ), |
| 81 | + ( |
| 82 | + "onnx_ref_float8e4m3fn", |
| 83 | + onnx._custom_element_types.float8e4m3fn, |
| 84 | + _enums.DataType.FLOAT8E4M3FN, |
| 85 | + ), |
| 86 | + ( |
| 87 | + "onnx_ref_float8e4m3fnuz", |
| 88 | + onnx._custom_element_types.float8e4m3fnuz, |
| 89 | + _enums.DataType.FLOAT8E4M3FNUZ, |
| 90 | + ), |
| 91 | + ( |
| 92 | + "onnx_ref_float8e5m2", |
| 93 | + onnx._custom_element_types.float8e5m2, |
| 94 | + _enums.DataType.FLOAT8E5M2, |
| 95 | + ), |
| 96 | + ( |
| 97 | + "onnx_ref_float8e5m2fnuz", |
| 98 | + onnx._custom_element_types.float8e5m2fnuz, |
| 99 | + _enums.DataType.FLOAT8E5M2FNUZ, |
| 100 | + ), |
| 101 | + ( |
| 102 | + "onnx_ref_uint4", |
| 103 | + onnx._custom_element_types.uint4, |
| 104 | + _enums.DataType.UINT4, |
| 105 | + ), |
| 106 | + ("onnx_ref_int4", onnx._custom_element_types.int4, _enums.DataType.INT4), |
| 107 | + ] |
| 108 | + ) |
| 109 | + def test_from_numpy_takes_np_dtype_and_returns_data_type( |
| 110 | + self, _: str, np_dtype: np.dtype, onnx_type: _enums.DataType |
| 111 | + ): |
| 112 | + self.assertEqual(_enums.DataType.from_numpy(np_dtype), onnx_type) |
42 | 113 |
|
43 | 114 | def test_numpy_returns_np_dtype(self):
|
44 | 115 | self.assertEqual(_enums.DataType.DOUBLE.numpy(), np.dtype(np.float64))
|
|
0 commit comments