-
Notifications
You must be signed in to change notification settings - Fork 18
feat: general concat reshape to batching pass #983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
rather than call batch, why not just call the batchopinterface if defined? or other utility function from the implementation of batch lowering |
e.g. we can make https://github.com/EnzymeAD/Enzyme/blob/db0181320d6e425ee963bd496ed0d8dbb615be18/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp#L130 into a batchOperation utility or something, which we call directly here instead of outlining into a new function |
yeah that sounds like a better thing to do |
983ae0e
to
0a3007c
Compare
somehow I made it segfault only on mac... |
That's probably an issue of f(a(), b()) running a before b on Linux and opposite on Mac |
Summarizing the transform pass:
Later we can use this same pass for ops like
cholesky
/tringular_solve
/lu
where we need to check that the concat dim is not among the last 2 dims.For
lu
it is a bit more complicated since it has 4 returns and we need to check whether all the other returns are either (1) unused or (2) also concatenated.