Skip to content

Decompose addmm with Gemm | feat(torchlib) #1111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class TestDeduceTypeConstraints(unittest.TestCase):
"_aten_embedding_bag_onnx",
"_aten_embedding_bag_1d_padding_idx_onnx",
)
_SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ("aten_all",)

@parameterized.parameterized.expand(
((op,) for op in torch_lib_onnx_functions_from_registry()),
Expand All @@ -41,11 +40,13 @@ def test_deduce_type_constraints_does_not_crash_for_onnx_function(
):
if onnx_function.name in self._SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN:
self.skipTest("Unimplemented: function contains loop or scan node.")
if onnx_function.name in self._SKIP_FUNCTIONS_WITH_NESTED_FUNCTION:
self.skipTest("Unimplemented: function contains nested function.")
signature_type_constraint = deduce_type_constraints.deduce_type_constraints(
onnx_function
)
try:
signature_type_constraint = deduce_type_constraints.deduce_type_constraints(
onnx_function
)
except NotImplementedError as e:
if "Nested function" in str(e):
self.skipTest("Unimplemented: function contains nested function.")
logger.info(
"Original signature: %s%s",
onnx_function.name,
Expand Down
28 changes: 25 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ def aten_addcmul(

@torch_op("aten::addmm")
def aten_addmm(
self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
self: TInt, mat1: TInt, mat2: TInt, beta: float = 1.0, alpha: float = 1.0
) -> TInt:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

mat1_mat2 = op.MatMul(mat1, mat2)
Expand All @@ -232,6 +232,29 @@ def aten_addmm(
return op.Add(scaled_self, scaled_mat1_mat2)


@torch_op("aten::addmm")
def aten_addmm_gemm(
self: TFloat, mat1: TFloat, mat2: TFloat, beta: float = 1.0, alpha: float = 1.0
) -> TFloat:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

# A special case when rank of mat1 and mat2 are 2, we can use Gemm instead of MatMul
# We expect the if branches to be folded away by optimization passes
# TODO(#1110): Handle Gemm with a graph rewriting pass instead of hard coding the branching logic here
use_gemm = op.And(
op.Equal(Rank(mat1), op.Constant(value_int=2)),
op.Equal(Rank(mat2), op.Constant(value_int=2)),
)
if use_gemm:
result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
else:
mat1_mat2 = op.MatMul(mat1, mat2)
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
scaled_self = op.Mul(self, beta)
result = op.Add(scaled_self, scaled_mat1_mat2)
return result


@torch_op("aten::addmv")
def aten_addmv(
self: TReal, mat: TReal, vec: TReal, beta: float = 1.0, alpha: float = 1.0
Expand Down Expand Up @@ -5235,7 +5258,6 @@ def aten_mm(
) -> TRealUnlessInt16OrInt8:
"""mm(Tensor self, Tensor mat2) -> Tensor"""

# TODO(justinchuby): Specify type conversion for uint8/int8/int16
return op.MatMul(self, mat2)


Expand Down
9 changes: 9 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,14 @@ def _where_input_wrangler(
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
TorchLibOpInfo("addmm", core_ops.aten_addmm),
TorchLibOpInfo("addmm_gemm", core_ops.aten_addmm_gemm).xfail(
"decomposed",
reason=(
"The float attributes alpha/beta come in as int in the test cases, which breaks"
"eager mode. We don't need to care about this as long as the full graph tests pass"
),
test_class_name="TestOutputConsistencyEager",
),
TorchLibOpInfo("addmv", core_ops.aten_addmv),
TorchLibOpInfo(
"addr",
Expand Down Expand Up @@ -1968,6 +1976,7 @@ def _where_input_wrangler(
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
)

ops_test_common.duplicate_opinfo(OPS_DB, "addmm", ("addmm_gemm",))
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))
Expand Down