Skip to content

Commit f93eb58

Browse files
authored
GQA Fusion (#2161)
Introduce GQA Fusion (for Phi models).
1 parent 8b1f814 commit f93eb58

File tree

4 files changed

+573
-110
lines changed

4 files changed

+573
-110
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -405,17 +405,13 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
405405
shape = _get_input(node, 1)
406406
if input is None or shape is None:
407407
return None
408+
408409
input_shape = input.shape
409-
if input_shape is None:
410-
return None
411-
# input_shape_dims = list(input_shape.dims)
412-
# if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in input_shape_dims):
413-
# return None
414410
shape_value = state.get_shape_value(shape)
415-
if shape_value is None:
411+
412+
if shape_value is None or input_shape is None:
416413
return None
417-
# target_shape_dims = list(shape_value.dims)
418-
# if input_shape_dims == target_shape_dims:
414+
419415
# No need to check for special values like -1, 0, etc. here
420416
if _same_shape(input_shape, shape_value):
421417
return op.Identity(input)

onnxscript/rewriter/ort_fusions/_test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,8 @@ def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4):
3939
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
4040
np.testing.assert_allclose(baseline_output, optimized_output, rtol=rtol, atol=atol)
4141
except AssertionError as e:
42+
diff_mask = ~np.isclose(baseline_output, optimized_output, rtol=rtol, atol=atol)
43+
diff = np.where(diff_mask, "X", " ")
44+
print(diff)
4245
print(f"Failed for output {i} with rtol={rtol} and atol={atol}\n{e}")
4346
raise

0 commit comments

Comments
 (0)