Skip to content

Commit e56c5bb

Browse files
Add phi2/phi4 test cases for mha/gqa fusion (#2409)
Onnxscript extensions: * Extend onnxscript's toModelProto to allow specification of valueinfos in generated model. * Extend the onnx-proto to onnxscript converter to serialize valueinfos in the model, so that it can be used in the generated script. Fusion test cases: * Add Phi2 (1 layer) and Phi4 (2 layer) test cases for MHA and GQA fusion respectively --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent 92decb4 commit e56c5bb

File tree

8 files changed

+1350
-8
lines changed

8 files changed

+1350
-8
lines changed

.lintrunner.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ exclude_patterns = [
5050
'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME
5151
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
5252
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
53-
'onnxscript/rewriter/ort_fusions/_smollm_*.py', # onnxscript code
53+
'onnxscript/rewriter/ort_fusions/models/*.py', # onnxscript code
54+
'onnxscript/rewriter/ort_fusions/models/_phi2lm.py', # onnxscript code
55+
'onnxscript/rewriter/ort_fusions/models/_phi4lm.py', # onnxscript code
5456
'onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py', # onnxscript code
5557
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
5658
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME

onnxscript/backend/onnx_export.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ def _get_const_repr(const_node):
6868
rank = len(tensor_proto.dims)
6969
if rank == 0:
7070
array = onnx.numpy_helper.to_array(tensor_proto).reshape(1) # noqa: TID251
71-
return repr(array[0])
71+
return str(array[0])
7272
if rank == 1 and tensor_proto.dims[0] < 5:
73-
return repr(list(onnx.numpy_helper.to_array(tensor_proto))) # noqa: TID251
73+
nparray = onnx.numpy_helper.to_array(tensor_proto) # noqa: TID251
74+
return repr(nparray.tolist())
7475
return None
7576

7677

@@ -138,6 +139,15 @@ def input_sig(inp: ValueInfoProto | str):
138139
return f"{result}:"
139140

140141

142+
def _translate_value_infos(value_infos: Sequence[ValueInfoProto]) -> str:
143+
def _translate_value_info(value_info: ValueInfoProto) -> str:
144+
return f"{_SINGLE_INDENT}'{_cleanup_variable_name(value_info.name)}': {_translate_type(value_info.type)},"
145+
146+
lines = [_translate_value_info(x) for x in value_infos]
147+
lines_joined = "\n".join(lines)
148+
return "{\n" + lines_joined + "\n}"
149+
150+
141151
def _to_str(s):
142152
if isinstance(s, bytes):
143153
return s.decode("utf-8")
@@ -710,10 +720,13 @@ def add(line: str) -> None:
710720
add(f"{indent}return {return_values}")
711721
script = "\n".join(result)
712722
if self.skipped_initializers:
713-
return self._substitute_initializers(script, function_name)
723+
value_infos = _translate_value_infos(graph.value_info)
724+
return self._substitute_initializers(script, function_name, value_infos)
714725
return script
715726

716-
def _substitute_initializers(self, script: str, script_function_name: str) -> str:
727+
def _substitute_initializers(
728+
self, script: str, script_function_name: str, value_infos: str
729+
) -> str:
717730
init_names = self.skipped_initializers.keys()
718731
# Formal parameters representing initializers (single level indentation)
719732
__ = _SINGLE_INDENT
@@ -733,12 +746,14 @@ def generate_rand(name: str, value: TensorProto) -> str:
733746
# Actual parameter values for initializers (double level indentation)
734747
indented_initializers_as_params = "\n".join(f"{__}{__}{x}," for x in init_names)
735748
return f"""
749+
value_infos = {value_infos}
750+
736751
def make_model(
737752
{initializers_as_params}
738753
):
739754
{script}
740755
741-
{__}model = {script_function_name}.to_model_proto()
756+
{__}model = {script_function_name}.to_model_proto(value_infos=value_infos)
742757
{__}return model
743758
744759
def make_model_with_random_weights():

onnxscript/converter_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,27 @@ def cast_add(x, y):
191191
self.assertEqual(y_value_info.type.tensor_type.elem_type, onnx.TensorProto.INT64)
192192
self.assertEqual(output_value_info.type.tensor_type.elem_type, onnx.TensorProto.FLOAT)
193193

194+
def test_set_value_info(self):
195+
@script()
196+
def double_square(x):
197+
square = op.Mul(x, x)
198+
return op.Add(square, square)
199+
200+
# Converting "cast_add" to a ModelProto will generate an incomplete ModelProto,
201+
# with input-types undefined (since the script has no type-annotation).
202+
model = double_square.to_model_proto()
203+
graph = model.graph
204+
self.assertEqual(len(graph.value_info), 0)
205+
model = double_square.to_model_proto(
206+
io_types=FLOAT["N"], value_infos={"square": FLOAT["N"]}
207+
)
208+
graph = model.graph
209+
self.assertEqual(len(graph.value_info), 1)
210+
value_info = graph.value_info[0]
211+
self.assertEqual(value_info.name, "square")
212+
self.assertEqual(value_info.type.tensor_type.elem_type, onnx.TensorProto.FLOAT)
213+
self.assertEqual(value_info.type.tensor_type.shape.dim[0].dim_param, "N")
214+
194215
def test_onnxfns1(self):
195216
from tests.models import onnxfns1
196217

onnxscript/irbuilder.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def to_model_proto(
320320
io_types: Optional[ONNXType] = None,
321321
input_types: Optional[Sequence[ONNXType]] = None,
322322
output_types: Optional[Sequence[ONNXType]] = None,
323+
value_infos: dict[str, ONNXType] | None = None,
323324
**kwargs,
324325
) -> onnx.ModelProto:
325326
"""Converts this instance into a `onnx.ModelProto`.
@@ -333,12 +334,24 @@ def to_model_proto(
333334
are set to be of the corresponding type in this list.
334335
output_types: When specified, all the outputs of the model
335336
are set to be of the corresponding type in this list.
337+
value_infos: A dictionary mapping intermediate variable names to ONNX types.
338+
Used to set value_info for intermediate variables.
336339
kwargs: Additional parameters given to function :func:`onnx.helper.make_model`.
337340
338341
Returns:
339342
An instance of :class:`onnx.ModelProto`.
340343
"""
341-
graph, sub_functions = self.to_graph_and_functions(use_default_type=False)
344+
value_infos = (
345+
[
346+
onnx.helper.make_value_info(name, type.to_type_proto())
347+
for name, type in value_infos.items()
348+
]
349+
if value_infos
350+
else None
351+
)
352+
graph, sub_functions = self.to_graph_and_functions(
353+
use_default_type=False, value_infos=value_infos
354+
)
342355
if io_types is not None:
343356
for input in graph.input:
344357
if not input.HasField("type"):
@@ -394,14 +407,18 @@ def to_proto(f):
394407
)
395408

396409
def to_graph_and_functions(
397-
self, use_default_type: bool = True
410+
self,
411+
use_default_type: bool = True,
412+
value_infos: Sequence[ValueInfoProto] | None = None,
398413
) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]:
399414
"""Converts this instance into a `onnx.GraphProto` and a map from
400415
function-name to `onnx.FunctionProto`.
401416
402417
Args:
403418
use_default_type: if True, the function uses a default type
404419
for inputs and outputs that do not have a type
420+
value_infos: a sequence of :class:`onnx.ValueInfoProto` to be added
421+
to the graph.
405422
406423
Returns:
407424
a pair of a :class:`onnx.GraphProto` and list of :class:`onnx.FunctionProto`
@@ -415,6 +432,7 @@ def to_graph_and_functions(
415432
self.name,
416433
[x.to_value_info(use_default_type) for x in self.inputs],
417434
[y.to_value_info(use_default_type) for y in self.outputs],
435+
value_info=value_infos,
418436
)
419437
return graph, called_functions
420438

onnxscript/rewriter/ort_fusions/gqa_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
import onnxscript.optimizer
1717
from onnxscript import FLOAT, script
1818
from onnxscript import opset18 as op
19+
from onnxscript.rewriter.ort_fusions import optimize_for_ort
1920
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose
2021
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
22+
from onnxscript.rewriter.ort_fusions.models._phi4lm import phi4lm_test
2123
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
2224

2325
msft_op = onnxscript.values.Opset("com.microsoft", 1)
@@ -359,5 +361,16 @@ def test_fusion(self):
359361
assert_allclose(outputs3, source_model_outputs)
360362

361363

364+
class GQAFusionTest2(unittest.TestCase):
365+
@unittest.skip("Needs too much memory.")
366+
def test_phi4lm(self):
367+
test_case = phi4lm_test()
368+
model = test_case.get_onnx_model()
369+
onnxscript.optimizer.optimize(model)
370+
optimize_for_ort(model, debug=True)
371+
gqa_nodes = [n for n in model.graph if n.op_type == "GQA"]
372+
self.assertEqual(len(gqa_nodes), 2, "Expected 2i GQA nodes after fusion")
373+
374+
362375
if __name__ == "__main__":
363376
unittest.main()

onnxscript/rewriter/ort_fusions/mha_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import onnxscript.optimizer
1111
import onnxscript.rewriter.ort_fusions._core as xformers
1212
from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run
13+
from onnxscript.rewriter.ort_fusions.models._phi2lm import phi2lm_test
1314
from onnxscript.rewriter.ort_fusions.models._smollm_2 import smollm_test_2
1415
from onnxscript.rewriter.ort_fusions.models._whisper_decoder import whisper_decoder_test
1516
from onnxscript.rewriter.ort_fusions.models._whisper_encoder import whisper_encoder_test
@@ -96,6 +97,23 @@ def test_whisper_decoder(self):
9697
new_outputs = ort_run("optimized", model, inputs)
9798
assert_allclose(new_outputs, original_outputs)
9899

100+
def test_phi2lm(self):
101+
test_case = phi2lm_test()
102+
model = test_case.get_onnx_model()
103+
onnxscript.optimizer.optimize(model)
104+
xformers.optimize_for_ort(model)
105+
mha_nodes = [n for n in model.graph if n.op_type == "MultiHeadAttention"]
106+
self.assertEqual(
107+
len(mha_nodes),
108+
1,
109+
"Expected exactly one MultiHeadAttention node after optimization",
110+
)
111+
mha_node = mha_nodes[0]
112+
# Check that the MHA node has past kv cache inputs
113+
self.assertEqual(len(mha_node.inputs), 8, "Expected MHA node to have 8 inputs")
114+
self.assertIsNotNone(mha_node.inputs[6], "Expected MHA node to have past key input")
115+
self.assertIsNotNone(mha_node.inputs[7], "Expected MHA node to have past value input")
116+
99117

100118
if __name__ == "__main__":
101119
unittest.main()

0 commit comments

Comments
 (0)