Skip to content

Commit 3010dac

Browse files
add failure nodes to some checks
1 parent d841ea2 commit 3010dac

File tree

4 files changed

+27
-19
lines changed

4 files changed

+27
-19
lines changed

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def check(
103103
# TODO(rama): handle redundant reshape/expand
104104
if self._const_freqs:
105105
if (freqs.const_value is None) or not _ir_utils.has_rank(freqs, 3):
106-
return check_result.fail("freqs is not a constant or not 3D.")
106+
return check_result.fail("freqs is not a constant or not 3D.", freqs)
107107
else:
108108
return check_result
109109
if (
@@ -113,14 +113,14 @@ def check(
113113
):
114114
pass
115115
else:
116-
return check_result.fail("position_ids is not a 1D or 2D tensor.")
116+
return check_result.fail("position_ids is not a 1D or 2D tensor.", position_ids)
117117
if not _ir_utils.has_rank(inv_freq, 3):
118-
return check_result.fail("inv_freq is not 3D.")
118+
return check_result.fail("inv_freq is not 3D.", inv_freq)
119119
inv_freq_shape = inv_freq.shape
120120
if inv_freq.const_value is None: # TODO: should this be inv_freq_shape?
121-
return check_result.fail("inv_freq is not a constant.")
121+
return check_result.fail("inv_freq is not a constant.", inv_freq)
122122
if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1:
123-
return check_result.fail("inv_freq is not of shape [1, ., 1].")
123+
return check_result.fail("inv_freq is not of shape [1, ., 1].", inv_freq)
124124
return check_result
125125

126126
def rewrite(

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,36 +172,44 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
172172

173173
if no_match(query_BSD, ["B", "S", "D"]):
174174
return check_result.fail(
175-
f"Shape mismatch: {query_BSD} does not match expected dimensions ['B', 'S', 'D']"
175+
f"Shape mismatch: {query_BSD} does not match expected dimensions ['B', 'S', 'D']",
176+
query_BSD,
176177
)
177178
if no_match(key_BSD, ["B", "Skv", "D"]):
178179
return check_result.fail(
179-
f"Shape mismatch: {key_BSD} does not match expected dimensions ['B', 'Skv', 'D']"
180+
f"Shape mismatch: {key_BSD} does not match expected dimensions ['B', 'Skv', 'D']",
181+
query_BSD,
180182
)
181183
if no_match(value_BSD, ["B", "Skv", "D"]):
182184
return check_result.fail(
183-
f"Shape mismatch: {value_BSD} does not match expected dimensions ['B', 'Skv', 'D']"
185+
f"Shape mismatch: {value_BSD} does not match expected dimensions ['B', 'Skv', 'D']",
186+
value_BSD,
184187
)
185188

186189
if no_match(past_key, ["B", "H", "Spast", "Dh"]):
187190
return check_result.fail(
188-
f"Shape mismatch: {past_key} does not match expected dimensions ['B', 'H', 'Spast', 'Dh']"
191+
f"Shape mismatch: {past_key} does not match expected dimensions ['B', 'H', 'Spast', 'Dh']",
192+
past_key,
189193
)
190194
if no_match(past_value, ["B", "H", "Spast", "Dv"]):
191195
return check_result.fail(
192-
f"Shape mismatch: {past_value} does not match expected dimensions ['B', 'H', 'Spast', 'Dv']"
196+
f"Shape mismatch: {past_value} does not match expected dimensions ['B', 'H', 'Spast', 'Dv']",
197+
past_value,
193198
)
194199
if no_match(query_BSHDh, ["B", "S", "H", "Dh"]):
195200
return check_result.fail(
196-
f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']"
201+
f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
202+
query_BSHDh,
197203
)
198204
if no_match(key_BSHDh, ["B", "S", "H", "Dh"]):
199205
return check_result.fail(
200-
f"Shape mismatch: {key_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']"
206+
f"Shape mismatch: {key_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
207+
query_BSHDh,
201208
)
202209
if no_match(value_BSHDh, ["B", "S", "H", "Dh"]):
203210
return check_result.fail(
204-
f"Shape mismatch: {value_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']"
211+
f"Shape mismatch: {value_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
212+
query_BSHDh,
205213
)
206214
# TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
207215
# But this also, unforunately, depends on ORT version.

onnxscript/rewriter/ort_fusions/rms_normalization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.M
5858
# epsilon must be a scalar
5959
epsilon_value = _ir_utils.get_singleton_value(epsilon)
6060
if not isinstance(epsilon_value, float): # TODO: support other types
61-
return check_result.fail("Epsilon is not a float value.")
61+
return check_result.fail("Epsilon is not a float value.", epsilon)
6262
# input and output must be same dtype
6363
if x.dtype not in float_types:
64-
return check_result.fail("Input is not a float type.")
64+
return check_result.fail("Input is not a float type.", x)
6565
if scale.dtype not in float_types:
66-
return check_result.fail("Scale is not a float type.")
66+
return check_result.fail("Scale is not a float type.", scale)
6767
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
6868
if stash_dtype not in fp_float_types:
6969
return check_result.fail("Normalization precision is not a float or double type.")

onnxscript/rewriter/ort_fusions/rotary_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult:
3434
check_result = pattern.MatchResult()
3535
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
3636
if x is None or x.shape is None or len(x.shape) != 4:
37-
return check_result.fail("Input is not a 4D tensor.")
37+
return check_result.fail("Input is not a 4D tensor.", x)
3838
if not isinstance(x.shape[1], int):
39-
return check_result.fail("Input dimension 1 is not an integer.")
39+
return check_result.fail("Input dimension 1 is not an integer.", x)
4040
head_size = x.shape[3]
4141
if not isinstance(head_size, int):
42-
return check_result.fail("Head size is not an integer.")
42+
return check_result.fail("Head size is not an integer.", x)
4343
half_head_size = head_size // 2
4444

4545
# Check that x is being split into two equal halves of size half_head_size

0 commit comments

Comments
 (0)