From caf890584a41f16b3366ddf275efc04e04fca4d8 Mon Sep 17 00:00:00 2001 From: Cerebra Catalyst Team Date: Fri, 7 Jun 2024 15:16:45 -0700 Subject: [PATCH] Change a way to split tiling axes without configuration For the axes without specified configuration, AQT has assumed it will be split into (1, dim). This change changes this assumption to (dim, 1) since it would provide more consistency when it comes to sharding. PiperOrigin-RevId: 641372231 --- aqt/jax/v2/tiled_dot_general.py | 48 +++++++++++++++------------------ 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/aqt/jax/v2/tiled_dot_general.py b/aqt/jax/v2/tiled_dot_general.py index 8b9d52e7..9e07634c 100644 --- a/aqt/jax/v2/tiled_dot_general.py +++ b/aqt/jax/v2/tiled_dot_general.py @@ -140,7 +140,7 @@ def f(axes_cfg, axes, shape): for axis in axes: if axis not in axis_in_cfg: axes_cfg.append( - AxisTiling(axis=axis, tile_count=1, tile_size=shape[axis]) + AxisTiling(axis=axis, tile_count=shape[axis], tile_size=1) ) f(new_cfg.lhs.contraction_axes, lhs_ca, lhs_shape) @@ -358,11 +358,11 @@ def generate_tiling_states_for_dot_general( # DotGeneral reshapeing - xlhs_ra_tile, _ = xlhs.to_tiled_axes_transposed(lhs_ra, 2) - xrhs.broadcast_to_other(xlhs.axes_shape(xlhs_ra_tile)) + _, xlhs_ra = xlhs.to_tiled_axes_transposed(lhs_ra, 2) + xrhs.broadcast_to_other(xlhs.axes_shape(xlhs_ra)) - xrhs_ra_tile, _ = xrhs.to_tiled_axes_transposed(rhs_ra, 2) - xlhs.broadcast_to_other(xrhs.axes_shape(xrhs_ra_tile)) + _, xrhs_ra = xrhs.to_tiled_axes_transposed(rhs_ra, 2) + xlhs.broadcast_to_other(xrhs.axes_shape(xrhs_ra)) return xlhs, xrhs @@ -389,15 +389,15 @@ def tiled_dot_general_with_tiling_states( xrhs_ra_tile, xrhs_ra = xrhs.to_tiled_axes_transposed(rhs_ra, 2) xlhs_bcast = xlhs.get_broadcasted_tile_map_indexes() xrhs_bcast = xrhs.get_broadcasted_tile_map_indexes() - (xlhs_ra_tile_other,) = xlhs.to_tiled_axes_transposed(xlhs_bcast, 1) - (xrhs_ra_tile_other,) = xrhs.to_tiled_axes_transposed(xrhs_bcast, 1) + (xlhs_ra_other,) = xlhs.to_tiled_axes_transposed(xlhs_bcast, 1) + (xrhs_ra_other,) = xrhs.to_tiled_axes_transposed(xrhs_bcast, 1) (xlhs_ba,) = xlhs.to_tiled_axes_transposed(lhs_ba, 1) (xrhs_ba,) = xrhs.to_tiled_axes_transposed(rhs_ba, 1) tiled_ca = (xlhs_ca, xrhs_ca) tiled_ba = ( - xlhs_ca_tile + xlhs_ba + xlhs_ra_tile + xlhs_ra_tile_other, - xrhs_ca_tile + xrhs_ba + xrhs_ra_tile_other + xrhs_ra_tile, + xlhs_ca_tile + xlhs_ba + xlhs_ra + xlhs_ra_other, + xrhs_ca_tile + xrhs_ba + xrhs_ra_other + xrhs_ra, ) tiled_dimension_numbers = (tiled_ca, tiled_ba) @@ -430,28 +430,24 @@ def tiled_dot_general_with_tiling_states( assert xlhs.axes_shape(xlhs_ba) == xrhs.axes_shape(xrhs_ba), g_msg ba_sh = xlhs.axes_shape(xlhs_ba) - assert xlhs.axes_shape(xlhs_ra_tile) == xrhs.axes_shape( - xrhs_ra_tile_other - ), g_msg - lhs_ra_tile_sh = xlhs.axes_shape(xlhs_ra_tile) + assert xlhs.axes_shape(xlhs_ra) == xrhs.axes_shape(xrhs_ra_other), g_msg + lhs_ra_sh = xlhs.axes_shape(xlhs_ra) - assert xlhs.axes_shape(xlhs_ra_tile_other) == xrhs.axes_shape( - xrhs_ra_tile - ), g_msg - rhs_ra_tile_sh = xlhs.axes_shape(xlhs_ra_tile_other) + assert xlhs.axes_shape(xlhs_ra_other) == xrhs.axes_shape(xrhs_ra), g_msg + rhs_ra_sh = xlhs.axes_shape(xlhs_ra_other) - lhs_ra_sh = xlhs.axes_shape(xlhs_ra) - rhs_ra_sh = xrhs.axes_shape(xrhs_ra) + lhs_ra_tile_sh = xlhs.axes_shape(xlhs_ra_tile) + rhs_ra_tile_sh = xrhs.axes_shape(xrhs_ra_tile) g_msg += f'Tiled dg {out.shape=} \n' assert ( out.shape == ca_tile_sh + ba_sh - + lhs_ra_tile_sh - + rhs_ra_tile_sh + lhs_ra_sh + rhs_ra_sh + + lhs_ra_tile_sh + + rhs_ra_tile_sh ), g_msg # Sum over ca_tile now. @@ -462,7 +458,7 @@ def tiled_dot_general_with_tiling_states( g_msg += f'After sum over tiles {out.shape=} \n' assert ( out.shape - == ba_sh + lhs_ra_tile_sh + rhs_ra_tile_sh + lhs_ra_sh + rhs_ra_sh + == ba_sh + lhs_ra_sh + rhs_ra_sh + lhs_ra_tile_sh + rhs_ra_tile_sh ), g_msg # Transpose tile and tile size together @@ -488,10 +484,10 @@ def tiled_dot_general_with_tiling_states( + interleave(new_rhs_ra_tile, new_rhs_ra) ) - lhs_ra_sh_interleaved = tuple(interleave(lhs_ra_tile_sh, lhs_ra_sh)) - rhs_ra_sh_interleaved = tuple(interleave(rhs_ra_tile_sh, rhs_ra_sh)) - lhs_ra_sh_flattened = tuple(zip_product(lhs_ra_tile_sh, lhs_ra_sh)) - rhs_ra_sh_flattened = tuple(zip_product(rhs_ra_tile_sh, rhs_ra_sh)) + lhs_ra_sh_interleaved = tuple(interleave(lhs_ra_sh, lhs_ra_tile_sh)) + rhs_ra_sh_interleaved = tuple(interleave(rhs_ra_sh, rhs_ra_tile_sh)) + lhs_ra_sh_flattened = tuple(zip_product(lhs_ra_sh, lhs_ra_tile_sh)) + rhs_ra_sh_flattened = tuple(zip_product(rhs_ra_sh, rhs_ra_tile_sh)) g_msg += f'After transpose {out.shape=} \n' assert (