Skip to content

Add phi2/phi4 test cases for mha/gqa fusion #2409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ exclude_patterns = [
'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
'onnxscript/rewriter/ort_fusions/_smollm_*.py', # onnxscript code
'onnxscript/rewriter/ort_fusions/models/*.py', # onnxscript code
'onnxscript/rewriter/ort_fusions/models/_phi2lm.py', # onnxscript code
'onnxscript/rewriter/ort_fusions/models/_phi4lm.py', # onnxscript code
'onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py', # onnxscript code
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
Expand Down
25 changes: 20 additions & 5 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def _get_const_repr(const_node):
rank = len(tensor_proto.dims)
if rank == 0:
array = onnx.numpy_helper.to_array(tensor_proto).reshape(1) # noqa: TID251
return repr(array[0])
return str(array[0])
if rank == 1 and tensor_proto.dims[0] < 5:
return repr(list(onnx.numpy_helper.to_array(tensor_proto))) # noqa: TID251
nparray = onnx.numpy_helper.to_array(tensor_proto) # noqa: TID251
return repr(nparray.tolist())
return None


Expand Down Expand Up @@ -138,6 +139,15 @@ def input_sig(inp: ValueInfoProto | str):
return f"{result}:"


def _translate_value_infos(value_infos: Sequence[ValueInfoProto]) -> str:
def _translate_value_info(value_info: ValueInfoProto) -> str:
return f"{_SINGLE_INDENT}'{_cleanup_variable_name(value_info.name)}': {_translate_type(value_info.type)},"

lines = [_translate_value_info(x) for x in value_infos]
lines_joined = "\n".join(lines)
return "{\n" + lines_joined + "\n}"


def _to_str(s):
if isinstance(s, bytes):
return s.decode("utf-8")
Expand Down Expand Up @@ -710,10 +720,13 @@ def add(line: str) -> None:
add(f"{indent}return {return_values}")
script = "\n".join(result)
if self.skipped_initializers:
return self._substitute_initializers(script, function_name)
value_infos = _translate_value_infos(graph.value_info)
return self._substitute_initializers(script, function_name, value_infos)
return script

