-
Notifications
You must be signed in to change notification settings - Fork 72
Optimize causal mask shape #2325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
❌ 8 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
The MHA-Bias rules can be simplified using pattern-disjunction. (This _may_ help with Whisper ... that was my original motivation, but not sure, after I fixed another issue in PR #2325, which may be the primary issue ). But the cleanup is useful anyway, and it makes fusion more efficient.) Signed-off-by: Ganesan Ramalingam <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So constant folding doesn’t get this properly?
The MHA-Bias rules can be simplified using pattern-disjunction. (This _may_ help with Whisper ... that was my original motivation, but not sure, after I fixed another issue in PR microsoft#2325, which may be the primary issue ). But the cleanup is useful anyway, and it makes fusion more efficient.) Signed-off-by: Ganesan Ramalingam <[email protected]>
Good question (though it is the "optimizer", though we call it constant-folding, since it goes beyond pure constant folding). I think not. It does the necessary analysis for shape-inference. May be worth checking. I thought we might need a more generic optimization pass, but perhaps not. |
The generation of the causal mask's shape (produced by the translation of scalar_dot_product_attention) interferes with the subsequent fusion optimizations (because it makes use of the shape of the intermediate matmul value).
This PR introduces a very specific fusion/rewrite to eliminate this redundant computation of the "sequence length" dimension.