From b6dc968c7fa8388bacf5768b6d616cd2616d901e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 24 Oct 2023 23:09:22 +0000 Subject: [PATCH 1/5] Decompose addmm with Gemm | feat(torchlib) --- .../function_libs/torch_lib/graph_building.py | 1 + .../function_libs/torch_lib/ops/core.py | 50 +++++++++++++++++-- .../function_libs/torch_lib/ops_test_data.py | 9 ++++ 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index b873d310f9..8b76690ff7 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -282,6 +282,7 @@ def eval_function( # type: ignore[override] param = name_to_schema[name] # Cast int to float if needed if param.type in {float, "float"}: + print(name, param.type) # FIXME(justinchuby): Create invariant on the type of param.type to simplify this attributes[name] = float(value) return self._graph.add_function_call(function, inputs, attributes) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6d770254e6..0bd40d4253 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -12,7 +12,7 @@ from __future__ import annotations import math -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, TypeVar, Union from onnxscript import ( BFLOAT16, @@ -50,6 +50,19 @@ from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType +TRealUnlessFloat32 = TypeVar( + "TRealUnlessFloat32", + bound=Union[ + BFLOAT16, + FLOAT16, + DOUBLE, + INT8, + INT16, + INT32, + INT64, + ], +) + _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi @@ -222,14 +235,42 @@ def aten_addcmul( @torch_op("aten::addmm") def aten_addmm( - self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0 -) -> TReal: + self: TRealUnlessFloat32, + mat1: TRealUnlessFloat32, + mat2: TRealUnlessFloat32, + beta: float = 1.0, + alpha: float = 1.0, +) -> TRealUnlessFloat32: """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor""" mat1_mat2 = op.MatMul(mat1, mat2) scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha) scaled_self = op.Mul(self, beta) - return op.Add(scaled_self, scaled_mat1_mat2) + result = op.Add(scaled_self, scaled_mat1_mat2) + return result + + +@torch_op("aten::addmm") +def aten_addmm_gemm( + self: FLOAT, mat1: FLOAT, mat2: FLOAT, beta: float = 1.0, alpha: float = 1.0 +) -> FLOAT: + """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor""" + + # A special case when rank of mat1 and mat2 are 2, we can use Gemm instead of MatMul + # We expect the if branches to be folded away by optimization passes + # TODO(#1110): Handle Gemm with a graph rewriting pass instead of hard coding the branching logic here + use_gemm = op.And( + op.Equal(Rank(mat1), op.Constant(value_int=2)), + op.Equal(Rank(mat2), op.Constant(value_int=2)), + ) + if use_gemm: + result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta) + else: + mat1_mat2 = op.MatMul(mat1, mat2) + scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha) + scaled_self = op.Mul(self, beta) + result = op.Add(scaled_self, scaled_mat1_mat2) + return result @torch_op("aten::addmv") @@ -5235,7 +5276,6 @@ def aten_mm( ) -> TRealUnlessInt16OrInt8: """mm(Tensor self, Tensor mat2) -> Tensor""" - # TODO(justinchuby): Specify type conversion for uint8/int8/int16 return op.MatMul(self, mat2) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 7c6a64b49b..7304aa8ba6 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -489,6 +489,14 @@ def _where_input_wrangler( TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv), TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}), TorchLibOpInfo("addmm", core_ops.aten_addmm), + TorchLibOpInfo("addmm_gemm", core_ops.aten_addmm_gemm).xfail( + "decomposed", + reason=( + "The float attributes alpha/beta come in as int in the test cases, which breaks" + "eager mode. We don't need to care about this as long as the full graph tests pass" + ), + test_class_name="TestOutputConsistencyEager", + ), TorchLibOpInfo("addmv", core_ops.aten_addmv), TorchLibOpInfo( "addr", @@ -1968,6 +1976,7 @@ def _where_input_wrangler( TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True), ) +ops_test_common.duplicate_opinfo(OPS_DB, "addmm", ("addmm_gemm",)) ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step")) From 9a8aba6068e830fa236684b4befbb374ec85c6b1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 24 Oct 2023 16:11:43 -0700 Subject: [PATCH 2/5] Apply suggestions from code review --- onnxscript/function_libs/torch_lib/graph_building.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index 8b76690ff7..b873d310f9 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -282,7 +282,6 @@ def eval_function( # type: ignore[override] param = name_to_schema[name] # Cast int to float if needed if param.type in {float, "float"}: - print(name, param.type) # FIXME(justinchuby): Create invariant on the type of param.type to simplify this attributes[name] = float(value) return self._graph.add_function_call(function, inputs, attributes) From d731ea58843122673a1965cc19603e289eaab370 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 24 Oct 2023 16:12:23 -0700 Subject: [PATCH 3/5] Update onnxscript/function_libs/torch_lib/ops/core.py --- onnxscript/function_libs/torch_lib/ops/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 0bd40d4253..28955aa234 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -246,8 +246,7 @@ def aten_addmm( mat1_mat2 = op.MatMul(mat1, mat2) scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha) scaled_self = op.Mul(self, beta) - result = op.Add(scaled_self, scaled_mat1_mat2) - return result + return op.Add(scaled_self, scaled_mat1_mat2) @torch_op("aten::addmm") From 0b298306ff6b7e907da1a9913749a178e279b43f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 25 Oct 2023 15:59:59 +0000 Subject: [PATCH 4/5] Skip test --- .../tools/torch_lib/deduce_type_constraints_test.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index a2882d283e..4e01d37acc 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -30,7 +30,6 @@ class TestDeduceTypeConstraints(unittest.TestCase): "_aten_embedding_bag_onnx", "_aten_embedding_bag_1d_padding_idx_onnx", ) - _SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ("aten_all",) @parameterized.parameterized.expand( ((op,) for op in torch_lib_onnx_functions_from_registry()), @@ -41,11 +40,13 @@ def test_deduce_type_constraints_does_not_crash_for_onnx_function( ): if onnx_function.name in self._SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN: self.skipTest("Unimplemented: function contains loop or scan node.") - if onnx_function.name in self._SKIP_FUNCTIONS_WITH_NESTED_FUNCTION: - self.skipTest("Unimplemented: function contains nested function.") - signature_type_constraint = deduce_type_constraints.deduce_type_constraints( - onnx_function - ) + try: + signature_type_constraint = deduce_type_constraints.deduce_type_constraints( + onnx_function + ) + except NotImplementedError as e: + if "Nested function" in str(e): + self.skipTest("Unimplemented: function contains nested function.") logger.info( "Original signature: %s%s", onnx_function.name, From b448bc0764fee6c361ea25dd29c3e952e9babc60 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 25 Oct 2023 16:03:14 +0000 Subject: [PATCH 5/5] Simplify --- .../function_libs/torch_lib/ops/core.py | 27 ++++--------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 28955aa234..775e149da3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -12,7 +12,7 @@ from __future__ import annotations import math -from typing import Any, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Optional, Sequence, Tuple, Union from onnxscript import ( BFLOAT16, @@ -50,19 +50,6 @@ from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType -TRealUnlessFloat32 = TypeVar( - "TRealUnlessFloat32", - bound=Union[ - BFLOAT16, - FLOAT16, - DOUBLE, - INT8, - INT16, - INT32, - INT64, - ], -) - _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi @@ -235,12 +222,8 @@ def aten_addcmul( @torch_op("aten::addmm") def aten_addmm( - self: TRealUnlessFloat32, - mat1: TRealUnlessFloat32, - mat2: TRealUnlessFloat32, - beta: float = 1.0, - alpha: float = 1.0, -) -> TRealUnlessFloat32: + self: TInt, mat1: TInt, mat2: TInt, beta: float = 1.0, alpha: float = 1.0 +) -> TInt: """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor""" mat1_mat2 = op.MatMul(mat1, mat2) @@ -251,8 +234,8 @@ def aten_addmm( @torch_op("aten::addmm") def aten_addmm_gemm( - self: FLOAT, mat1: FLOAT, mat2: FLOAT, beta: float = 1.0, alpha: float = 1.0 -) -> FLOAT: + self: TFloat, mat1: TFloat, mat2: TFloat, beta: float = 1.0, alpha: float = 1.0 +) -> TFloat: """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor""" # A special case when rank of mat1 and mat2 are 2, we can use Gemm instead of MatMul