Skip to content

Add a couple of variants of patterns in ORT fusions #2077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@


def optimize_for_ort(model: ir.Model) -> None:
rewrite(model, ORT_PATTERN_REWRITE_RULES)
fuse_xformers(model)
rewrite(model, ORT_PATTERN_REWRITE_RULES)

Check warning on line 43 in onnxscript/rewriter/ort_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L43

Added line #L43 was not covered by tests
13 changes: 9 additions & 4 deletions onnxscript/rewriter/ort_fusions/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,15 @@ def rewrite(
)


_cast = CosSinCacheFusion.rule("CosSinCache", 2048, cast=True, const_freqs=True)
_no_cast = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False)

cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _no_cast])
_cast_const_freqs = CosSinCacheFusion.rule(
"CosSinCache_cast_const_freqs", 2048, cast=True, const_freqs=True
)
_cast = CosSinCacheFusion.rule(
"CosSinCache_cast_no_const_freqs", 2048, cast=True, const_freqs=False
)
_basic = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False)

cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _basic])

debug: bool = True

Expand Down
54 changes: 43 additions & 11 deletions onnxscript/rewriter/ort_fusions/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,30 @@


class SDPA(pattern.RewriteRuleClassBase):
def __init__(self, name: str, *, use_mask: bool, pre_scale: bool):
def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool):
super().__init__(name=name)
self._use_mask = use_mask
self._pre_scale = pre_scale
self._use_mul = use_mul

def pattern(
self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale
):
if self._pre_scale:
# Some implementations scale the query and key before computing the dot product
query = op.Mul(query, query_scale)
key_transposed = op.Mul(key_transposed, key_scale)
if self._use_mul:
query = op.Mul(query, query_scale)
key_transposed = op.Mul(key_transposed, key_scale)
else:
query = op.Div(query, query_scale)
key_transposed = op.Div(key_transposed, key_scale)
attn_score = op.MatMul(query, key_transposed)
if not self._pre_scale:
# Some implementations scale the dot product.
attn_score = op.Div(attn_score, qk_scale)
if self._use_mul:
attn_score = op.Mul(attn_score, qk_scale)
else:
attn_score = op.Div(attn_score, qk_scale)
if self._use_mask:
# Some implementations add a mask to the dot product.
attn_score = op.Add(attn_score, mask)
Expand All @@ -42,16 +50,18 @@
if not isinstance(hidden_size, int):
return False
expected_scaling_factor = math.sqrt(hidden_size)
if self._use_mul:
expected_scaling_factor = 1.0 / expected_scaling_factor

if self._pre_scale:
# Check if query_scale and key_scale are scalars == 1/sqrt(sqrt(hidden_size))
sqrt_scaling_factor = 1.0 / math.sqrt(expected_scaling_factor)
# Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor)
sqrt_scaling_factor = math.sqrt(expected_scaling_factor)
if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3):
return False
if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3):
return False
else:
# Check if qk_scale is a scalar == sqrt(hidden_size)
# Check if qk_scale is a scalar == expected_scaling_factor)
if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3):
return False

Expand All @@ -63,13 +73,35 @@
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")


masked_pre_mul_sdpa_rule = SDPA.rule("masked_pre_mul_sdpa", use_mask=True, pre_scale=True)
masked_post_div_sdpa_rule = SDPA.rule("masked_post_div_sdpa", use_mask=True, pre_scale=False)
masked_pre_div_sdpa_rule = SDPA.rule(
"masked_pre_mul_sdpa", use_mask=True, pre_scale=True, use_mul=False
)
masked_pre_mul_sdpa_rule = SDPA.rule(
"masked_pre_mul_sdpa", use_mask=True, pre_scale=True, use_mul=True
)
masked_post_div_sdpa_rule = SDPA.rule(
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=False
)
masked_post_mul_sdpa_rule = SDPA.rule(
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=True
)

sdpa_rules = pattern.RewriteRuleSet([masked_pre_mul_sdpa_rule, masked_post_div_sdpa_rule])
sdpa_rules = pattern.RewriteRuleSet(
[
masked_pre_mul_sdpa_rule,
masked_post_div_sdpa_rule,
masked_post_mul_sdpa_rule,
masked_pre_div_sdpa_rule,
]
)

debug: bool = True


def fuse_sdpa(model: ir.Model) -> int:
count = sdpa_rules.apply_to_model(model)
print(f"SDPA count: {count}")
if count == 0 and debug:
sdpa_rules.apply_to_model(model, debug=True)

Check warning on line 104 in onnxscript/rewriter/ort_fusions/sdpa.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/sdpa.py#L104

Added line #L104 was not covered by tests
else:
print(f"SDPA count: {count}")
return count
2 changes: 1 addition & 1 deletion onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,7 +1830,7 @@ def report(self) -> None:
print(f"Rule: {rule}")
print(f"Best score: {matches[0].score()}")
for match in matches:
print(f"Status: {match.status}")
print(f"Status: {match.status.name}")
if match.status == MatchStatus.NO_MATCH:
print("Graph matching failed: " + match.match_result.reason)
node = match.match_result._failure_node
Expand Down
Loading