@@ -99,6 +99,60 @@ def _get_singleton_axes(x: jnp.ndarray) -> list[utils.AxisIdx]:
99
99
return qt
100
100
101
101
102
+ def aqt_promote_dtype (
103
+ lhs_in : jnp .ndarray , rhs_in : jnp .ndarray
104
+ ) -> tuple [jnp .ndarray , jnp .ndarray ]:
105
+ """Promotes the dtype of lhs_in and rhs_in.
106
+
107
+ Args:
108
+ lhs_in: Left-hand-side array.
109
+ rhs_in: Right-hand-side array.
110
+
111
+ Returns:
112
+ A tuple of the promoted lhs_in and rhs_in.
113
+
114
+ We create a list of dtypes and hand-hold them because promote_dtype fails for
115
+ these dtypes.
116
+ """
117
+
118
+ manual_promotion_dtypes = [jnp .int4 , jnp .float8_e4m3fn , jnp .float8_e5m2 ]
119
+ if (
120
+ lhs_in .dtype in manual_promotion_dtypes
121
+ and rhs_in .dtype in manual_promotion_dtypes
122
+ ):
123
+ if lhs_in .dtype == rhs_in .dtype :
124
+ pass
125
+ else :
126
+ lhs_in = (
127
+ jnp .float32 (lhs_in )
128
+ if lhs_in .dtype == jnp .int4
129
+ else jnp .bfloat16 (lhs_in )
130
+ )
131
+ rhs_in = (
132
+ jnp .float32 (rhs_in )
133
+ if rhs_in .dtype == jnp .int4
134
+ else jnp .bfloat16 (rhs_in )
135
+ )
136
+ elif lhs_in .dtype in manual_promotion_dtypes :
137
+ lhs_in = (
138
+ jnp .float32 (lhs_in )
139
+ if lhs_in .dtype == jnp .int4
140
+ else jnp .bfloat16 (lhs_in )
141
+ )
142
+ elif rhs_in .dtype in manual_promotion_dtypes :
143
+ rhs_in = (
144
+ jnp .float32 (rhs_in )
145
+ if rhs_in .dtype == jnp .int4
146
+ else jnp .bfloat16 (rhs_in )
147
+ )
148
+ elif (
149
+ lhs_in .dtype not in manual_promotion_dtypes
150
+ and rhs_in .dtype not in manual_promotion_dtypes
151
+ ):
152
+ lhs_in , rhs_in = nn .dtypes .promote_dtype (lhs_in , rhs_in )
153
+ return lhs_in , rhs_in
154
+
155
+
102
156
class FreezerMode (enum .Enum ):
103
157
NONE = 1
104
158
CALIBRATION = 2
@@ -603,15 +657,7 @@ def __call__(
603
657
# from being rejected by assertions in aqt_dot_general.py, line 522-526 and
604
658
# 414.
605
659
# TODO: b/322111904 - Handle this in more proper way.
606
- # We hand-hold int4 because promote_dtype(int4, x) fails.
607
- # (To avoid unintended promotion, 4-bit integers do not support
608
- # implicit promotion.)
609
- if lhs_in .dtype == jnp .int4 :
610
- lhs_in = jnp .float32 (lhs_in )
611
- if rhs_in .dtype == jnp .int4 :
612
- rhs_in = jnp .float32 (rhs_in )
613
- if lhs_in .dtype != jnp .int4 and rhs_in .dtype != jnp .int4 :
614
- lhs_in , rhs_in = nn .dtypes .promote_dtype (lhs_in , rhs_in )
660
+ lhs_in , rhs_in = aqt_promote_dtype (lhs_in , rhs_in )
615
661
616
662
# yes_swap = whether einsum swaps [lhs,rhs] when passing them to dot_general
617
663
einsum = functools .partial (aqt_dot_general .einsum , eqn = eqn )
0 commit comments