Skip to content

Implement aten::_softmax | feat(torchlib) #1024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,29 @@ def aten__local_scalar_dense_int(self: IntType) -> INT64:
return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=INT64.dtype)


@torch_op("aten::_softmax", trace_only=True)
def aten__softmax_half(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT:
"""_softmax(Tensor self, int dim, bool half_to_float) -> Tensor"""

# trace_only because we need to cast conditionally based on half_to_float
if half_to_float:
self = op.Cast(self, to=FLOAT.dtype)

return aten_softmax_no_dtype(self, dim)


@torch_op("aten::_softmax", trace_only=True)
def aten__softmax(
self: TFloatHighPrecision, dim: int, half_to_float: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does these 2 _softmax() functions has different dtypes for self?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because we want to cast float16 to float32 when half_to_float is true. Since we have no idea what the input dtype will be within the function body, we rely on the dispatcher to pick the function that already limits the input dtype it can accept to know we are dealing with float16 types.

) -> TFloatHighPrecision:
"""_softmax(Tensor self, int dim, bool half_to_float) -> Tensor"""

# trace_only to reuse aten_softmax_no_dtype

del half_to_float # Unused
return aten_softmax_no_dtype(self, dim)


@torch_op("aten::abs")
def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""
Expand Down Expand Up @@ -6336,6 +6359,39 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"))
def aten_softmax(
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
) -> TFloatOrBFloat16:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = op.Size(op.Shape(self)) == 0
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.Softmax(self, axis=dim)
result = op.Cast(result, to=dtype)
if self_is_scalar:
# Convert to scalar when input is scalar
result = op.Squeeze(result)

return result


@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"))
def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = op.Size(op.Shape(self)) == 0
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.Softmax(self, axis=dim)
if self_is_scalar:
# Convert to scalar when input is scalar
result = op.Squeeze(result)

return result


def aten_sort(
self: TensorType, dim: int = -1, descending: bool = False
) -> tuple[TensorType, TensorType]:
Expand Down
17 changes: 0 additions & 17 deletions onnxscript/function_libs/torch_lib/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,23 +341,6 @@ def aten_special_sinc(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::softmax")
def aten_special_softmax(
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
) -> TFloatOrBFloat16:
"""special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = op.Size(op.Shape(self)) == 0
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.Softmax(self, axis=dim)
result = op.Cast(result, to=dtype)
if self_is_scalar: # squeeze to scalar due to input is scalar
result = op.Squeeze(result)

return result


def aten_special_spherical_bessel_j0(x: TensorType) -> TensorType:
"""special_spherical_bessel_j0(Tensor x) -> Tensor"""

Expand Down
35 changes: 35 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,34 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs)
yield opinfo_core.SampleInput(input_, args=(src, *args))


def sample_inputs__softmax(
op_info,
device,
dtype,
requires_grad,
**kwargs,
):
del op_info # Unused

make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
cases = [
((S,), (0,)),
((S, S), (0,)),
((S, S), (1,)),
((S, S), (-1,)),
((S, M, S), (2,)),
((S, 0, 0), (-1,)),
]

for (shape, dim), half_to_float in itertools.product(cases, (False,)):
# NOTE: softmax with half to float conversion is not supported on CPU
# So we don't test it here
kwargs = dict(half_to_float=half_to_float)
yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs)


# NOTE: How to create an OpInfo:
# 1. Create a function that generates sample inputs for the op.
# This function should yield SampleInputs.
Expand Down Expand Up @@ -966,4 +994,11 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs)
sample_inputs_func=sample_inputs_slice_scatter,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._softmax",
aten_name="_softmax",
dtypes=common_dtype.floating_types_and_half(),
sample_inputs_func=sample_inputs__softmax,
supports_out=False,
),
]
10 changes: 9 additions & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,13 @@ def _where_input_wrangler(
"ops.aten._local_scalar_dense",
core_ops.aten__local_scalar_dense,
),
TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax, trace_only=True),
TorchLibOpInfo(
"ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True
).xfail(
reason="PyTorch does not implement _softmax for float16 on CPU",
dtypes=(torch.float16,),
),
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail(
matcher=lambda sample: not (len(sample.kwargs) > 0),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
Expand Down Expand Up @@ -1195,7 +1202,7 @@ def _where_input_wrangler(
TorchLibOpInfo("sinh", core_ops.aten_sinh),
TorchLibOpInfo(
"softmax",
special_ops.aten_special_softmax,
core_ops.aten_softmax,
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (3e-4, 4e-4)},
).xfail(
variant_name="with_dtype",
Expand Down Expand Up @@ -1937,6 +1944,7 @@ def _where_input_wrangler(
"nn.functional.upsample_nearest3d",
),
)
ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",))
ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",))
ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction"))
Expand Down