Skip to content

Commit df48c32

Browse files
check fixes
1 parent fb97d2f commit df48c32

File tree

1 file changed

+11
-15
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+11
-15
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,40 +61,36 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
6161
# If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used.
6262
sqrt_scaling_factor = math.sqrt(expected_scaling_factor)
6363
# Calculate the scaling factor for query
64-
if _ir_utils.get_singleton_value(query_scale) is None:
64+
if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None:
6565
return check_result.fail(
6666
"Query scale is not a scalar.",
6767
query_scale,
6868
)
69-
if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3):
70-
query_scale_value = _ir_utils.get_singleton_value(query_scale)
69+
if not math.isclose(query_scale_value, sqrt_scaling_factor, rel_tol=1e-3):
7170
self._scale = query_scale_value * query_scale_value
7271
else:
7372
self._scale = expected_scaling_factor
7473
# Ensure the scaling factor for key is the same as for query
75-
if _ir_utils.get_singleton_value(key_scale) is None:
74+
if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None:
7675
return check_result.fail(
7776
"Key scale is not a scalar.",
7877
key_scale,
7978
)
80-
if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3):
81-
if _ir_utils.get_singleton_value(query_scale) != _ir_utils.get_singleton_value(
82-
key_scale
83-
):
84-
return check_result.fail(
85-
"Query and key scales are not equal.",
86-
query_scale,
87-
)
79+
if not math.isclose(query_scale_value, key_scale_value, rel_tol=1e-3):
80+
return check_result.fail(
81+
"Query and key scales are not equal.",
82+
query_scale,
83+
)
8884
else:
8985
# Check if qk_scale is a scalar == expected_scaling_factor)
9086
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
91-
if _ir_utils.get_singleton_value(qk_scale) is None:
87+
if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None:
9288
return check_result.fail(
9389
"QK scale is not a scalar.",
9490
qk_scale,
9591
)
96-
if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3):
97-
self._scale = _ir_utils.get_singleton_value(qk_scale)
92+
if not math.isclose(qk_scale_value, expected_scaling_factor, rel_tol=1e-3):
93+
self._scale = qk_scale_value
9894
else:
9995
self._scale = expected_scaling_factor
10096

0 commit comments

Comments
 (0)