From 107ccceecb3e15137b26ac30fb2ca2e5185697df Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Mon, 24 Mar 2025 17:15:46 +0000 Subject: [PATCH 1/2] add spda rules and tests for unmasked patterns --- onnxscript/rewriter/ort_fusions/sdpa.py | 19 +++++++ onnxscript/rewriter/ort_fusions/sdpa_test.py | 54 +++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 3244bc45a8..dfaaf42738 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -73,6 +73,21 @@ def rewrite(self, op, query, key_transposed, value, mask, **_): return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion") +# Rules for SDPA without mask +unmasked_pre_div_sdpa_rule = SDPA.rule( + "unmasked_pre_div_sdpa", use_mask=False, pre_scale=True, use_mul=False +) +unmasked_pre_mul_sdpa_rule = SDPA.rule( + "unmasked_pre_mul_sdpa", use_mask=False, pre_scale=True, use_mul=True +) +unmasked_post_div_sdpa_rule = SDPA.rule( + "unmasked_post_div_sdpa", use_mask=False, pre_scale=False, use_mul=False +) +unmasked_post_mul_sdpa_rule = SDPA.rule( + "unmasked_post_div_sdpa", use_mask=False, pre_scale=False, use_mul=True +) + +# Rules for SDPA with mask masked_pre_div_sdpa_rule = SDPA.rule( "masked_pre_div_sdpa", use_mask=True, pre_scale=True, use_mul=False ) @@ -88,6 +103,10 @@ def rewrite(self, op, query, key_transposed, value, mask, **_): sdpa_rules = pattern.RewriteRuleSet( [ + unmasked_pre_mul_sdpa_rule, + unmasked_post_div_sdpa_rule, + unmasked_post_mul_sdpa_rule, + unmasked_pre_div_sdpa_rule, masked_pre_mul_sdpa_rule, masked_post_div_sdpa_rule, masked_post_mul_sdpa_rule, diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 1ffb3fa55c..0c220bdbd5 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -9,7 +9,7 @@ import unittest import numpy -from parameterized import parameterized +import parameterized import onnxscript.ir as ir import onnxscript.optimizer @@ -28,6 +28,52 @@ SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR) +@script() +def _unmasked_pre_div_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=SQRT_SCALE_FACTOR) + scaled_query = op.Div(query, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _unmasked_pre_mul_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR) + scaled_query = op.Mul(query, multiplier) + scaled_key = op.Mul(key_transposed, multiplier) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _unmasked_post_div_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Div(attn_score, divisor) + attn_weight = op.Softmax(scaled_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _unmasked_post_mul_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=MUL_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Mul(attn_score, multiplier) + attn_weight = op.Softmax(scaled_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + @script() def _masked_pre_div_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) @@ -105,8 +151,12 @@ def get_ort_inputs(self): class TestSDPAFusion(unittest.TestCase): - @parameterized.expand( + @parameterized.parameterized.expand( [ + ("unmasked_pre_div", _unmasked_pre_div_sdpa_script), + ("unmasked_pre_mul", _unmasked_pre_mul_sdpa_script), + ("unmasked_post_div", _unmasked_post_div_sdpa_script), + ("unmasked_post_mul", _unmasked_post_mul_sdpa_script), ("pre_div", _masked_pre_div_sdpa_script), ("pre_mul", _masked_pre_mul_sdpa_script), ("post_div", _masked_post_div_sdpa_script), From f4e352281b29645db7e762c056f8781083813479 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Tue, 25 Mar 2025 00:00:24 +0000 Subject: [PATCH 2/2] typo: fix --- onnxscript/rewriter/ort_fusions/sdpa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index dfaaf42738..6d983b0a6c 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -84,7 +84,7 @@ def rewrite(self, op, query, key_transposed, value, mask, **_): "unmasked_post_div_sdpa", use_mask=False, pre_scale=False, use_mul=False ) unmasked_post_mul_sdpa_rule = SDPA.rule( - "unmasked_post_div_sdpa", use_mask=False, pre_scale=False, use_mul=True + "unmasked_post_mul_sdpa", use_mask=False, pre_scale=False, use_mul=True ) # Rules for SDPA with mask @@ -98,7 +98,7 @@ def rewrite(self, op, query, key_transposed, value, mask, **_): "masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=False ) masked_post_mul_sdpa_rule = SDPA.rule( - "masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=True + "masked_post_mul_sdpa", use_mask=True, pre_scale=False, use_mul=True ) sdpa_rules = pattern.RewriteRuleSet(