def _substitute_initializers(self, script: str, script_function_name: str) -> str:
def _substitute_initializers(
self, script: str, script_function_name: str, value_infos: str
) -> str:
init_names = self.skipped_initializers.keys()
# Formal parameters representing initializers (single level indentation)
__ = _SINGLE_INDENT
Expand All @@ -733,12 +746,14 @@ def generate_rand(name: str, value: TensorProto) -> str:
# Actual parameter values for initializers (double level indentation)
indented_initializers_as_params = "\n".join(f"{__}{__}{x}," for x in init_names)
return f"""
value_infos = {value_infos}

def make_model(
{initializers_as_params}
):
{script}

{__}model = {script_function_name}.to_model_proto()
{__}model = {script_function_name}.to_model_proto(value_infos=value_infos)
{__}return model

def make_model_with_random_weights():
Expand Down
21 changes: 21 additions & 0 deletions onnxscript/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,27 @@ def cast_add(x, y):
self.assertEqual(y_value_info.type.tensor_type.elem_type, onnx.TensorProto.INT64)
self.assertEqual(output_value_info.type.tensor_type.elem_type, onnx.TensorProto.FLOAT)

def test_set_value_info(self):
@script()
def double_square(x):
square = op.Mul(x, x)
return op.Add(square, square)

# Converting "cast_add" to a ModelProto will generate an incomplete ModelProto,
# with input-types undefined (since the script has no type-annotation).
model = double_square.to_model_proto()
graph = model.graph
self.assertEqual(len(graph.value_info), 0)
model = double_square.to_model_proto(
io_types=FLOAT["N"], value_infos={"square": FLOAT["N"]}
)
graph = model.graph
self.assertEqual(len(graph.value_info), 1)
value_info = graph.value_info[0]
self.assertEqual(value_info.name, "square")
self.assertEqual(value_info.type.tensor_type.elem_type, onnx.TensorProto.FLOAT)
self.assertEqual(value_info.type.tensor_type.shape.dim[0].dim_param, "N")

def test_onnxfns1(self):
from tests.models import onnxfns1

Expand Down
22 changes: 20 additions & 2 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def to_model_proto(
io_types: Optional[ONNXType] = None,
input_types: Optional[Sequence[ONNXType]] = None,
output_types: Optional[Sequence[ONNXType]] = None,
value_infos: dict[str, ONNXType] | None = None,
**kwargs,
) -> onnx.ModelProto:
"""Converts this instance into a `onnx.ModelProto`.
Expand All @@ -333,12 +334,24 @@ def to_model_proto(
are set to be of the corresponding type in this list.
output_types: When specified, all the outputs of the model
are set to be of the corresponding type in this list.
value_infos: A dictionary mapping intermediate variable names to ONNX types.
Used to set value_info for intermediate variables.
kwargs: Additional parameters given to function :func:`onnx.helper.make_model`.

Returns:
An instance of :class:`onnx.ModelProto`.
"""
graph, sub_functions = self.to_graph_and_functions(use_default_type=False)
value_infos = (
[
onnx.helper.make_value_info(name, type.to_type_proto())
for name, type in value_infos.items()
]
if value_infos
else None
)
graph, sub_functions = self.to_graph_and_functions(
use_default_type=False, value_infos=value_infos
)
if io_types is not None:
for input in graph.input:
if not input.HasField("type"):
Expand Down Expand Up @@ -394,14 +407,18 @@ def to_proto(f):
)

def to_graph_and_functions(
self, use_default_type: bool = True
self,
use_default_type: bool = True,
value_infos: Sequence[ValueInfoProto] | None = None,
) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]:
"""Converts this instance into a `onnx.GraphProto` and a map from
function-name to `onnx.FunctionProto`.

Args:
use_default_type: if True, the function uses a default type
for inputs and outputs that do not have a type
value_infos: a sequence of :class:`onnx.ValueInfoProto` to be added
to the graph.

Returns:
a pair of a :class:`onnx.GraphProto` and list of :class:`onnx.FunctionProto`
Expand All @@ -415,6 +432,7 @@ def to_graph_and_functions(
self.name,
[x.to_value_info(use_default_type) for x in self.inputs],
[y.to_value_info(use_default_type) for y in self.outputs],
value_info=value_infos,
)
return graph, called_functions

Expand Down
13 changes: 13 additions & 0 deletions onnxscript/rewriter/ort_fusions/gqa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import onnxscript.optimizer
from onnxscript import FLOAT, script
from onnxscript import opset18 as op
from onnxscript.rewriter.ort_fusions import optimize_for_ort
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
from onnxscript.rewriter.ort_fusions.models._phi4lm import phi4lm_test
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa

msft_op = onnxscript.values.Opset("com.microsoft", 1)
Expand Down Expand Up @@ -359,5 +361,16 @@ def test_fusion(self):
assert_allclose(outputs3, source_model_outputs)


class GQAFusionTest2(unittest.TestCase):
@unittest.skip("Needs too much memory.")
def test_phi4lm(self):
test_case = phi4lm_test()
model = test_case.get_onnx_model()
onnxscript.optimizer.optimize(model)
optimize_for_ort(model, debug=True)
gqa_nodes = [n for n in model.graph if n.op_type == "GQA"]
self.assertEqual(len(gqa_nodes), 2, "Expected 2i GQA nodes after fusion")


if __name__ == "__main__":
unittest.main()
18 changes: 18 additions & 0 deletions onnxscript/rewriter/ort_fusions/mha_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import onnxscript.optimizer
import onnxscript.rewriter.ort_fusions._core as xformers
from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run
from onnxscript.rewriter.ort_fusions.models._phi2lm import phi2lm_test
from onnxscript.rewriter.ort_fusions.models._smollm_2 import smollm_test_2
from onnxscript.rewriter.ort_fusions.models._whisper_decoder import whisper_decoder_test
from onnxscript.rewriter.ort_fusions.models._whisper_encoder import whisper_encoder_test
Expand Down Expand Up @@ -96,6 +97,23 @@ def test_whisper_decoder(self):
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)

def test_phi2lm(self):
test_case = phi2lm_test()
model = test_case.get_onnx_model()
onnxscript.optimizer.optimize(model)
xformers.optimize_for_ort(model)
mha_nodes = [n for n in model.graph if n.op_type == "MultiHeadAttention"]
self.assertEqual(
len(mha_nodes),
1,
"Expected exactly one MultiHeadAttention node after optimization",
)
mha_node = mha_nodes[0]
# Check that the MHA node has past kv cache inputs
self.assertEqual(len(mha_node.inputs), 8, "Expected MHA node to have 8 inputs")
self.assertIsNotNone(mha_node.inputs[6], "Expected MHA node to have past key input")
self.assertIsNotNone(mha_node.inputs[7], "Expected MHA node to have past value input")


if __name__ == "__main__":
unittest.main()
Loading
Loading