diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 6d983b0a6c..788fffe046 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -70,7 +70,10 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, return True def rewrite(self, op, query, key_transposed, value, mask, **_): - return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion") + if self._use_mask: + return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion") + else: + return op.SDPA(query, key_transposed, value, _domain="ai.onnxruntime.fusion") # Rules for SDPA without mask diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 0c220bdbd5..19329e75f6 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -29,7 +29,7 @@ @script() -def _unmasked_pre_div_sdpa_script(query, key, value, mask): +def _unmasked_pre_div_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) divisor = op.Constant(value_float=SQRT_SCALE_FACTOR) scaled_query = op.Div(query, divisor) @@ -41,7 +41,7 @@ def _unmasked_pre_div_sdpa_script(query, key, value, mask): @script() -def _unmasked_pre_mul_sdpa_script(query, key, value, mask): +def _unmasked_pre_mul_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR) scaled_query = op.Mul(query, multiplier) @@ -53,7 +53,7 @@ def _unmasked_pre_mul_sdpa_script(query, key, value, mask): @script() -def _unmasked_post_div_sdpa_script(query, key, value, mask): +def _unmasked_post_div_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) divisor = op.Constant(value_float=SCALE_FACTOR) attn_score = op.MatMul(query, key_transposed) @@ -64,7 +64,7 @@ def _unmasked_post_div_sdpa_script(query, key, value, mask): @script() -def _unmasked_post_mul_sdpa_script(query, key, value, mask): +def _unmasked_post_mul_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) multiplier = op.Constant(value_float=MUL_SCALE_FACTOR) attn_score = op.MatMul(query, key_transposed)