Skip to content

Commit 4905bfd

Browse files
authored
Update constant fold to use correct numpy type (#2204)
In PyTorch<=2.7, the numpy arrays for bfloat16 and float8 types have dtypes UINT16 and UINT8, which leads to incorrect constant folded graphs. This PR updates the numpy helper to cast the arrays to the correct dtypes. Fix #2187
1 parent 0deb51b commit 4905bfd

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,9 @@ def _get_numpy_value(
297297
if size_limit is not None and const_value.size > size_limit:
298298
return None
299299
try:
300-
array = const_value.numpy()
300+
# Reinterpret the array with `.view()` because some implementations of
301+
# ir.TensorProtocol (e.g. PyTorch<=2.7) do not use ml_dtypes for bfloat16 etc.
302+
array = const_value.numpy().view(const_value.dtype.numpy())
301303
except FileNotFoundError:
302304
# External data is not available.
303305
return None

0 commit comments

Comments
 (0)