From 98afaa0ace0e893f6b6b96f9f56cbe4592470c5b Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Sat, 31 May 2025 20:22:04 +0200 Subject: [PATCH 1/2] =?UTF-8?q?[Rewriter]:=20MatMul=20=E2=88=98=20Add=20->?= =?UTF-8?q?=20Gemm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- onnxscript/rewriter/matmul_add_to_gemm.py | 101 ++++++ .../rewriter/matmul_add_to_gemm_test.py | 314 ++++++++++++++++++ 2 files changed, 415 insertions(+) create mode 100644 onnxscript/rewriter/matmul_add_to_gemm.py create mode 100644 onnxscript/rewriter/matmul_add_to_gemm_test.py diff --git a/onnxscript/rewriter/matmul_add_to_gemm.py b/onnxscript/rewriter/matmul_add_to_gemm.py new file mode 100644 index 0000000000..eb15296414 --- /dev/null +++ b/onnxscript/rewriter/matmul_add_to_gemm.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Does the following transformation: +- Add(MatMul(X, W), B) -> Gemm +- Add(MatMul(Transpose(X), W), B) -> Gemm +- Add(MatMul(X, Transpose(W)), B) -> Gemm +- Add(MatMul(Transpose(X), Transpose(W)), B) -> Gemm +""" + +from abc import ABC +from typing import ClassVar + +from onnxscript.rewriter import pattern as orp + + +class _MatMulAddToGemmBase(orp.RewriteRuleClassBase, ABC): + trans_a: ClassVar = False + trans_b: ClassVar = False + + def rewrite(self, op, input_a, input_b, input_c): + attributes = {} + if self.trans_a: + attributes["transA"] = 1 + if self.trans_b: + attributes["transB"] = 1 + return op.Gemm(input_a, input_b, input_c, **attributes) + + def check(self, context, input_a, input_b, **_): + del context # Not used + check_result = orp.MatchResult() + # Rank of input_a and input_b must be 2 + if len(input_a.shape) != 2 or len(input_b.shape) != 2: + return check_result.fail("Rank of input_a and input_b must be 2") + return check_result + + +class MatMulAddToGemm(_MatMulAddToGemmBase): + """Replaces ``Add(MatMul(a, b), c)`` with ``Gemm(a, b, c)``.""" + + def pattern(self, op, input_a, input_b, input_c): + matmul = op.MatMul(input_a, input_b) + return op.Add(matmul, input_c) + + +class TransAMatMulAddToGemm(_MatMulAddToGemmBase): + """Replaces ``Add(MatMul(Transpose(a), b), c)`` with ``Gemm(a, b, c)``.""" + + trans_a: ClassVar = True + + def pattern(self, op, input_a, input_b, input_c): + matmul = op.MatMul(op.Transpose(input_a, perm=[1, 0]), input_b) + return op.Add(matmul, input_c) + + +class TransBMatMulAddToGemm(_MatMulAddToGemmBase): + """Replaces ``Add(MatMul(a, Transpose(b)), c)`` with ``Gemm(a, b, c)``.""" + + trans_b: ClassVar = True + + def pattern(self, op, input_a, input_b, input_c): + matmul = op.MatMul(input_a, op.Transpose(input_b, perm=[1, 0])) + return op.Add(matmul, input_c) + + +class TransABMatMulAddToGemm(_MatMulAddToGemmBase): + """Replaces ``Add(MatMul(Transpose(a), Transpose(b)), c)`` with ``Gemm(a, b, c)``.""" + + trans_a: ClassVar = True + trans_b: ClassVar = True + + def pattern(self, op, input_a, input_b, input_c): + matmul = op.MatMul( + op.Transpose(input_a, perm=[1, 0]), + op.Transpose(input_b, perm=[1, 0]), + ) + return op.Add(matmul, input_c) + + +matmul_add_to_gemm_rule = MatMulAddToGemm().rule() +transpose_a_matmul_add_to_gemm_rule = TransAMatMulAddToGemm().rule() +transpose_b_matmul_add_to_gemm_rule = TransBMatMulAddToGemm().rule() +transpose_ab_matmul_add_to_gemm_rule = TransABMatMulAddToGemm().rule() + + +def gemm_rule_set() -> orp.RewriteRuleSet: + """Returns a set of rewrite rules that fuse MatMul + Add patterns into a single Gemm node, + handling cases where one or both MatMul inputs are transposed. + + Returns: + RewriteRuleSet + """ + + # Order is important + return orp.RewriteRuleSet( + [ + transpose_ab_matmul_add_to_gemm_rule, + transpose_a_matmul_add_to_gemm_rule, + transpose_b_matmul_add_to_gemm_rule, + matmul_add_to_gemm_rule, + ] + ) diff --git a/onnxscript/rewriter/matmul_add_to_gemm_test.py b/onnxscript/rewriter/matmul_add_to_gemm_test.py new file mode 100644 index 0000000000..937b3c7f1a --- /dev/null +++ b/onnxscript/rewriter/matmul_add_to_gemm_test.py @@ -0,0 +1,314 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest +from typing import Sequence + +import numpy as np +import onnx +from onnx_ir.passes.common import onnx_checker, shape_inference +from parameterized import parameterized + +from onnxscript import ir +from onnxscript.rewriter import matmul_add_to_gemm, testing +from onnxscript.rewriter import pattern as orp +from onnxscript.rewriter.matmul_add_to_gemm import matmul_add_to_gemm_rule + + +class _MatMulAddToGemmTestBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250607) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def get_test_model( + self, + input_shape: ir.Shape, + weight_shape: ir.Shape, + transA: bool = False, + transB: bool = False, + permA: Sequence[int] = [1, 0], + permB: Sequence[int] = [1, 0], + weight_as_inputs: bool = False, + bias_as_inputs: bool = False, + ): + """Returns the following model: + + Y = Add(MatMul(Transpose(X), Transpose(W)), B) + + Where: + - Transpose(X) is applied only if `transA=True` + - Transpose(W) is applied only if `transB=True` + - W and B can be graph inputs or initializers + """ + tape = ir.tape.Tape() + inputs = [] + bias_shape = weight_shape[0] if transB else weight_shape[-1] + output_shape = ir.Shape(("?",) * input_shape.rank()) + + x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + + if weight_as_inputs: + w = ir.Input("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT)) + inputs.append(w) + else: + w = ir.tensor( + self.rng.uniform(-0.5, 0.5, weight_shape).astype("float32"), name="W" + ) + w = tape.initializer(w) + + if bias_as_inputs: + b = ir.Input( + "B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT) + ) + inputs.append(b) + else: + b = ir.tensor(self.rng.uniform(-0.5, 0.5, bias_shape).astype("float32"), name="B") + b = tape.initializer(b) + + if transA: + x_t = tape.op("Transpose", inputs=[x], attributes={"perm": permA}) + + if transB: + w_t = tape.op("Transpose", inputs=[w], attributes={"perm": permB}) + + y = tape.op("MatMul", inputs=[x_t if transA else x, w_t if transB else w]) + y = tape.op( + "Add", + inputs=[y, b], + output=ir.Input("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), + ) + + # Build the model + ir_model = ir.Model( + ir.Graph( + inputs=[x, *inputs], + outputs=[y], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 20}, + name="test_model", + ), + ir_version=10, + ) + onnx_checker.CheckerPass(True)(ir_model) + ir_model = shape_inference.infer_shapes(ir_model) + return ir_model + + def check_matmul_add_to_gemm_incompatible_shapes(self, **kwargs): + base_model = self.get_test_model(**kwargs) + + updated_model = self.clone_model(base_model) + tracer = orp.MatchingTracer() + count = matmul_add_to_gemm_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[matmul_add_to_gemm_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex( + tracer_match.match_result.reason, "Rank of input_a and input_b must be 2" + ) + + +class MatMulAddToGemmTest(_MatMulAddToGemmTestBase): + @parameterized.expand( + [ + ("initializers", False, False), + ("inputs", True, True), + ] + ) + def test_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): + base_model = self.get_test_model( + input_shape=ir.Shape((512, 256)), + weight_shape=ir.Shape((256, 64)), + weight_as_inputs=weight_as_inputs, + bias_as_inputs=bias_as_inputs, + ) + updated_model = self.clone_model(base_model) + count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + + # Check MatMul + Add are fused into Gemm + self.assertEqual(count, 1) + self.assertEqual(len(updated_model.graph), 1) + + # Prepare inputs + if weight_as_inputs and bias_as_inputs: + inputs = ( + self.rng.random((512, 256), dtype=np.float32), + self.rng.random((256, 64), dtype=np.float32), + self.rng.random((64), dtype=np.float32), + ) + else: + inputs = (self.rng.random((512, 256), dtype=np.float32),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_matmul_add_to_gemm_incompatible_shapes(self): + kwargs = { + "input_shape": ir.Shape((1, 256, 512)), + "weight_shape": ir.Shape((1, 512, 64)), + } + return super().check_matmul_add_to_gemm_incompatible_shapes(**kwargs) + + +class TransAMatMulAddToGemmTest(_MatMulAddToGemmTestBase): + @parameterized.expand( + [ + ("initializers", False, False), + ("inputs", True, True), + ] + ) + def test_transpose_a_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): + base_model = self.get_test_model( + input_shape=ir.Shape((256, 512)), + weight_shape=ir.Shape((256, 64)), + weight_as_inputs=weight_as_inputs, + bias_as_inputs=bias_as_inputs, + transA=True, + ) + updated_model = self.clone_model(base_model) + count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + + # Check MatMul(Transpose, W) + Add are fused into Gemm + self.assertEqual(count, 1) + self.assertEqual(len(updated_model.graph), 1) + + # Prepare inputs + if weight_as_inputs and bias_as_inputs: + inputs = ( + self.rng.random((256, 512), dtype=np.float32), + self.rng.random((256, 64), dtype=np.float32), + self.rng.random((64,), dtype=np.float32), + ) + else: + inputs = (self.rng.random((256, 512), dtype=np.float32),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_transpose_a_matmul_add_to_gemm_incompatible_shapes(self): + kwargs = { + "input_shape": ir.Shape((1, 256, 512)), + "weight_shape": ir.Shape((1, 256, 64)), + "transA": True, + "permA": [0, 2, 1], + } + return super().check_matmul_add_to_gemm_incompatible_shapes(**kwargs) + + +class TransBMatMulAddToGemmTest(_MatMulAddToGemmTestBase): + @parameterized.expand( + [ + ("initializers", False, False), + ("inputs", True, True), + ] + ) + def test_transpose_b_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): + base_model = self.get_test_model( + input_shape=ir.Shape((512, 256)), + weight_shape=ir.Shape((64, 256)), + weight_as_inputs=weight_as_inputs, + bias_as_inputs=bias_as_inputs, + transB=True, + ) + updated_model = self.clone_model(base_model) + count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + + # Check MatMul(X, Transpose) + Add are fused into Gemm + self.assertEqual(count, 1) + self.assertEqual(len(updated_model.graph), 1) + + # Prepare inputs + if weight_as_inputs and bias_as_inputs: + inputs = ( + self.rng.random((512, 256), dtype=np.float32), + self.rng.random((64, 256), dtype=np.float32), + self.rng.random((64,), dtype=np.float32), + ) + else: + inputs = (self.rng.random((512, 256), dtype=np.float32),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_transpose_b_matmul_add_to_gemm_incompatible_shapes(self): + kwargs = { + "input_shape": ir.Shape((1, 512, 256)), + "weight_shape": ir.Shape((1, 64, 256)), + "transB": True, + "permB": [0, 2, 1], + } + return super().check_matmul_add_to_gemm_incompatible_shapes(**kwargs) + + +class TransABMatMulAddToGemmTest(_MatMulAddToGemmTestBase): + @parameterized.expand( + [ + ("initializers", False, False), + ("inputs", True, True), + ] + ) + def test_transpose_ab_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): + base_model = self.get_test_model( + input_shape=ir.Shape((256, 512)), + weight_shape=ir.Shape((64, 256)), + weight_as_inputs=weight_as_inputs, + bias_as_inputs=bias_as_inputs, + transA=True, + transB=True, + ) + updated_model = self.clone_model(base_model) + count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + + # Check MatMul(Transpose, Transpose) + Add are fused into Gemm + self.assertEqual(count, 1) + self.assertEqual(len(updated_model.graph), 1) + + # Prepare inputs + if weight_as_inputs and bias_as_inputs: + inputs = ( + self.rng.random((256, 512), dtype=np.float32), + self.rng.random((64, 256), dtype=np.float32), + self.rng.random((64), dtype=np.float32), + ) + else: + inputs = (self.rng.random((256, 512), dtype=np.float32),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_transpose_ab_matmul_add_to_gemm_incompatible_shapes(self): + kwargs = { + "input_shape": ir.Shape((1, 256, 512)), + "weight_shape": ir.Shape((1, 64, 256)), + "transA": True, + "transB": True, + "permA": [0, 2, 1], + "permB": [0, 2, 1], + } + return super().check_matmul_add_to_gemm_incompatible_shapes(**kwargs) + + +if __name__ == "__main__": + unittest.main() From b5bd684c05b1376b386ff6ce8880096c7de45fab Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Fri, 13 Jun 2025 21:20:22 +0200 Subject: [PATCH 2/2] checker workaround --- onnxscript/rewriter/matmul_add_to_gemm.py | 4 ++-- onnxscript/rewriter/matmul_add_to_gemm_test.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/matmul_add_to_gemm.py b/onnxscript/rewriter/matmul_add_to_gemm.py index eb15296414..622b713d5c 100644 --- a/onnxscript/rewriter/matmul_add_to_gemm.py +++ b/onnxscript/rewriter/matmul_add_to_gemm.py @@ -7,13 +7,13 @@ - Add(MatMul(Transpose(X), Transpose(W)), B) -> Gemm """ -from abc import ABC +import abc from typing import ClassVar from onnxscript.rewriter import pattern as orp -class _MatMulAddToGemmBase(orp.RewriteRuleClassBase, ABC): +class _MatMulAddToGemmBase(orp.RewriteRuleClassBase, abc.ABC): trans_a: ClassVar = False trans_b: ClassVar = False diff --git a/onnxscript/rewriter/matmul_add_to_gemm_test.py b/onnxscript/rewriter/matmul_add_to_gemm_test.py index 937b3c7f1a..c06e834831 100644 --- a/onnxscript/rewriter/matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/matmul_add_to_gemm_test.py @@ -67,6 +67,7 @@ def get_test_model( b = ir.tensor(self.rng.uniform(-0.5, 0.5, bias_shape).astype("float32"), name="B") b = tape.initializer(b) + x_t, w_t = None, None if transA: x_t = tape.op("Transpose", inputs=[x], attributes={"perm": permA})