Skip to content

Commit c4ef1fa

Browse files
Cerebra Catalyst Teamcopybara-github
authored andcommitted
Handle the jax.promote_dtype failure for [int4, fp8_e4m3, fp8_e5m2] when using aqt_einsum
PiperOrigin-RevId: 733390414
1 parent a1e1d11 commit c4ef1fa

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

aqt/jax/v2/flax/aqt_flax.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -603,14 +603,43 @@ def __call__(
603603
# from being rejected by assertions in aqt_dot_general.py, line 522-526 and
604604
# 414.
605605
# TODO: b/322111904 - Handle this in more proper way.
606-
# We hand-hold int4 because promote_dtype(int4, x) fails.
607-
# (To avoid unintended promotion, 4-bit integers do not support
608-
# implicit promotion.)
609-
if lhs_in.dtype == jnp.int4:
610-
lhs_in = jnp.float32(lhs_in)
611-
if rhs_in.dtype == jnp.int4:
612-
rhs_in = jnp.float32(rhs_in)
613-
if lhs_in.dtype != jnp.int4 and rhs_in.dtype != jnp.int4:
606+
# We create a list of dtypes and hand-hold them because promote_dtype fails
607+
# for these dtypes.
608+
manual_promotion_dtypes = [jnp.int4, jnp.float8_e4m3fn, jnp.float8_e5m2]
609+
if (
610+
lhs_in.dtype in manual_promotion_dtypes
611+
and rhs_in.dtype in manual_promotion_dtypes
612+
):
613+
if lhs_in.dtype == rhs_in.dtype:
614+
pass
615+
else:
616+
lhs_in = (
617+
jnp.float32(lhs_in)
618+
if lhs_in.dtype == jnp.int4
619+
else jnp.bfloat16(lhs_in)
620+
)
621+
rhs_in = (
622+
jnp.float32(rhs_in)
623+
if rhs_in.dtype == jnp.int4
624+
else jnp.bfloat16(rhs_in)
625+
)
626+
elif lhs_in.dtype in manual_promotion_dtypes:
627+
lhs_in = (
628+
jnp.float32(lhs_in)
629+
if lhs_in.dtype == jnp.int4
630+
else jnp.bfloat16(lhs_in)
631+
)
632+
elif rhs_in.dtype in manual_promotion_dtypes:
633+
rhs_in = (
634+
jnp.float32(rhs_in)
635+
if rhs_in.dtype == jnp.int4
636+
else jnp.bfloat16(rhs_in)
637+
)
638+
639+
if (
640+
lhs_in.dtype not in manual_promotion_dtypes
641+
and rhs_in.dtype not in manual_promotion_dtypes
642+
):
614643
lhs_in, rhs_in = nn.dtypes.promote_dtype(lhs_in, rhs_in)
615644

616645
# yes_swap = whether einsum swaps [lhs,rhs] when passing them to dot_general

0 commit comments

Comments
 (0)