Skip to content

Implement aten::_softmax | feat(torchlib) #1024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 29, 2023
Merged

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Aug 25, 2023

Move aten_special_softmax -> aten_softmax, support alias and when dtype input is None. Implement _softmax using softmax.

An alternative would be to create a decomp in pytorch from _softmax to softmax. But since pytorch optimization tends to create decomp to private ops (and not the other way around) in different optimizations steps as @BowenBao pointed out, we still need to support _softmax here.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Aug 25, 2023
@codecov
Copy link

codecov bot commented Aug 25, 2023

Codecov Report

Merging #1024 (a1368e4) into main (74c068c) will decrease coverage by 0.01%.
The diff coverage is 80.55%.

@@            Coverage Diff             @@
##             main    #1024      +/-   ##
==========================================
- Coverage   77.51%   77.50%   -0.01%     
==========================================
  Files         114      114              
  Lines       14309    14335      +26     
  Branches     1521     1525       +4     
==========================================
+ Hits        11091    11110      +19     
- Misses       2854     2859       +5     
- Partials      364      366       +2     
Files Changed Coverage Δ
onnxscript/function_libs/torch_lib/ops/special.py 61.94% <ø> (-2.65%) ⬇️
onnxscript/function_libs/torch_lib/ops/core.py 78.44% <75.00%> (-0.04%) ⬇️
...ript/tests/function_libs/torch_lib/extra_opinfo.py 98.03% <100.00%> (+0.05%) ⬆️
...ipt/tests/function_libs/torch_lib/ops_test_data.py 96.00% <100.00%> (+0.01%) ⬆️


@torch_op("aten::_softmax", trace_only=True)
def aten__softmax(
self: TFloatHighPrecision, dim: int, half_to_float: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does these 2 _softmax() functions has different dtypes for self?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because we want to cast float16 to float32 when half_to_float is true. Since we have no idea what the input dtype will be within the function body, we rely on the dispatcher to pick the function that already limits the input dtype it can accept to know we are dealing with float16 types.

@github-actions
Copy link

github-actions bot commented Aug 25, 2023

Test Results

         18 files  ±    0         18 suites  ±0   1h 18m 35s ⏱️ - 9m 21s
  10 566 tests +  11    7 848 ✔️ +  6      2 714 💤 +  6  3 ±0  1 🔥  - 1 
156 089 runs  +111  36 576 ✔️ +51  119 503 💤 +61  9 ±0  1 🔥  - 1 

For more details on these failures and errors, see this check.

Results for commit a1368e4. ± Comparison against base commit 74c068c.

This pull request removes 266 and adds 277 tests. Note that renamed tests count towards both.
onnxscript.function_libs.tools.torch_lib.deduce_type_constraints_test.TestDeduceTypeConstraints ‑ test_deduce_type_constraints_does_not_crash_for_onnx_function_aten_special_softmax
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_001_aten_all_dim
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_002_aten_allclose
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_003_aten_all
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_004_aten_abs
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_005_aten_abs_complex
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_006_aten_acos
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_007_aten_acosh
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_008_aten_add
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_009_aten_addbmm
…
onnxscript.function_libs.tools.torch_lib.deduce_type_constraints_test.TestDeduceTypeConstraints ‑ test_deduce_type_constraints_does_not_crash_for_onnx_function_aten_softmax
onnxscript.function_libs.tools.torch_lib.deduce_type_constraints_test.TestDeduceTypeConstraints ‑ test_deduce_type_constraints_does_not_crash_for_onnx_function_aten_softmax_no_dtype
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_001_aten__softmax
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_002_aten__softmax_half
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_003_aten_all_dim
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_004_aten_allclose
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_005_aten_all
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_006_aten_abs
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_007_aten_abs_complex
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_008_aten_acos
…

♻️ This comment has been updated with latest results.

@justinchuby justinchuby marked this pull request as draft August 28, 2023 15:47
@justinchuby justinchuby marked this pull request as ready for review August 28, 2023 19:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants