Skip to content

Commit b6ec405

Browse files
authored
Decompose addmm with Gemm | feat(torchlib) (#1111)
Decompose addmm with Gemm by creating a special variant for `FLOAT` and conditionally check for the ranks if the input tensors. The if branch is expected to be folded away by constant folding passes. I have not found other instances where Gemm is used in the torch.onnx exporter. Fixes #1089
1 parent 9fb0a7d commit b6ec405

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed

onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class TestDeduceTypeConstraints(unittest.TestCase):
3030
"_aten_embedding_bag_onnx",
3131
"_aten_embedding_bag_1d_padding_idx_onnx",
3232
)
33-
_SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ("aten_all",)
3433

3534
@parameterized.parameterized.expand(
3635
((op,) for op in torch_lib_onnx_functions_from_registry()),
@@ -41,11 +40,13 @@ def test_deduce_type_constraints_does_not_crash_for_onnx_function(
4140
):
4241
if onnx_function.name in self._SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN:
4342
self.skipTest("Unimplemented: function contains loop or scan node.")
44-
if onnx_function.name in self._SKIP_FUNCTIONS_WITH_NESTED_FUNCTION:
45-
self.skipTest("Unimplemented: function contains nested function.")
46-
signature_type_constraint = deduce_type_constraints.deduce_type_constraints(
47-
onnx_function
48-
)
43+
try:
44+
signature_type_constraint = deduce_type_constraints.deduce_type_constraints(
45+
onnx_function
46+
)
47+
except NotImplementedError as e:
48+
if "Nested function" in str(e):
49+
self.skipTest("Unimplemented: function contains nested function.")
4950
logger.info(
5051
"Original signature: %s%s",
5152
onnx_function.name,

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ def aten_addcmul(
222222

223223
@torch_op("aten::addmm")
224224
def aten_addmm(
225-
self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
226-
) -> TReal:
225+
self: TInt, mat1: TInt, mat2: TInt, beta: float = 1.0, alpha: float = 1.0
226+
) -> TInt:
227227
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
228228

229229
mat1_mat2 = op.MatMul(mat1, mat2)
@@ -232,6 +232,29 @@ def aten_addmm(
232232
return op.Add(scaled_self, scaled_mat1_mat2)
233233

234234

235+
@torch_op("aten::addmm")
236+
def aten_addmm_gemm(
237+
self: TFloat, mat1: TFloat, mat2: TFloat, beta: float = 1.0, alpha: float = 1.0
238+
) -> TFloat:
239+
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
240+
241+
# A special case when rank of mat1 and mat2 are 2, we can use Gemm instead of MatMul
242+
# We expect the if branches to be folded away by optimization passes
243+
# TODO(#1110): Handle Gemm with a graph rewriting pass instead of hard coding the branching logic here
244+
use_gemm = op.And(
245+
op.Equal(Rank(mat1), op.Constant(value_int=2)),
246+
op.Equal(Rank(mat2), op.Constant(value_int=2)),
247+
)
248+
if use_gemm:
249+
result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
250+
else:
251+
mat1_mat2 = op.MatMul(mat1, mat2)
252+
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
253+
scaled_self = op.Mul(self, beta)
254+
result = op.Add(scaled_self, scaled_mat1_mat2)
255+
return result
256+
257+
235258
@torch_op("aten::addmv")
236259
def aten_addmv(
237260
self: TReal, mat: TReal, vec: TReal, beta: float = 1.0, alpha: float = 1.0
@@ -5235,7 +5258,6 @@ def aten_mm(
52355258
) -> TRealUnlessInt16OrInt8:
52365259
"""mm(Tensor self, Tensor mat2) -> Tensor"""
52375260

5238-
# TODO(justinchuby): Specify type conversion for uint8/int8/int16
52395261
return op.MatMul(self, mat2)
52405262

52415263

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,14 @@ def _where_input_wrangler(
489489
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
490490
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
491491
TorchLibOpInfo("addmm", core_ops.aten_addmm),
492+
TorchLibOpInfo("addmm_gemm", core_ops.aten_addmm_gemm).xfail(
493+
"decomposed",
494+
reason=(
495+
"The float attributes alpha/beta come in as int in the test cases, which breaks"
496+
"eager mode. We don't need to care about this as long as the full graph tests pass"
497+
),
498+
test_class_name="TestOutputConsistencyEager",
499+
),
492500
TorchLibOpInfo("addmv", core_ops.aten_addmv),
493501
TorchLibOpInfo(
494502
"addr",
@@ -1968,6 +1976,7 @@ def _where_input_wrangler(
19681976
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
19691977
)
19701978

1979+
ops_test_common.duplicate_opinfo(OPS_DB, "addmm", ("addmm_gemm",))
19711980
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",))
19721981
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",))
19731982
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))

0 commit comments

Comments
 (0)