Skip to content

Decompose addmm with Gemm | feat(torchlib) #1111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 25, 2023
Merged

Decompose addmm with Gemm | feat(torchlib) #1111

merged 5 commits into from
Oct 25, 2023

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Oct 24, 2023

Decompose addmm with Gemm by creating a special variant for FLOAT and conditionally check for the ranks if the input tensors. The if branch is expected to be folded away by constant folding passes.

I have not found other instances where Gemm is used in the torch.onnx exporter.

Fixes #1089

cc @baijumeswani

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Oct 24, 2023
@codecov
Copy link

codecov bot commented Oct 24, 2023

Codecov Report

Merging #1111 (b448bc0) into main (9fb0a7d) will increase coverage by 0.02%.
The diff coverage is 62.50%.

@@            Coverage Diff             @@
##             main    #1111      +/-   ##
==========================================
+ Coverage   78.13%   78.16%   +0.02%     
==========================================
  Files         117      117              
  Lines       14954    14966      +12     
  Branches     1585     1586       +1     
==========================================
+ Hits        11685    11698      +13     
  Misses       2900     2900              
+ Partials      369      368       -1     
Files Coverage Δ
...ipt/tests/function_libs/torch_lib/ops_test_data.py 96.15% <100.00%> (+0.01%) ⬆️
...bs/tools/torch_lib/deduce_type_constraints_test.py 90.32% <80.00%> (-3.02%) ⬇️
onnxscript/function_libs/torch_lib/ops/core.py 79.61% <50.00%> (-0.10%) ⬇️

... and 4 files with indirect coverage changes

@BowenBao
Copy link
Contributor

aten::mm is also exported as Gemm by torchscript exporter. Although I'm not sure if it is better to export as Gemm or Matmul in this case.

@@ -232,6 +249,29 @@ def aten_addmm(
return op.Add(scaled_self, scaled_mat1_mat2)


@torch_op("aten::addmm")
def aten_addmm_gemm(
self: FLOAT, mat1: FLOAT, mat2: FLOAT, beta: float = 1.0, alpha: float = 1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only do 'gemm' variant on float inputs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ORT implements Gemm for float inputs only. So I set it accordingly for practicality

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to TFloat because it supports float16 as well

@github-actions
Copy link

github-actions bot commented Oct 24, 2023

Test Results

         18 files  ±         0         18 suites  ±0   1h 45m 59s ⏱️ + 21m 35s
  11 101 tests +       19    8 331 ✔️ ±       0      2 732 💤  -          1       37 +  19  1 🔥 +1 
172 643 runs  +16 742  39 714 ✔️ +3 925  131 074 💤 +12 429  1 854 +387  1 🔥 +1 

For more details on these failures and errors, see this check.

Results for commit b448bc0. ± Comparison against base commit 9fb0a7d.

This pull request removes 520 and adds 539 tests. Note that renamed tests count towards both.
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_017_aten_addmv
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_018_aten_addr
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_019_aten_amax
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_020_aten_amin
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_021_aten_any
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_022_aten_any_dim
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_023_aten_asin
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_024_aten_asinh
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_025_aten_atan
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_026_aten_atan2
…
onnxscript.function_libs.tools.torch_lib.deduce_type_constraints_test.TestDeduceTypeConstraints ‑ test_deduce_type_constraints_does_not_crash_for_onnx_function_aten_addmm_gemm
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_017_aten_addmm_gemm
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_018_aten_addmv
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_019_aten_addr
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_020_aten_amax
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_021_aten_amin
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_022_aten_any
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_023_aten_any_dim
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_024_aten_asin
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_025_aten_asinh
…
This pull request skips 8 and un-skips 20 tests.
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__addmm_cpu_float16
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__addmm_cpu_float32
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__addmm_decomposed_cpu_float16
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__addmm_decomposed_cpu_float32
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU ‑ test_output_match_opinfo__addmm_cpu_float16
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU ‑ test_output_match_opinfo__addmm_cpu_float32
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU ‑ test_output_match_opinfo__addmm_decomposed_cpu_float16
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU ‑ test_output_match_opinfo__addmm_decomposed_cpu_float32
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__all_dim_cpu_bool
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__all_dim_cpu_float16
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__all_dim_cpu_float32
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__all_dim_cpu_int32
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__all_dim_cpu_int64
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__any_dim_cpu_bool
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__any_dim_cpu_float16
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__any_dim_cpu_float32
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__any_dim_cpu_int32
onnxscript.tests.function_libs.torch_lib.ops_test.TestOutputConsistencyEagerCPU ‑ test_output_match_opinfo__any_dim_cpu_int64
…

♻️ This comment has been updated with latest results.

@justinchuby
Copy link
Collaborator Author

aten::mm is also exported as Gemm by torchscript exporter. Although I'm not sure if it is better to export as Gemm or Matmul in this case.

Yeah Gemm seems to be more specialized for that case. Backends should just do the right thing for matmul imo

@justinchuby justinchuby requested a review from BowenBao October 25, 2023 16:00
Copy link
Contributor

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@justinchuby justinchuby merged commit b6ec405 into main Oct 25, 2023
@justinchuby justinchuby deleted the justinchu/gemm branch October 25, 2023 17:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Linear from PyTorch must map to Gemm in ONNX
2 participants