Skip to content

Commit b537ff3

Browse files
reorder checks
1 parent e62345a commit b537ff3

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,6 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
6666
"Query scale is not a scalar.",
6767
query_scale,
6868
)
69-
if not math.isclose(query_scale_value, sqrt_scaling_factor, rel_tol=1e-3):
70-
self._scale = query_scale_value * query_scale_value
71-
else:
72-
self._scale = expected_scaling_factor
7369
# Ensure the scaling factor for key is the same as for query
7470
if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None:
7571
return check_result.fail(
@@ -81,6 +77,11 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
8177
"Query and key scales are not equal.",
8278
query_scale,
8379
)
80+
if not math.isclose(query_scale_value, sqrt_scaling_factor, rel_tol=1e-3):
81+
self._scale = query_scale_value * query_scale_value
82+
else:
83+
# Pass no scaling factor to SDPA, SDPA will use the default scaling factor
84+
self._scale = None
8485
else:
8586
# Check if qk_scale is a scalar == expected_scaling_factor)
8687
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
@@ -92,7 +93,8 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
9293
if not math.isclose(qk_scale_value, expected_scaling_factor, rel_tol=1e-3):
9394
self._scale = qk_scale_value
9495
else:
95-
self._scale = expected_scaling_factor
96+
# Pass no scaling factor to SDPA, SDPA will use the default scaling factor
97+
self._scale = None
9698

9799
# check ranks/shapes
98100

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -303,20 +303,20 @@ def test_sdpa_fusion(self, name, script_func):
303303
# Ensure that the scale of the SDPA node is set correctly
304304
sdpa_node = next(n for n in model.graph if n.op_type == "SDPA")
305305
self.assertEqual(sdpa_node.op_type, "SDPA")
306-
self.assertIsNotNone(sdpa_node.attributes.get("scale"))
307306

308-
scale_factor = sdpa_node.attributes["scale"].value
309-
self.assertIsNotNone(scale_factor)
310307
if "custom" in name:
308+
self.assertIsNotNone(sdpa_node.attributes.get("scale"))
309+
scale_factor = sdpa_node.attributes["scale"].value
310+
self.assertIsNotNone(scale_factor)
311311
if "pre" in name:
312312
self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR)
313313
elif "post" in name:
314314
self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR)
315315
else:
316-
if "div" in name:
317-
self.assertEqual(scale_factor, SCALE_FACTOR)
318-
elif "mul" in name:
319-
self.assertEqual(scale_factor, MUL_SCALE_FACTOR)
316+
# These tests are for the default scaling factors, no scale factor is passed to SDPA
317+
# pattern rewriting check functions should be sufficient to check if expected value
318+
# of scale_factor (is =default_scaling_factor)
319+
self.assertIsNone(sdpa_node.attributes.get("scale"))
320320

321321
# new_outputs = ort_run("optimized", model, inputs)
322322
# assert_allclose(new_outputs, original_outputs)

0 commit comments

Comments
 (0)