Skip to content

Commit 87226d2

Browse files
authored
feat(atenlib): add ops(squeeze) (#460)
1 parent 925e81c commit 87226d2

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5053,10 +5053,20 @@ def aten_square(self: TensorType) -> TensorType:
50535053
raise NotImplementedError()
50545054

50555055

5056-
def aten_squeeze(self: TensorType) -> TensorType:
5056+
@torch_op("aten::squeeze", trace_only=True)
5057+
def aten_squeeze(self: TTensor, dim: Optional[int] = None) -> TTensor:
50575058
"""squeeze(Tensor(a) self) -> Tensor(a)"""
50585059

5059-
raise NotImplementedError()
5060+
if op.OptionalHasElement(dim):
5061+
rank = op.Size(op.Shape(self))
5062+
if rank == 0:
5063+
self = op.Reshape(self, op.Constant(value_ints=[-1]))
5064+
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
5065+
result = op.Squeeze(self, dims)
5066+
else:
5067+
result = op.Squeeze(self)
5068+
5069+
return result
50605070

50615071

50625072
def aten_squeeze_copy(self: TensorType) -> TensorType:

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def _where_input_wrangler(
412412
),
413413
"ones_like": core_ops.aten_ones_like,
414414
"slice": core_ops.aten_slice,
415+
"squeeze": core_ops.aten_squeeze,
415416
"sum": (core_ops.aten_sum_dim_IntList, _sum_input_wrangler),
416417
"transpose": core_ops.aten_transpose,
417418
"zeros_like": core_ops.aten_zeros_like,
@@ -556,6 +557,13 @@ def _where_input_wrangler(
556557
matcher=lambda sample: len(sample.args[0]) == 0,
557558
reason="Empty perm is not supported",
558559
),
560+
skip(
561+
"squeeze",
562+
matcher=lambda sample: len(sample.args) > 0
563+
and len(sample.input.shape) > 0
564+
and sample.input.shape[sample.args[0]] != 1,
565+
reason="Cannot select an axis to squeeze out which has size not equal to one",
566+
),
559567
)
560568

561569
duplicate_opinfo(

0 commit comments

Comments
 (0)