Skip to content

Commit 3e0f966

Browse files
Cerebra Catalyst Teamcopybara-github
authored andcommitted
Handle the jax.promote_dtype failure for [int4, fp8_e4m3, fp8_e5m2] when using aqt_einsum
PiperOrigin-RevId: 733390414
1 parent a1e1d11 commit 3e0f966

File tree

1 file changed

+55
-9
lines changed

1 file changed

+55
-9
lines changed

aqt/jax/v2/flax/aqt_flax.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,60 @@ def _get_singleton_axes(x: jnp.ndarray) -> list[utils.AxisIdx]:
9999
return qt
100100

101101

102+
def aqt_promote_dtype(
103+
lhs_in: jnp.ndarray, rhs_in: jnp.ndarray
104+
) -> tuple[jnp.ndarray, jnp.ndarray]:
105+
"""Promotes the dtype of lhs_in and rhs_in.
106+
107+
Args:
108+
lhs_in: Left-hand-side array.
109+
rhs_in: Right-hand-side array.
110+
111+
Returns:
112+
A tuple of the promoted lhs_in and rhs_in.
113+
114+
We create a list of dtypes and hand-hold them because promote_dtype fails for
115+
these dtypes.
116+
"""
117+
118+
manual_promotion_dtypes = [jnp.int4, jnp.float8_e4m3fn, jnp.float8_e5m2]
119+
if (
120+
lhs_in.dtype in manual_promotion_dtypes
121+
and rhs_in.dtype in manual_promotion_dtypes
122+
):
123+
if lhs_in.dtype == rhs_in.dtype:
124+
pass
125+
else:
126+
lhs_in = (
127+
jnp.float32(lhs_in)
128+
if lhs_in.dtype == jnp.int4
129+
else jnp.bfloat16(lhs_in)
130+
)
131+
rhs_in = (
132+
jnp.float32(rhs_in)
133+
if rhs_in.dtype == jnp.int4
134+
else jnp.bfloat16(rhs_in)
135+
)
136+
elif lhs_in.dtype in manual_promotion_dtypes:
137+
lhs_in = (
138+
jnp.float32(lhs_in)
139+
if lhs_in.dtype == jnp.int4
140+
else jnp.bfloat16(lhs_in)
141+
)
142+
elif rhs_in.dtype in manual_promotion_dtypes:
143+
rhs_in = (
144+
jnp.float32(rhs_in)
145+
if rhs_in.dtype == jnp.int4
146+
else jnp.bfloat16(rhs_in)
147+
)
148+
elif (
149+
lhs_in.dtype not in manual_promotion_dtypes
150+
and rhs_in.dtype not in manual_promotion_dtypes
151+
):
152+
lhs_in, rhs_in = nn.dtypes.promote_dtype(lhs_in, rhs_in)
153+
return lhs_in, rhs_in
154+
155+
102156
class FreezerMode(enum.Enum):
103157
NONE = 1
104158
CALIBRATION = 2
@@ -603,15 +657,7 @@ def __call__(
603657
# from being rejected by assertions in aqt_dot_general.py, line 522-526 and
604658
# 414.
605659
# TODO: b/322111904 - Handle this in more proper way.
606-
# We hand-hold int4 because promote_dtype(int4, x) fails.
607-
# (To avoid unintended promotion, 4-bit integers do not support
608-
# implicit promotion.)
609-
if lhs_in.dtype == jnp.int4:
610-
lhs_in = jnp.float32(lhs_in)
611-
if rhs_in.dtype == jnp.int4:
612-
rhs_in = jnp.float32(rhs_in)
613-
if lhs_in.dtype != jnp.int4 and rhs_in.dtype != jnp.int4:
614-
lhs_in, rhs_in = nn.dtypes.promote_dtype(lhs_in, rhs_in)
660+
lhs_in, rhs_in = aqt_promote_dtype(lhs_in, rhs_in)
615661

616662
# yes_swap = whether einsum swaps [lhs,rhs] when passing them to dot_general
617663
einsum = functools.partial(aqt_dot_general.einsum, eqn=eqn)

0 commit comments

Comments
 (0)