We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
dtype_byte_size
1 parent 8a953f0 commit bee04f1Copy full SHA for bee04f1
src/accelerate/utils/modeling.py
@@ -169,7 +169,7 @@ def dtype_byte_size(dtype: torch.dtype):
169
return 1 / 2
170
elif dtype == CustomDtype.FP8:
171
return 1
172
- elif is_torch_version(">=", "2.1.0") and dtype == torch.float8_e4m3fn:
+ elif is_torch_version(">=", "2.1.0") and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
173
174
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
175
if bit_search is None:
0 commit comments