Skip to content

Commit a6350f1

Browse files
authored
Revert "feat(atenlib): add ops(squeeze) (#460)"
This reverts commit 87226d2.
1 parent fb9c3a5 commit a6350f1

File tree

2 files changed

+2
-20
lines changed

2 files changed

+2
-20
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5053,20 +5053,10 @@ def aten_square(self: TensorType) -> TensorType:
50535053
raise NotImplementedError()
50545054

50555055

5056-
@torch_op("aten::squeeze", trace_only=True)
5057-
def aten_squeeze(self: TTensor, dim: Optional[int] = None) -> TTensor:
5056+
def aten_squeeze(self: TensorType) -> TensorType:
50585057
"""squeeze(Tensor(a) self) -> Tensor(a)"""
50595058

5060-
if op.OptionalHasElement(dim):
5061-
rank = op.Size(op.Shape(self))
5062-
if rank == 0:
5063-
self = op.Reshape(self, op.Constant(value_ints=[-1]))
5064-
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
5065-
result = op.Squeeze(self, dims)
5066-
else:
5067-
result = op.Squeeze(self)
5068-
5069-
return result
5059+
raise NotImplementedError()
50705060

50715061

50725062
def aten_squeeze_copy(self: TensorType) -> TensorType:

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ def _where_input_wrangler(
412412
),
413413
"ones_like": core_ops.aten_ones_like,
414414
"slice": core_ops.aten_slice,
415-
"squeeze": core_ops.aten_squeeze,
416415
"sum": (core_ops.aten_sum_dim_IntList, _sum_input_wrangler),
417416
"transpose": core_ops.aten_transpose,
418417
"zeros_like": core_ops.aten_zeros_like,
@@ -557,13 +556,6 @@ def _where_input_wrangler(
557556
matcher=lambda sample: len(sample.args[0]) == 0,
558557
reason="Empty perm is not supported",
559558
),
560-
skip(
561-
"squeeze",
562-
matcher=lambda sample: len(sample.args) > 0
563-
and len(sample.input.shape) > 0
564-
and sample.input.shape[sample.args[0]] != 1,
565-
reason="Cannot select an axis to squeeze out which has size not equal to one",
566-
),
567559
)
568560

569561
duplicate_opinfo(

0 commit comments

Comments
 (0)