Skip to content

Commit b8771b9

Browse files
Cerebra Catalyst Teamcopybara-github
authored andcommitted
Avoid jax.promote_dtype(float8_e4m3fn, x) failure when using aqt_einsum with fp8_e4m3fn dtype
PiperOrigin-RevId: 733390414
1 parent bc98e84 commit b8771b9

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

aqt/jax/v2/flax/aqt_flax.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,18 @@ def __call__(
610610
lhs_in = jnp.float32(lhs_in)
611611
if rhs_in.dtype == jnp.int4:
612612
rhs_in = jnp.float32(rhs_in)
613-
if lhs_in.dtype != jnp.int4 and rhs_in.dtype != jnp.int4:
613+
# Cast float8_e4m3fn to bfloat16 to avoid the failure in
614+
# promote_dtype(float8_e4m3fn, x)
615+
if (
616+
lhs_in.dtype == jnp.float8_e4m3fn
617+
and rhs_in.dtype != jnp.float8_e4m3fn
618+
):
619+
lhs_in = jnp.bfloat16(lhs_in)
620+
if (
621+
rhs_in.dtype == jnp.float8_e4m3fn
622+
and lhs_in.dtype != jnp.float8_e4m3fn
623+
):
624+
rhs_in = jnp.bfloat16(rhs_in)
614625
lhs_in, rhs_in = nn.dtypes.promote_dtype(lhs_in, rhs_in)
615626

616627
# yes_swap = whether einsum swaps [lhs,rhs] when passing them to dot_general

0 commit comments

Comments
 (0)