Skip to content

Commit bee04f1

Browse files
authored
Add fp8_e5m2 support in dtype_byte_size (#3625)
* float8_e5m2 device_map * remove prints
1 parent 8a953f0 commit bee04f1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/accelerate/utils/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def dtype_byte_size(dtype: torch.dtype):
169169
return 1 / 2
170170
elif dtype == CustomDtype.FP8:
171171
return 1
172-
elif is_torch_version(">=", "2.1.0") and dtype == torch.float8_e4m3fn:
172+
elif is_torch_version(">=", "2.1.0") and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
173173
return 1
174174
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
175175
if bit_search is None:

0 commit comments

Comments
 (0)