Skip to content

Commit 7d43f92

Browse files
authored
[torchlib] Simplify squeeze (#2047)
It was too complicated
1 parent 9245ea2 commit 7d43f92

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7861,25 +7861,18 @@ def aten_square(self: TensorType) -> TensorType:
78617861
raise NotImplementedError()
78627862

78637863

7864-
@torch_op("aten::squeeze")
7864+
@torch_op("aten::squeeze", trace_only=True)
78657865
def aten_squeeze(self: TTensor) -> TTensor:
78667866
"""squeeze(Tensor(a) self) -> Tensor(a)"""
78677867

78687868
return op.Squeeze(self)
78697869

78707870

7871-
@torch_op("aten::squeeze.dim")
7871+
@torch_op("aten::squeeze.dim", trace_only=True)
78727872
def aten_squeeze_dim(self: TTensor, dim: int) -> TTensor:
7873-
result = self
7874-
if Rank(self) > 0: # type: ignore[operator]
7875-
# check if specified dimension is 1, do squeeze
7876-
shape = op.Shape(self)
7877-
dim_size = op.Gather(shape, dim, axis=0)
7878-
if dim_size == 1:
7879-
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
7880-
result = op.Squeeze(self, dims)
7881-
7882-
return result
7873+
if len(self.shape) == 0:
7874+
return self
7875+
return op.Squeeze(self, [dim])
78837876

78847877

78857878
@torch_op("aten::squeeze.dim", complex=True, trace_only=True)
@@ -7888,6 +7881,9 @@ def aten_squeeze_dim_complex(self: TTensor, dim: int) -> TTensor:
78887881
# Account for the complex dimension in ONNX
78897882
dim = dim - 1
78907883

7884+
if len(self.shape) == 1:
7885+
# The single dimension is the complex dimension
7886+
return self
78917887
return aten_squeeze_dim(self, dim)
78927888

78937889

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,17 +1443,29 @@ def _where_input_wrangler(
14431443
TorchLibOpInfo(
14441444
"squeeze_dim",
14451445
core_ops.aten_squeeze_dim,
1446-
).skip(
1446+
)
1447+
.skip(
14471448
matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)),
14481449
reason="this Aten overload only support one tensor as input and one int as args by design",
1450+
)
1451+
.skip(
1452+
matcher=lambda sample: len(sample.input.shape) != 0
1453+
and sample.input.shape[sample.args[0]] != 1,
1454+
reason="this Aten overload only support squeeze dim with size 1",
14491455
),
14501456
TorchLibOpInfo(
14511457
"squeeze_dim",
14521458
core_ops.aten_squeeze_dim_complex,
14531459
complex=True,
1454-
).skip(
1460+
)
1461+
.skip(
14551462
matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)),
14561463
reason="this Aten overload only support one tensor as input and one int as args by design",
1464+
)
1465+
.skip(
1466+
matcher=lambda sample: len(sample.input.shape) != 0
1467+
and sample.input.shape[sample.args[0]] != 1,
1468+
reason="this Aten overload only support squeeze dim with size 1",
14571469
),
14581470
TorchLibOpInfo(
14591471
"squeeze",

0 commit comments

Comments
 (0)