Skip to content

Commit 89dd454

Browse files
authored
[IR] Fix an error when checking for float8_e4m3fnuz type in ir.Tensor (#2078)
The float8_e4m3fnuz type was mistaken with float8_e4m3b11fnuz, which is a different type: https://github.com/jax-ml/ml_dtypes#float8_e4m3b11fnuz
1 parent 1a8dbd7 commit 89dd454

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

onnxscript/ir/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
199199
)
200200
if dtype.itemsize == 1 and array.dtype not in (
201201
np.uint8,
202-
ml_dtypes.float8_e4m3b11fnuz,
202+
ml_dtypes.float8_e4m3fnuz,
203203
ml_dtypes.float8_e4m3fn,
204204
ml_dtypes.float8_e5m2fnuz,
205205
ml_dtypes.float8_e5m2,

0 commit comments

Comments
 (0)