@@ -603,14 +603,43 @@ def __call__(
603
603
# from being rejected by assertions in aqt_dot_general.py, line 522-526 and
604
604
# 414.
605
605
# TODO: b/322111904 - Handle this in more proper way.
606
- # We hand-hold int4 because promote_dtype(int4, x) fails.
607
- # (To avoid unintended promotion, 4-bit integers do not support
608
- # implicit promotion.)
609
- if lhs_in .dtype == jnp .int4 :
610
- lhs_in = jnp .float32 (lhs_in )
611
- if rhs_in .dtype == jnp .int4 :
612
- rhs_in = jnp .float32 (rhs_in )
613
- if lhs_in .dtype != jnp .int4 and rhs_in .dtype != jnp .int4 :
606
+ # We create a list of dtypes and hand-hold them because promote_dtype fails
607
+ # for these dtypes.
608
+ manual_promotion_dtypes = [jnp .int4 , jnp .float8_e4m3fn , jnp .float8_e5m2 ]
609
+ if (
610
+ lhs_in .dtype in manual_promotion_dtypes
611
+ and rhs_in .dtype in manual_promotion_dtypes
612
+ ):
613
+ if lhs_in .dtype == rhs_in .dtype :
614
+ pass
615
+ else :
616
+ lhs_in = (
617
+ jnp .float32 (lhs_in )
618
+ if lhs_in .dtype == jnp .int4
619
+ else jnp .bfloat16 (lhs_in )
620
+ )
621
+ rhs_in = (
622
+ jnp .float32 (rhs_in )
623
+ if rhs_in .dtype == jnp .int4
624
+ else jnp .bfloat16 (rhs_in )
625
+ )
626
+ elif lhs_in .dtype in manual_promotion_dtypes :
627
+ lhs_in = (
628
+ jnp .float32 (lhs_in )
629
+ if lhs_in .dtype == jnp .int4
630
+ else jnp .bfloat16 (lhs_in )
631
+ )
632
+ elif rhs_in .dtype in manual_promotion_dtypes :
633
+ rhs_in = (
634
+ jnp .float32 (rhs_in )
635
+ if rhs_in .dtype == jnp .int4
636
+ else jnp .bfloat16 (rhs_in )
637
+ )
638
+
639
+ if (
640
+ lhs_in .dtype not in manual_promotion_dtypes
641
+ and rhs_in .dtype not in manual_promotion_dtypes
642
+ ):
614
643
lhs_in , rhs_in = nn .dtypes .promote_dtype (lhs_in , rhs_in )
615
644
616
645
# yes_swap = whether einsum swaps [lhs,rhs] when passing them to dot_general
0 commit comments