File tree Expand file tree Collapse file tree 2 files changed +20
-2
lines changed
function_libs/torch_aten/ops
tests/function_libs/torch_aten Expand file tree Collapse file tree 2 files changed +20
-2
lines changed Original file line number Diff line number Diff line change @@ -5053,10 +5053,20 @@ def aten_square(self: TensorType) -> TensorType:
50535053 raise NotImplementedError ()
50545054
50555055
5056- def aten_squeeze (self : TensorType ) -> TensorType :
5056+ @torch_op ("aten::squeeze" , trace_only = True )
5057+ def aten_squeeze (self : TTensor , dim : Optional [int ] = None ) -> TTensor :
50575058 """squeeze(Tensor(a) self) -> Tensor(a)"""
50585059
5059- raise NotImplementedError ()
5060+ if op .OptionalHasElement (dim ):
5061+ rank = op .Size (op .Shape (self ))
5062+ if rank == 0 :
5063+ self = op .Reshape (self , op .Constant (value_ints = [- 1 ]))
5064+ dims = op .Reshape (dim , op .Constant (value_ints = [- 1 ]))
5065+ result = op .Squeeze (self , dims )
5066+ else :
5067+ result = op .Squeeze (self )
5068+
5069+ return result
50605070
50615071
50625072def aten_squeeze_copy (self : TensorType ) -> TensorType :
Original file line number Diff line number Diff line change @@ -412,6 +412,7 @@ def _where_input_wrangler(
412412 ),
413413 "ones_like" : core_ops .aten_ones_like ,
414414 "slice" : core_ops .aten_slice ,
415+ "squeeze" : core_ops .aten_squeeze ,
415416 "sum" : (core_ops .aten_sum_dim_IntList , _sum_input_wrangler ),
416417 "transpose" : core_ops .aten_transpose ,
417418 "zeros_like" : core_ops .aten_zeros_like ,
@@ -556,6 +557,13 @@ def _where_input_wrangler(
556557 matcher = lambda sample : len (sample .args [0 ]) == 0 ,
557558 reason = "Empty perm is not supported" ,
558559 ),
560+ skip (
561+ "squeeze" ,
562+ matcher = lambda sample : len (sample .args ) > 0
563+ and len (sample .input .shape ) > 0
564+ and sample .input .shape [sample .args [0 ]] != 1 ,
565+ reason = "Cannot select an axis to squeeze out which has size not equal to one" ,
566+ ),
559567)
560568
561569duplicate_opinfo (
You can’t perform that action at this time.
0 commit comments