From d2817bbbb7e81471fd4cd3bb2aaf1ee5ba4da883 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 25 Aug 2023 05:36:44 +0000 Subject: [PATCH 1/5] Implement `aten::_softmax` | feat(torchlib) --- .../function_libs/torch_lib/ops/core.py | 54 +++++++++++++++++++ .../function_libs/torch_lib/ops/special.py | 17 ------ .../function_libs/torch_lib/ops_test_data.py | 2 +- 3 files changed, 55 insertions(+), 18 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 76d1de9a1a..d782066208 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -67,6 +67,29 @@ def aten__local_scalar_dense_int(self: IntType) -> INT64: return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=INT64.dtype) +@torch_op("aten::_softmax", trace_only=True) +def aten__softmax(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT: + """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" + + # trace_only because we need to cast conditionally based on half_to_float + if half_to_float: + self = op.Cast(self, to=FLOAT.dtype) + + return aten_softmax_no_dtype(self, dim) + + +@torch_op("aten::_softmax", trace_only=True) +def aten__softmax( + self: TFloatHighPrecision, dim: int, half_to_float: bool +) -> TFloatHighPrecision: + """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" + + # trace_only to reuse aten_softmax_no_dtype + + del half_to_float # Unused + return aten_softmax_no_dtype(self, dim) + + @torch_op("aten::abs") def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8: """abs(Tensor self) -> Tensor""" @@ -6324,6 +6347,37 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType: raise NotImplementedError() +@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax")) +def aten_softmax( + self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype +) -> TFloatOrBFloat16: + """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" + + self_is_scalar = op.Size(op.Shape(self)) == 0 + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + result = op.Softmax(self, axis=dim) + result = op.Cast(result, to=dtype) + if self_is_scalar: # squeeze to scalar due to input is scalar + result = op.Squeeze(result) + + return result + + +@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax")) +def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: + """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" + + self_is_scalar = op.Size(op.Shape(self)) == 0 + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + result = op.Softmax(self, axis=dim) + if self_is_scalar: # squeeze to scalar due to input is scalar + result = op.Squeeze(result) + + return result + + def aten_sort( self: TensorType, dim: int = -1, descending: bool = False ) -> tuple[TensorType, TensorType]: diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index b7f70d2a99..5872fc3d17 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -341,23 +341,6 @@ def aten_special_sinc(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::softmax") -def aten_special_softmax( - self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype -) -> TFloatOrBFloat16: - """special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" - - self_is_scalar = op.Size(op.Shape(self)) == 0 - if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.Softmax(self, axis=dim) - result = op.Cast(result, to=dtype) - if self_is_scalar: # squeeze to scalar due to input is scalar - result = op.Squeeze(result) - - return result - - def aten_special_spherical_bessel_j0(x: TensorType) -> TensorType: """special_spherical_bessel_j0(Tensor x) -> Tensor""" diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 7d8a544446..12c790f23f 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1197,7 +1197,7 @@ def _where_input_wrangler( TorchLibOpInfo("sinh", core_ops.aten_sinh), TorchLibOpInfo( "softmax", - special_ops.aten_special_softmax, + special_ops.aten_softmax, tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (3e-4, 4e-4)}, ).xfail( variant_name="with_dtype", From 0bd49339f986535801fcb019f8e015876df24efc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 24 Aug 2023 22:40:33 -0700 Subject: [PATCH 2/5] Update ops_test_data.py --- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 12c790f23f..5324c28633 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1197,7 +1197,7 @@ def _where_input_wrangler( TorchLibOpInfo("sinh", core_ops.aten_sinh), TorchLibOpInfo( "softmax", - special_ops.aten_softmax, + core_ops.aten_softmax, tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (3e-4, 4e-4)}, ).xfail( variant_name="with_dtype", From fe55ce9e6de2d2a10af563b68c50793a02ab7d11 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 25 Aug 2023 21:51:19 +0000 Subject: [PATCH 3/5] aten__softmax_half --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d782066208..63291bc675 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -68,7 +68,7 @@ def aten__local_scalar_dense_int(self: IntType) -> INT64: @torch_op("aten::_softmax", trace_only=True) -def aten__softmax(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT: +def aten__softmax_half(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT: """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" # trace_only because we need to cast conditionally based on half_to_float From 215cd71ffe3b4a3e3bfae5919257d1b2920a10fe Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Aug 2023 19:50:09 +0000 Subject: [PATCH 4/5] Add tests --- .../function_libs/torch_lib/extra_opinfo.py | 35 +++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 8 +++++ 2 files changed, 43 insertions(+) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index b53522e653..dd60688b16 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -780,6 +780,34 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(input_, args=(src, *args)) +def sample_inputs__softmax( + op_info, + device, + dtype, + requires_grad, + **kwargs, +): + del op_info # Unused + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + cases = [ + ((S,), (0,)), + ((S, S), (0,)), + ((S, S), (1,)), + ((S, S), (-1,)), + ((S, M, S), (2,)), + ((S, 0, 0), (-1,)), + ] + + for (shape, dim), half_to_float in itertools.product(cases, (False,)): + # NOTE: softmax with half to float conversion is not supported on CPU + # So we don't test it here + kwargs = dict(half_to_float=half_to_float) + yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs) + + # NOTE: How to create an OpInfo: # 1. Create a function that generates sample inputs for the op. # This function should yield SampleInputs. @@ -966,4 +994,11 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs) sample_inputs_func=sample_inputs_slice_scatter, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._softmax", + aten_name="_softmax", + dtypes=common_dtype.floating_types_and_half(), + sample_inputs_func=sample_inputs__softmax, + supports_out=False, + ), ] diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 5324c28633..7a5aa68c47 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -458,6 +458,13 @@ def _where_input_wrangler( "ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense, ), + TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax, trace_only=True), + TorchLibOpInfo( + "ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True + ).xfail( + reason="PyTorch does not implement _softmax for float16 on CPU", + dtypes=(torch.float16,), + ), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail( matcher=lambda sample: not (len(sample.kwargs) > 0), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design", @@ -1939,6 +1946,7 @@ def _where_input_wrangler( "nn.functional.upsample_nearest3d", ), ) +ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",)) From a1368e461077821e8877432d44f55ee99bdfba81 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Aug 2023 19:55:06 +0000 Subject: [PATCH 5/5] Wording --- onnxscript/function_libs/torch_lib/ops/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 7e75ffad24..afdeb56c12 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6370,7 +6370,8 @@ def aten_softmax( self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.Softmax(self, axis=dim) result = op.Cast(result, to=dtype) - if self_is_scalar: # squeeze to scalar due to input is scalar + if self_is_scalar: + # Convert to scalar when input is scalar result = op.Squeeze(result) return result @@ -6384,7 +6385,8 @@ def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.Softmax(self, axis=dim) - if self_is_scalar: # squeeze to scalar due to input is scalar + if self_is_scalar: + # Convert to scalar when input is scalar result = op.Squeeze(result) return result