Skip to content

Commit ef7e9e7

Browse files
authored
Fix handling of attention-bias in MHA fusion (#2332)
In models generated from pytorch, masks may have shapes that are broadcastable to (B, H, S, St): eg., a 2D mask of shape (S, St) or even shape (1, 1, 1, St) in one example. ONNX's opset23 Attention op allows masks of this shape. However, ORT's contrib ops (MHA, Attention) allow a mask of shape (1 or B, 1 or H, S, St). That is: they support broadcast only for the first two dimensions. (Even that is not supported by some earlier versions of ORT, which we don't consider here.) So, while doing fusion for MHA, we should expand the mask to ensure it satisfies the constraints of MHA/Attention. --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent b34cd9c commit ef7e9e7

File tree

1 file changed

+46
-2
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+46
-2
lines changed

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,46 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
265265
past_value,
266266
)
267267

268-
# TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
269-
# But this also, unforunately, depends on ORT version.
268+
# mask (aka attention_bias) shape check:
269+
# ONNX's Attention op (named SDPA here) allows a mask broadcastable to (B, H, S, St)
270+
# ORT's contrib ops (MHA, Attention) allow a mask of shape (1 or B, 1 or H, S, St)
271+
# That is: broadcast allowed only for the first two dimensions. (Even that is not
272+
# supported by some earlier versions of ORT, which are not supported here.)
273+
if self._use_mask:
274+
if (mask_shape := mask.shape) is None:
275+
return check_result.fail(
276+
"Mask shape cannot be determined.",
277+
mask,
278+
)
279+
if mask_shape.rank() == 4:
280+
if no_match(mask, ["B_or_1", "H_or_1", "S_or_1", "St"]):
281+
return check_result.fail(
282+
f"Shape mismatch: {mask} does not match expected dimensions ['1 or B', '1 or H', '1 or S', 'St']",
283+
mask,
284+
)
285+
mask_dim_2 = bindings.get("S_or_1")
286+
if mask_dim_2 == bindings.get("S"):
287+
self._use_mask_broadcast = False
288+
elif mask_dim_2 == 1:
289+
self._use_mask_broadcast = True
290+
else:
291+
return check_result.fail(
292+
"Mask dimension 2 cannot be verified to be 1 or S"
293+
)
294+
elif mask_shape.rank() == 2:
295+
if no_match(mask, ["S_or_1", "St"]):
296+
return check_result.fail(
297+
f"Shape mismatch: {mask} does not match expected dimensions ['1 or S', 'St']",
298+
mask,
299+
)
300+
self._use_mask_broadcast = True
301+
else:
302+
return check_result.fail(
303+
f"Mask shape {mask_shape} is not supported. Expected 2D or 4D.",
304+
mask,
305+
)
306+
else:
307+
self._use_mask_broadcast = False
270308

271309
# TODO: verify Reshapes:
272310
# eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]:
@@ -315,6 +353,12 @@ def rewrite(
315353
query_BSD_emb = query_BSD
316354
key_BSD_emb = key
317355

356+
if self._use_mask_broadcast:
357+
one = op.Constant(value_ints=[1])
358+
S = op.Shape(query_BSD, start=1, end=2)
359+
shape_11S1 = op.Concat(one, one, S, one, axis=0)
360+
mask = op.Expand(mask, shape_11S1)
361+
318362
num_outputs = 1 + (2 * self._has_past_present)
319363
return op.MultiHeadAttention(
320364
query_BSD_emb,

0 commit comments

Comments
 (0)