You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix handling of attention-bias in MHA fusion (#2332)
In models generated from pytorch, masks may have shapes that are
broadcastable to (B, H, S, St): eg., a 2D mask of shape (S, St) or even
shape (1, 1, 1, St) in one example.
ONNX's opset23 Attention op allows masks of this shape. However, ORT's
contrib ops (MHA, Attention) allow a mask of shape (1 or B, 1 or H, S,
St). That is: they support broadcast only for the first two dimensions.
(Even that is not supported by some earlier versions of ORT, which we
don't consider here.)
So, while doing fusion for MHA, we should expand the mask to ensure it
satisfies the constraints of MHA/Attention.
---------
Signed-off-by: Ganesan Ramalingam <[email protected]>
0 commit comments