@@ -66,10 +66,6 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
66
66
"Query scale is not a scalar." ,
67
67
query_scale ,
68
68
)
69
- if not math .isclose (query_scale_value , sqrt_scaling_factor , rel_tol = 1e-3 ):
70
- self ._scale = query_scale_value * query_scale_value
71
- else :
72
- self ._scale = expected_scaling_factor
73
69
# Ensure the scaling factor for key is the same as for query
74
70
if (key_scale_value := _ir_utils .get_singleton_value (key_scale )) is None :
75
71
return check_result .fail (
@@ -81,6 +77,11 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
81
77
"Query and key scales are not equal." ,
82
78
query_scale ,
83
79
)
80
+ if not math .isclose (query_scale_value , sqrt_scaling_factor , rel_tol = 1e-3 ):
81
+ self ._scale = query_scale_value * query_scale_value
82
+ else :
83
+ # Pass no scaling factor to SDPA, SDPA will use the default scaling factor
84
+ self ._scale = None
84
85
else :
85
86
# Check if qk_scale is a scalar == expected_scaling_factor)
86
87
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
@@ -92,7 +93,8 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
92
93
if not math .isclose (qk_scale_value , expected_scaling_factor , rel_tol = 1e-3 ):
93
94
self ._scale = qk_scale_value
94
95
else :
95
- self ._scale = expected_scaling_factor
96
+ # Pass no scaling factor to SDPA, SDPA will use the default scaling factor
97
+ self ._scale = None
96
98
97
99
# check ranks/shapes
98
100
0 commit comments