Skip to content

Change a way to split tiling axes without configuration #644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions aqt/jax/v2/tiled_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def f(axes_cfg, axes, shape):
for axis in axes:
if axis not in axis_in_cfg:
axes_cfg.append(
AxisTiling(axis=axis, tile_count=1, tile_size=shape[axis])
AxisTiling(axis=axis, tile_count=shape[axis], tile_size=1)
)

f(new_cfg.lhs.contraction_axes, lhs_ca, lhs_shape)
Expand Down Expand Up @@ -358,11 +358,11 @@ def generate_tiling_states_for_dot_general(

# DotGeneral reshapeing

xlhs_ra_tile, _ = xlhs.to_tiled_axes_transposed(lhs_ra, 2)
xrhs.broadcast_to_other(xlhs.axes_shape(xlhs_ra_tile))
_, xlhs_ra = xlhs.to_tiled_axes_transposed(lhs_ra, 2)
xrhs.broadcast_to_other(xlhs.axes_shape(xlhs_ra))

xrhs_ra_tile, _ = xrhs.to_tiled_axes_transposed(rhs_ra, 2)
xlhs.broadcast_to_other(xrhs.axes_shape(xrhs_ra_tile))
_, xrhs_ra = xrhs.to_tiled_axes_transposed(rhs_ra, 2)
xlhs.broadcast_to_other(xrhs.axes_shape(xrhs_ra))

return xlhs, xrhs

Expand All @@ -389,15 +389,15 @@ def tiled_dot_general_with_tiling_states(
xrhs_ra_tile, xrhs_ra = xrhs.to_tiled_axes_transposed(rhs_ra, 2)
xlhs_bcast = xlhs.get_broadcasted_tile_map_indexes()
xrhs_bcast = xrhs.get_broadcasted_tile_map_indexes()
(xlhs_ra_tile_other,) = xlhs.to_tiled_axes_transposed(xlhs_bcast, 1)
(xrhs_ra_tile_other,) = xrhs.to_tiled_axes_transposed(xrhs_bcast, 1)
(xlhs_ra_other,) = xlhs.to_tiled_axes_transposed(xlhs_bcast, 1)
(xrhs_ra_other,) = xrhs.to_tiled_axes_transposed(xrhs_bcast, 1)
(xlhs_ba,) = xlhs.to_tiled_axes_transposed(lhs_ba, 1)
(xrhs_ba,) = xrhs.to_tiled_axes_transposed(rhs_ba, 1)

tiled_ca = (xlhs_ca, xrhs_ca)
tiled_ba = (
xlhs_ca_tile + xlhs_ba + xlhs_ra_tile + xlhs_ra_tile_other,
xrhs_ca_tile + xrhs_ba + xrhs_ra_tile_other + xrhs_ra_tile,
xlhs_ca_tile + xlhs_ba + xlhs_ra + xlhs_ra_other,
xrhs_ca_tile + xrhs_ba + xrhs_ra_other + xrhs_ra,
)
tiled_dimension_numbers = (tiled_ca, tiled_ba)

Expand Down Expand Up @@ -430,28 +430,24 @@ def tiled_dot_general_with_tiling_states(
assert xlhs.axes_shape(xlhs_ba) == xrhs.axes_shape(xrhs_ba), g_msg
ba_sh = xlhs.axes_shape(xlhs_ba)

assert xlhs.axes_shape(xlhs_ra_tile) == xrhs.axes_shape(
xrhs_ra_tile_other
), g_msg
lhs_ra_tile_sh = xlhs.axes_shape(xlhs_ra_tile)
assert xlhs.axes_shape(xlhs_ra) == xrhs.axes_shape(xrhs_ra_other), g_msg
lhs_ra_sh = xlhs.axes_shape(xlhs_ra)

assert xlhs.axes_shape(xlhs_ra_tile_other) == xrhs.axes_shape(
xrhs_ra_tile
), g_msg
rhs_ra_tile_sh = xlhs.axes_shape(xlhs_ra_tile_other)
assert xlhs.axes_shape(xlhs_ra_other) == xrhs.axes_shape(xrhs_ra), g_msg
rhs_ra_sh = xlhs.axes_shape(xlhs_ra_other)

lhs_ra_sh = xlhs.axes_shape(xlhs_ra)
rhs_ra_sh = xrhs.axes_shape(xrhs_ra)
lhs_ra_tile_sh = xlhs.axes_shape(xlhs_ra_tile)
rhs_ra_tile_sh = xrhs.axes_shape(xrhs_ra_tile)

g_msg += f'Tiled dg {out.shape=} \n'
assert (
out.shape
== ca_tile_sh
+ ba_sh
+ lhs_ra_tile_sh
+ rhs_ra_tile_sh
+ lhs_ra_sh
+ rhs_ra_sh
+ lhs_ra_tile_sh
+ rhs_ra_tile_sh
), g_msg

# Sum over ca_tile now.
Expand All @@ -462,7 +458,7 @@ def tiled_dot_general_with_tiling_states(
g_msg += f'After sum over tiles {out.shape=} \n'
assert (
out.shape
== ba_sh + lhs_ra_tile_sh + rhs_ra_tile_sh + lhs_ra_sh + rhs_ra_sh
== ba_sh + lhs_ra_sh + rhs_ra_sh + lhs_ra_tile_sh + rhs_ra_tile_sh
), g_msg

# Transpose tile and tile size together
Expand All @@ -488,10 +484,10 @@ def tiled_dot_general_with_tiling_states(
+ interleave(new_rhs_ra_tile, new_rhs_ra)
)

lhs_ra_sh_interleaved = tuple(interleave(lhs_ra_tile_sh, lhs_ra_sh))
rhs_ra_sh_interleaved = tuple(interleave(rhs_ra_tile_sh, rhs_ra_sh))
lhs_ra_sh_flattened = tuple(zip_product(lhs_ra_tile_sh, lhs_ra_sh))
rhs_ra_sh_flattened = tuple(zip_product(rhs_ra_tile_sh, rhs_ra_sh))
lhs_ra_sh_interleaved = tuple(interleave(lhs_ra_sh, lhs_ra_tile_sh))
rhs_ra_sh_interleaved = tuple(interleave(rhs_ra_sh, rhs_ra_tile_sh))
lhs_ra_sh_flattened = tuple(zip_product(lhs_ra_sh, lhs_ra_tile_sh))
rhs_ra_sh_flattened = tuple(zip_product(rhs_ra_sh, rhs_ra_tile_sh))

g_msg += f'After transpose {out.shape=} \n'
assert (
Expand Down