Skip to content

feat: general concat reshape to batching pass #983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented May 22, 2025

Summarizing the transform pass:

  1. Pattern: concat of insert dims (at the same dim) with equivalent ops.
  2. For all the parent operands, we insert a dim at dim=0 and concat those.
  3. Insert a batch op and immediately resolve the op into a callop using batching utilities
  4. Move the first batched dim to the concat dim.

Later we can use this same pass for ops like cholesky / tringular_solve / lu where we need to check that the concat dim is not among the last 2 dims.

For lu it is a bit more complicated since it has 4 returns and we need to check whether all the other returns are either (1) unused or (2) also concatenated.

@avik-pal avik-pal marked this pull request as ready for review May 22, 2025 00:38
@avik-pal avik-pal requested a review from wsmoses May 22, 2025 00:42
@wsmoses
Copy link
Member

wsmoses commented May 22, 2025

rather than call batch, why not just call the batchopinterface if defined? or other utility function from the implementation of batch lowering

@wsmoses
Copy link
Member

wsmoses commented May 22, 2025

e.g. we can make https://github.com/EnzymeAD/Enzyme/blob/db0181320d6e425ee963bd496ed0d8dbb615be18/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp#L130 into a batchOperation utility or something, which we call directly here instead of outlining into a new function

@avik-pal
Copy link
Collaborator Author

yeah that sounds like a better thing to do

@avik-pal avik-pal force-pushed the ap/general_concat_push_up branch from 983ae0e to 0a3007c Compare July 18, 2025 02:56
@avik-pal avik-pal marked this pull request as ready for review July 19, 2025 13:48
@avik-pal
Copy link
Collaborator Author

somehow I made it segfault only on mac...

@wsmoses
Copy link
Member

wsmoses commented Jul 19, 2025

That's probably an issue of f(a(), b()) running a before b on Linux and opposite on Mac

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants