From 7c88db17922c8d6b9dae0fcbd6ae2297554fa15b Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Fri, 21 Mar 2025 20:32:45 +0000 Subject: [PATCH 1/4] Extend sdpa tests --- onnxscript/rewriter/ort_fusions/sdpa.py | 2 +- onnxscript/rewriter/ort_fusions/sdpa_test.py | 78 +++++++++++++++++--- 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 70b208507a..3244bc45a8 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -74,7 +74,7 @@ def rewrite(self, op, query, key_transposed, value, mask, **_): masked_pre_div_sdpa_rule = SDPA.rule( - "masked_pre_mul_sdpa", use_mask=True, pre_scale=True, use_mul=False + "masked_pre_div_sdpa", use_mask=True, pre_scale=True, use_mul=False ) masked_pre_mul_sdpa_rule = SDPA.rule( "masked_pre_mul_sdpa", use_mask=True, pre_scale=True, use_mul=True diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index b3f551c638..ef687f1c28 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -22,7 +22,9 @@ S = 8 # sequence length H = 128 # head size SCALE_FACTOR = math.sqrt(H) +MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR) +MUL_SQRT_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR) @script() @@ -38,37 +40,77 @@ def _masked_pre_div_sdpa_script(query, key, value, mask): return attn_output -class _MaskedPreDivSDPATestCase: +@script() +def _masked_pre_mul_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=MUL_SQRT_SCALE_FACTOR) + scaled_query = op.Mul(query, multiplier) + scaled_key = op.Mul(key_transposed, multiplier) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _masked_post_div_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Div(attn_score, divisor) + masked_attn_score = op.Add(scaled_attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _masked_post_mul_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=MUL_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Mul(attn_score, multiplier) + masked_attn_score = op.Add(scaled_attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +class SDPATestCase: + def __init__(self, script_func): + self.script_func = script_func + self._onnx_model = None + self._ort_inputs = None + def get_onnx_model(self): - if not hasattr(self, "_onnx_model"): + if self._onnx_model is None: qkv_type = FLOAT[B, N, S, H] mask_type = FLOAT[B, N, S, S] - model_proto = _masked_pre_div_sdpa_script.to_model_proto( + model_proto = self.script_func.to_model_proto( input_types=[qkv_type, qkv_type, qkv_type, mask_type], output_types=[qkv_type] ) - model = ir.serde.deserialize_model(model_proto) - self._onnx_model = model + self._onnx_model = ir.serde.deserialize_model(model_proto) return self._onnx_model def get_ort_inputs(self): - if not hasattr(self, "_ort_inputs"): - inputs = { + if self._ort_inputs is None: + self._ort_inputs = { "query": numpy.random.rand(B, N, S, H).astype(numpy.float32), "key": numpy.random.rand(B, N, S, H).astype(numpy.float32), "value": numpy.random.rand(B, N, S, H).astype(numpy.float32), "mask": numpy.random.rand(B, N, S, S).astype(numpy.float32), } - self._ort_inputs = inputs return self._ort_inputs class TestSDPAFusion(unittest.TestCase): - def test_sdpa_fusion(self): - test = _MaskedPreDivSDPATestCase() - model = test.get_onnx_model() + def _test_sdpa_fusion(self, script_func): + test_case = SDPATestCase(script_func) + model = test_case.get_onnx_model() onnxscript.optimizer.optimize(model) - # inputs = test.get_ort_inputs() + # inputs = test_case.get_ort_inputs() # original_outputs = ort_run("original", model, inputs) count = fuse_sdpa(model) @@ -80,3 +122,15 @@ def test_sdpa_fusion(self): # new_outputs = ort_run("optimized", model, inputs) # assert_allclose(new_outputs, original_outputs) + + def test_sdpa_fusion_pre_div(self): + self._test_sdpa_fusion(_masked_pre_div_sdpa_script) + + def test_sdpa_fusion_pre_mul(self): + self._test_sdpa_fusion(_masked_pre_mul_sdpa_script) + + def test_sdpa_fusion_post_div(self): + self._test_sdpa_fusion(_masked_post_div_sdpa_script) + + def test_sdpa_fusion_post_mul(self): + self._test_sdpa_fusion(_masked_post_mul_sdpa_script) From e44d39dbd3c547c804761ce034928cb292dad0f8 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Fri, 21 Mar 2025 20:47:22 +0000 Subject: [PATCH 2/4] refactor internal attrs --- onnxscript/rewriter/ort_fusions/sdpa_test.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index ef687f1c28..f2757d2e26 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -80,11 +80,9 @@ def _masked_post_mul_sdpa_script(query, key, value, mask): class SDPATestCase: def __init__(self, script_func): self.script_func = script_func - self._onnx_model = None - self._ort_inputs = None def get_onnx_model(self): - if self._onnx_model is None: + if not hasattr(self, "_onnx_model"): qkv_type = FLOAT[B, N, S, H] mask_type = FLOAT[B, N, S, S] model_proto = self.script_func.to_model_proto( @@ -94,13 +92,14 @@ def get_onnx_model(self): return self._onnx_model def get_ort_inputs(self): - if self._ort_inputs is None: - self._ort_inputs = { + if not hasattr(self, "_ort_inputs"): + inputs = { "query": numpy.random.rand(B, N, S, H).astype(numpy.float32), "key": numpy.random.rand(B, N, S, H).astype(numpy.float32), "value": numpy.random.rand(B, N, S, H).astype(numpy.float32), "mask": numpy.random.rand(B, N, S, S).astype(numpy.float32), } + self._ort_inputs = inputs return self._ort_inputs From b107f519f47f4e70fdd3dd57a9457d40951cd455 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Fri, 21 Mar 2025 22:07:15 +0000 Subject: [PATCH 3/4] use parameterized --- onnxscript/rewriter/ort_fusions/sdpa_test.py | 23 +++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index f2757d2e26..53bf933e53 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -9,6 +9,7 @@ import unittest import numpy +from parameterized import parameterized import onnxscript.ir as ir import onnxscript.optimizer @@ -104,7 +105,15 @@ def get_ort_inputs(self): class TestSDPAFusion(unittest.TestCase): - def _test_sdpa_fusion(self, script_func): + @parameterized.expand( + [ + ("pre_div", _masked_pre_div_sdpa_script), + ("pre_mul", _masked_pre_mul_sdpa_script), + ("post_div", _masked_post_div_sdpa_script), + ("post_mul", _masked_post_mul_sdpa_script), + ] + ) + def test_sdpa_fusion(self, name, script_func): test_case = SDPATestCase(script_func) model = test_case.get_onnx_model() onnxscript.optimizer.optimize(model) @@ -121,15 +130,3 @@ def _test_sdpa_fusion(self, script_func): # new_outputs = ort_run("optimized", model, inputs) # assert_allclose(new_outputs, original_outputs) - - def test_sdpa_fusion_pre_div(self): - self._test_sdpa_fusion(_masked_pre_div_sdpa_script) - - def test_sdpa_fusion_pre_mul(self): - self._test_sdpa_fusion(_masked_pre_mul_sdpa_script) - - def test_sdpa_fusion_post_div(self): - self._test_sdpa_fusion(_masked_post_div_sdpa_script) - - def test_sdpa_fusion_post_mul(self): - self._test_sdpa_fusion(_masked_post_mul_sdpa_script) From 2b7a1783f837d5221499337735fcadaeef935f1e Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Fri, 21 Mar 2025 22:25:33 +0000 Subject: [PATCH 4/4] nit changes --- onnxscript/rewriter/ort_fusions/sdpa_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 53bf933e53..1ffb3fa55c 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -25,7 +25,7 @@ SCALE_FACTOR = math.sqrt(H) MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR) -MUL_SQRT_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR) +SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR) @script() @@ -44,7 +44,7 @@ def _masked_pre_div_sdpa_script(query, key, value, mask): @script() def _masked_pre_mul_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier = op.Constant(value_float=MUL_SQRT_SCALE_FACTOR) + multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR) scaled_query = op.Mul(query, multiplier) scaled_key = op.Mul(key_transposed, multiplier) attn_score = op.MatMul(scaled_query, scaled_key)