Skip to content

[torchlib] Implement aten::_softmax to avoid decomposition in exporter #857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
BowenBao opened this issue Jul 11, 2023 · 4 comments
Closed
Labels
good first issue Good for newcomers module: torchlib Related to the torch/aten function lib in development

Comments

@BowenBao
Copy link
Contributor

BowenBao commented Jul 11, 2023

Update: the aten side op is aten::_softmax, instead of aten::softmax.

The generated subgraph is unnecessarily large because there is no OnnxFunction registered for aten::_softmax, so exporter decomposes. It should export as a single onnx::Softmax.

We should probably revisit and discover all aten ops that has close/1-to-1 mapping with ONNX ops, even if the aten op has decomp registered.

import torch
import onnxscript

def func(x):
    return x.softmax(dim=-1)

print(onnxscript.proto2text(torch.onnx.dynamo_export(func, torch.randn(3, 3)).model_proto))
<
   ir_version: 8,
   opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18],
   producer_name: "pytorch",
   producer_version: "2.1.0"
>
torch_jit (float[3,3] arg0) => (float[3,3] div) {
   1 = Constant <value = int64[1] {-1}> ()
   amax = pkg.onnxscript.torch_lib.aten_amax <keepdim = 1> (arg0, 1)
   sub = pkg.onnxscript.torch_lib.aten_sub <alpha = 1> (arg0, amax)
   exp = pkg.onnxscript.torch_lib.aten_exp (sub)
   5 = Constant <value = int64[1] {-1}> ()
   sum_1 = pkg.onnxscript.torch_lib._aten_sum_dim_onnx <keepdim = 1> (exp, 5)
   div = pkg.onnxscript.torch_lib.aten_div (exp, sum_1)
}
<
  domain: "pkg.onnxscript.torch_lib",
  opset_import: ["" : 18]
>
aten_amax (self, dim) => (return_val)
{
   return_val = ReduceMax <keepdims: int = @keepdim> (self, dim)
}
<
  domain: "pkg.onnxscript.torch_lib",
  opset_import: ["" : 18]
>
aten_sub (self, other) => (return_val)
{
   alpha = Constant <value_float: float = @alpha> ()
   alpha_0 = CastLike (alpha, other)
   other_1 = Mul (other, alpha_0)
   return_val = Sub (self, other_1)
}
<
  domain: "pkg.onnxscript.torch_lib",
  opset_import: ["" : 18]
>
aten_exp (self) => (return_val)
{
   return_val = Exp (self)
}
<
  domain: "pkg.onnxscript.torch_lib",
  opset_import: ["" : 18]
>
_aten_sum_dim_onnx (self, dim) => (result_16)
{
   tmp = Shape (self)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value = int64 tmp_1 {0}> ()
   tmp_1_cast = CastLike (tmp_1, tmp_0)
   self_is_scalar = Equal (tmp_0, tmp_1_cast)
   self_5 = If (self_is_scalar) <then_branch = thenGraph_4 () => ( self_3) {
      tmp_2 = Constant <value_ints = [-1]> ()
      self_3 = Reshape (self, tmp_2)
   }, else_branch = elseGraph_4 () => ( self_4) {
      self_4 = Identity (self)
   }>
   tmp_6 = Shape (dim)
   tmp_7 = Size (tmp_6)
   tmp_8 = Constant <value = int64 tmp_8 {0}> ()
   tmp_8_cast = CastLike (tmp_8, tmp_7)
   cond = Equal (tmp_7, tmp_8_cast)
   dim_13 = If (cond) <then_branch = thenGraph_7 () => ( dim_11) {
      tmp_9 = Constant <value_ints = [-1]> ()
      dim_10 = Reshape (dim, tmp_9)
      dim_11 = Cast <to = 7> (dim_10)
   }, else_branch = elseGraph_7 () => ( dim_12) {
      dim_12 = Identity (dim)
   }>
   result = ReduceSum <keepdims: int = @keepdim> (self_5, dim_13)
   result_16 = If (self_is_scalar) <then_branch = thenGraph_12 () => ( result_14) {
      result_14 = Squeeze (result)
   }, else_branch = elseGraph_12 () => ( result_15) {
      result_15 = Identity (result)
   }>
}
<
  domain: "pkg.onnxscript.torch_lib",
  opset_import: ["" : 18]
>
aten_div (self, other) => (return_val)
{
   return_val = Div (self, other)
}
@BowenBao BowenBao added the module: torchlib Related to the torch/aten function lib in development label Jul 11, 2023
@justinchuby justinchuby added good first issue Good for newcomers contribution welcome We welcome code contributions for this labels Jul 11, 2023
@BowenBao BowenBao changed the title [torchlib] Implement aten::softmax to avoid decomposition in exporter [torchlib] Implement aten::_softmax to avoid decomposition in exporter Aug 4, 2023
@justinchuby
Copy link
Collaborator

How is softmax different from _softmax?

@BowenBao
Copy link
Contributor Author

BowenBao commented Aug 4, 2023

_softmax has half_to_float which might be tricky for onnxscript.

@justinchuby
Copy link
Collaborator

- func: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor

@justinchuby justinchuby self-assigned this Aug 26, 2023
@justinchuby justinchuby removed the contribution welcome We welcome code contributions for this label Aug 26, 2023
@justinchuby
Copy link
Collaborator

#1024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

No branches or pull requests

2 participants