@@ -503,12 +503,74 @@ def aten_argwhere(self: TensorType) -> TensorType:
503
503
raise NotImplementedError ()
504
504
505
505
506
+ @torch_op ("aten::as_strided" , trace_only = True )
506
507
def aten_as_strided (
507
- self : TensorType , size : INT64 , stride : INT64 , storage_offset : Optional [ INT64 ] = None
508
- ) -> TensorType :
508
+ self : TTensor , size : INT64 , stride : INT64 , storage_offset : int = 0
509
+ ) -> TTensor :
509
510
"""as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"""
510
511
511
- raise NotImplementedError ()
512
+ rank = len (stride )
513
+ return _aten_as_strided_onnx (self , size , stride , storage_offset , rank )
514
+
515
+
516
+ @torch_op ("aten::as_strided" , private = True )
517
+ def _aten_as_strided_onnx (
518
+ self : TTensor , size : INT64 , stride : INT64 , storage_offset : int = 0 , rank : int = 0
519
+ ) -> TTensor :
520
+ # e.g. when size=[2,3,4], stride=[2,1,3], indices=[0]
521
+ # i = 0
522
+ # indices=[0], add_value=[0,3,6,9]
523
+ # expand(shape=[4]) to [0,0,0,0]
524
+ # then + add_value = [0,3,6,9]
525
+ # i = 1
526
+ # indices=[0,3,6,9], add_value=[0,1,2]
527
+ # expand(shape=[3,4] to [[0,3,6,9],[0,3,6,9],[0,3,6,9]]
528
+ # indices + add_value = [[0,3,6,9],[1,3,7,10],[2,5,8,11]]
529
+ # i = 2
530
+ # indices = [[0,3,6,9],[1,3,7,10],[2,5,8,11]], add_value=[0,2]
531
+ # expand(shape=[2,3,4]) to [[[0,3,6,9],[1,3,7,10],[2,5,8,11]]],[[0,3,6,9],[1,3,7,10],[2,5,8,11]]]
532
+ # indices + add_value = [[[0,3,6,9],[1,3,7,10],[2,5,8,11]]],[[2,5,8,11],[3,5,9,12],[4,7,10,13]]]
533
+ neg_1 = op .Constant (value_ints = [- 1 ])
534
+ rank_tensor = op .Reshape (rank , neg_1 ) # should be 3
535
+ # The final indices for op.Gather(data, indices), will be continually changed during the loop
536
+ indices = op .Constant (value_int = 0 )
537
+ one_seq = op .SequenceEmpty ()
538
+ for i in range (rank ):
539
+ # Get the index from back to front, should be 2,1,0 when to i=0,1,2
540
+ j = rank - i - 1
541
+ j_tensor = op .Reshape (j , neg_1 )
542
+ # Get size according to index_j, should be 4,3,2 when i=0,1,2
543
+ size_dim_j = op .Gather (size , j_tensor , axis = 0 )
544
+ # Get right size according to index_j, should be [4],[3,4],[2,3,4] when i=0,1,2
545
+ size_after_j = op .Slice (size , j_tensor , rank_tensor )
546
+ # Get stride according to index_j, should be 3,1,2 when i=0,1,2
547
+ stride_dim_j = op .Gather (stride , j_tensor , axis = 0 )
548
+ indices = op .Expand (indices , size_after_j )
549
+ # When size[j]=4, stride[j]=3, then add_value = [0,1,2,3] * 3 = [0,3,6,9]
550
+ # When size[j]=3, stride[j]=1, then add_value = [0,1,2] * 1 = [0,1,2]
551
+ # When size[j]=2, stride[j]=2, then add_value = [0,1] * 2 = [0,2]
552
+ add_value = op .Range (0 , size_dim_j , 1 ) * stride_dim_j
553
+ # Compute the shape for add_value for correct broadcasting
554
+ if i == 0 :
555
+ # shape = [dim_size]
556
+ shape = size_dim_j
557
+ else :
558
+ # shape = [dim_size, 1, 1, ...], the count of 1 euqal to i
559
+ ones = op .ConcatFromSequence (one_seq , axis = 0 )
560
+ shape = op .Concat (op .Cast (size_dim_j , to = FLOAT .dtype ), ones , axis = 0 )
561
+ shape = op .Cast (shape , to = INT64 .dtype )
562
+
563
+ add_value = op .Reshape (add_value , shape )
564
+ # Broadcasting add value to indices according to size and stride value
565
+ indices = indices + add_value
566
+ # Dims after dim_size to reshape(add_value), should be [1],[1,1],[1,1,1] when i=0,1,2
567
+ one_seq = op .SequenceInsert (one_seq , op .Constant (value_floats = [1.0 ]))
568
+
569
+ self_flatten = op .Reshape (self , op .Constant (value_ints = [- 1 ]))
570
+ indices = op .Add (indices , storage_offset )
571
+ result = op .Gather (self_flatten , indices )
572
+
573
+ return result
512
574
513
575
514
576
def aten_as_strided_copy (
0 commit comments