@@ -6718,10 +6718,47 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
67186718 return op .Reshape (self , final_shape )
67196719
67206720
6721- def aten_unfold (self : TensorType , dimension : int , size : int , step : int ) -> TensorType :
6721+ @torch_op ("aten::unfold" , trace_only = True )
6722+ def aten_unfold (self : TTensor , dimension : int , size : int , step : int ) -> TTensor :
67226723 """unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)"""
67236724
6724- raise NotImplementedError ()
6725+ self_rank = len (self .shape )
6726+ if self_rank == 0 :
6727+ result = op .Unsqueeze (self , 0 )
6728+ else :
6729+ dim_size = self .shape [dimension ]
6730+ target_end = (dim_size - size ) // step + 1
6731+ if target_end > 1 : # the rank of final reuslt will be self_rank + 1
6732+ self_rank = self_rank + 1
6733+ # perm need to be list[int], so have to be generated in trace_only mode
6734+ perm = list (range (self_rank ))
6735+ # from [0,1,2,3,4] -> [0,1,3,4,2] when dimension=1
6736+ perm .append (perm .pop (dimension + 1 ))
6737+ result = _aten_unfold_onnx (self , dimension , size , step , target_end , perm )
6738+ return result
6739+
6740+
6741+ @torch_op ("aten::unfold" , private = True )
6742+ def _aten_unfold_onnx (
6743+ self : TTensor , dim : int , size : int , step : int , target_end : int , perm : Sequence [int ]
6744+ ) -> TTensor :
6745+ dims = op .Reshape (op .Constant (value_int = dim ), op .Constant (value_ints = [- 1 ]))
6746+ # FIXME: the dtype for this function cannot work, default to float
6747+ seq_result = op .SequenceEmpty ()
6748+ i = op .Constant (value_ints = [0 ])
6749+ cond = i < target_end
6750+ while cond : # because for loop cannot work here, so use while loop
6751+ starts = i * step # starts is [0, step, step*2, step*3, ...]
6752+ ends = starts + size # ends is [0+size, step+size, step*2+size, step*3+size, ...]
6753+ slice_result = op .Slice (self , starts , ends , dims )
6754+ # sequence only support float32
6755+ slice_result_float32 = op .Cast (slice_result , to = FLOAT .dtype )
6756+ seq_result = op .SequenceInsert (seq_result , slice_result_float32 )
6757+ i = i + 1
6758+ cond = i < target_end
6759+ concat_result = op .ConcatFromSequence (seq_result , axis = dim , new_axis = 1 )
6760+ result = op .Transpose (concat_result , perm = perm )
6761+ return op .CastLike (result , self )
67256762
67266763
67276764def aten_unfold_backward (
0 commit comments