@@ -61,40 +61,36 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
61
61
# If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used.
62
62
sqrt_scaling_factor = math .sqrt (expected_scaling_factor )
63
63
# Calculate the scaling factor for query
64
- if _ir_utils .get_singleton_value (query_scale ) is None :
64
+ if ( query_scale_value := _ir_utils .get_singleton_value (query_scale ) ) is None :
65
65
return check_result .fail (
66
66
"Query scale is not a scalar." ,
67
67
query_scale ,
68
68
)
69
- if not _ir_utils .is_singleton_value (query_scale , sqrt_scaling_factor , rtol = 1e-3 ):
70
- query_scale_value = _ir_utils .get_singleton_value (query_scale )
69
+ if not math .isclose (query_scale_value , sqrt_scaling_factor , rel_tol = 1e-3 ):
71
70
self ._scale = query_scale_value * query_scale_value
72
71
else :
73
72
self ._scale = expected_scaling_factor
74
73
# Ensure the scaling factor for key is the same as for query
75
- if _ir_utils .get_singleton_value (key_scale ) is None :
74
+ if ( key_scale_value := _ir_utils .get_singleton_value (key_scale ) ) is None :
76
75
return check_result .fail (
77
76
"Key scale is not a scalar." ,
78
77
key_scale ,
79
78
)
80
- if not _ir_utils .is_singleton_value (key_scale , sqrt_scaling_factor , rtol = 1e-3 ):
81
- if _ir_utils .get_singleton_value (query_scale ) != _ir_utils .get_singleton_value (
82
- key_scale
83
- ):
84
- return check_result .fail (
85
- "Query and key scales are not equal." ,
86
- query_scale ,
87
- )
79
+ if not math .isclose (query_scale_value , key_scale_value , rel_tol = 1e-3 ):
80
+ return check_result .fail (
81
+ "Query and key scales are not equal." ,
82
+ query_scale ,
83
+ )
88
84
else :
89
85
# Check if qk_scale is a scalar == expected_scaling_factor)
90
86
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
91
- if _ir_utils .get_singleton_value (qk_scale ) is None :
87
+ if ( qk_scale_value := _ir_utils .get_singleton_value (qk_scale ) ) is None :
92
88
return check_result .fail (
93
89
"QK scale is not a scalar." ,
94
90
qk_scale ,
95
91
)
96
- if not _ir_utils . is_singleton_value ( qk_scale , expected_scaling_factor , rtol = 1e-3 ):
97
- self ._scale = _ir_utils . get_singleton_value ( qk_scale )
92
+ if not math . isclose ( qk_scale_value , expected_scaling_factor , rel_tol = 1e-3 ):
93
+ self ._scale = qk_scale_value
98
94
else :
99
95
self ._scale = expected_scaling_factor
100
96
0 commit comments