diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f430d77d48..afdeb56c12 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -67,6 +67,29 @@ def aten__local_scalar_dense_int(self: IntType) -> INT64: return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=INT64.dtype) +@torch_op("aten::_softmax", trace_only=True) +def aten__softmax_half(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT: + """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" + + # trace_only because we need to cast conditionally based on half_to_float + if half_to_float: + self = op.Cast(self, to=FLOAT.dtype) + + return aten_softmax_no_dtype(self, dim) + + +@torch_op("aten::_softmax", trace_only=True) +def aten__softmax( + self: TFloatHighPrecision, dim: int, half_to_float: bool +) -> TFloatHighPrecision: + """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" + + # trace_only to reuse aten_softmax_no_dtype + + del half_to_float # Unused + return aten_softmax_no_dtype(self, dim) + + @torch_op("aten::abs") def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8: """abs(Tensor self) -> Tensor""" @@ -6336,6 +6359,39 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType: raise NotImplementedError() +@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax")) +def aten_softmax( + self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype +) -> TFloatOrBFloat16: + """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" + + self_is_scalar = op.Size(op.Shape(self)) == 0 + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + result = op.Softmax(self, axis=dim) + result = op.Cast(result, to=dtype) + if self_is_scalar: + # Convert to scalar when input is scalar + result = op.Squeeze(result) + + return result + + +@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax")) +def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: + """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" + + self_is_scalar = op.Size(op.Shape(self)) == 0 + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + result = op.Softmax(self, axis=dim) + if self_is_scalar: + # Convert to scalar when input is scalar + result = op.Squeeze(result) + + return result + + def aten_sort( self: TensorType, dim: int = -1, descending: bool = False ) -> tuple[TensorType, TensorType]: diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index b7f70d2a99..5872fc3d17 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -341,23 +341,6 @@ def aten_special_sinc(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::softmax") -def aten_special_softmax( - self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype -) -> TFloatOrBFloat16: - """special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" - - self_is_scalar = op.Size(op.Shape(self)) == 0 - if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.Softmax(self, axis=dim) - result = op.Cast(result, to=dtype) - if self_is_scalar: # squeeze to scalar due to input is scalar - result = op.Squeeze(result) - - return result - - def aten_special_spherical_bessel_j0(x: TensorType) -> TensorType: """special_spherical_bessel_j0(Tensor x) -> Tensor""" diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index b53522e653..dd60688b16 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -780,6 +780,34 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(input_, args=(src, *args)) +def sample_inputs__softmax( + op_info, + device, + dtype, + requires_grad, + **kwargs, +): + del op_info # Unused + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + cases = [ + ((S,), (0,)), + ((S, S), (0,)), + ((S, S), (1,)), + ((S, S), (-1,)), + ((S, M, S), (2,)), + ((S, 0, 0), (-1,)), + ] + + for (shape, dim), half_to_float in itertools.product(cases, (False,)): + # NOTE: softmax with half to float conversion is not supported on CPU + # So we don't test it here + kwargs = dict(half_to_float=half_to_float) + yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs) + + # NOTE: How to create an OpInfo: # 1. Create a function that generates sample inputs for the op. # This function should yield SampleInputs. @@ -966,4 +994,11 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs) sample_inputs_func=sample_inputs_slice_scatter, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._softmax", + aten_name="_softmax", + dtypes=common_dtype.floating_types_and_half(), + sample_inputs_func=sample_inputs__softmax, + supports_out=False, + ), ] diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index e014df392c..207b8459e6 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -458,6 +458,13 @@ def _where_input_wrangler( "ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense, ), + TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax, trace_only=True), + TorchLibOpInfo( + "ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True + ).xfail( + reason="PyTorch does not implement _softmax for float16 on CPU", + dtypes=(torch.float16,), + ), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail( matcher=lambda sample: not (len(sample.kwargs) > 0), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design", @@ -1195,7 +1202,7 @@ def _where_input_wrangler( TorchLibOpInfo("sinh", core_ops.aten_sinh), TorchLibOpInfo( "softmax", - special_ops.aten_special_softmax, + core_ops.aten_softmax, tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (3e-4, 4e-4)}, ).xfail( variant_name="with_dtype", @@ -1937,6 +1944,7 @@ def _where_input_wrangler( "nn.functional.upsample_nearest3d", ), ) +ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction"))