@@ -172,36 +172,44 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
172
172
173
173
if no_match (query_BSD , ["B" , "S" , "D" ]):
174
174
return check_result .fail (
175
- f"Shape mismatch: { query_BSD } does not match expected dimensions ['B', 'S', 'D']"
175
+ f"Shape mismatch: { query_BSD } does not match expected dimensions ['B', 'S', 'D']" ,
176
+ query_BSD ,
176
177
)
177
178
if no_match (key_BSD , ["B" , "Skv" , "D" ]):
178
179
return check_result .fail (
179
- f"Shape mismatch: { key_BSD } does not match expected dimensions ['B', 'Skv', 'D']"
180
+ f"Shape mismatch: { key_BSD } does not match expected dimensions ['B', 'Skv', 'D']" ,
181
+ query_BSD ,
180
182
)
181
183
if no_match (value_BSD , ["B" , "Skv" , "D" ]):
182
184
return check_result .fail (
183
- f"Shape mismatch: { value_BSD } does not match expected dimensions ['B', 'Skv', 'D']"
185
+ f"Shape mismatch: { value_BSD } does not match expected dimensions ['B', 'Skv', 'D']" ,
186
+ value_BSD ,
184
187
)
185
188
186
189
if no_match (past_key , ["B" , "H" , "Spast" , "Dh" ]):
187
190
return check_result .fail (
188
- f"Shape mismatch: { past_key } does not match expected dimensions ['B', 'H', 'Spast', 'Dh']"
191
+ f"Shape mismatch: { past_key } does not match expected dimensions ['B', 'H', 'Spast', 'Dh']" ,
192
+ past_key ,
189
193
)
190
194
if no_match (past_value , ["B" , "H" , "Spast" , "Dv" ]):
191
195
return check_result .fail (
192
- f"Shape mismatch: { past_value } does not match expected dimensions ['B', 'H', 'Spast', 'Dv']"
196
+ f"Shape mismatch: { past_value } does not match expected dimensions ['B', 'H', 'Spast', 'Dv']" ,
197
+ past_value ,
193
198
)
194
199
if no_match (query_BSHDh , ["B" , "S" , "H" , "Dh" ]):
195
200
return check_result .fail (
196
- f"Shape mismatch: { query_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']"
201
+ f"Shape mismatch: { query_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']" ,
202
+ query_BSHDh ,
197
203
)
198
204
if no_match (key_BSHDh , ["B" , "S" , "H" , "Dh" ]):
199
205
return check_result .fail (
200
- f"Shape mismatch: { key_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']"
206
+ f"Shape mismatch: { key_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']" ,
207
+ query_BSHDh ,
201
208
)
202
209
if no_match (value_BSHDh , ["B" , "S" , "H" , "Dh" ]):
203
210
return check_result .fail (
204
- f"Shape mismatch: { value_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']"
211
+ f"Shape mismatch: { value_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']" ,
212
+ query_BSHDh ,
205
213
)
206
214
# TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
207
215
# But this also, unforunately, depends on ORT version.
0 commit comments