File tree Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -610,7 +610,18 @@ def __call__(
610
610
lhs_in = jnp .float32 (lhs_in )
611
611
if rhs_in .dtype == jnp .int4 :
612
612
rhs_in = jnp .float32 (rhs_in )
613
- if lhs_in .dtype != jnp .int4 and rhs_in .dtype != jnp .int4 :
613
+ # Cast float8_e4m3fn to bfloat16 to avoid the failure in
614
+ # promote_dtype(float8_e4m3fn, x)
615
+ if (
616
+ lhs_in .dtype == jnp .float8_e4m3fn
617
+ and rhs_in .dtype != jnp .float8_e4m3fn
618
+ ):
619
+ lhs_in = jnp .bfloat16 (lhs_in )
620
+ if (
621
+ rhs_in .dtype == jnp .float8_e4m3fn
622
+ and lhs_in .dtype != jnp .float8_e4m3fn
623
+ ):
624
+ rhs_in = jnp .bfloat16 (rhs_in )
614
625
lhs_in , rhs_in = nn .dtypes .promote_dtype (lhs_in , rhs_in )
615
626
616
627
# yes_swap = whether einsum swaps [lhs,rhs] when passing them to dot_general
You can’t perform that action at this time.
0 commit comments