diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c3892c6cd3..14366fc825 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7861,25 +7861,18 @@ def aten_square(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::squeeze") +@torch_op("aten::squeeze", trace_only=True) def aten_squeeze(self: TTensor) -> TTensor: """squeeze(Tensor(a) self) -> Tensor(a)""" return op.Squeeze(self) -@torch_op("aten::squeeze.dim") +@torch_op("aten::squeeze.dim", trace_only=True) def aten_squeeze_dim(self: TTensor, dim: int) -> TTensor: - result = self - if Rank(self) > 0: # type: ignore[operator] - # check if specified dimension is 1, do squeeze - shape = op.Shape(self) - dim_size = op.Gather(shape, dim, axis=0) - if dim_size == 1: - dims = op.Reshape(dim, op.Constant(value_ints=[-1])) - result = op.Squeeze(self, dims) - - return result + if len(self.shape) == 0: + return self + return op.Squeeze(self, [dim]) @torch_op("aten::squeeze.dim", complex=True, trace_only=True) @@ -7888,6 +7881,9 @@ def aten_squeeze_dim_complex(self: TTensor, dim: int) -> TTensor: # Account for the complex dimension in ONNX dim = dim - 1 + if len(self.shape) == 1: + # The single dimension is the complex dimension + return self return aten_squeeze_dim(self, dim) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 1399264546..78d09c5f3c 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1443,17 +1443,29 @@ def _where_input_wrangler( TorchLibOpInfo( "squeeze_dim", core_ops.aten_squeeze_dim, - ).skip( + ) + .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", + ) + .skip( + matcher=lambda sample: len(sample.input.shape) != 0 + and sample.input.shape[sample.args[0]] != 1, + reason="this Aten overload only support squeeze dim with size 1", ), TorchLibOpInfo( "squeeze_dim", core_ops.aten_squeeze_dim_complex, complex=True, - ).skip( + ) + .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", + ) + .skip( + matcher=lambda sample: len(sample.input.shape) != 0 + and sample.input.shape[sample.args[0]] != 1, + reason="this Aten overload only support squeeze dim with size 1", ), TorchLibOpInfo( "squeeze",