@@ -668,66 +668,76 @@ def aten_max_pool1d_with_indices(
668
668
raise NotImplementedError ()
669
669
670
670
671
- @torch_op ("aten::max_pool2d" , trace_only = True )
672
- def aten_max_pool2d (
673
- self : TFloatOrUInt8 ,
671
+ def _adjust_attributes_of_max_pool (
672
+ expand_size : int ,
674
673
kernel_size : Sequence [int ],
675
- stride : Sequence [int ] = (),
676
- padding : Sequence [int ] = (0 , 0 ),
677
- dilation : Sequence [int ] = (1 , 1 ),
678
- ceil_mode : bool = False ,
679
- ) -> TFloatOrUInt8 :
680
- """max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor"""
681
-
682
- # Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
683
- # But ONNX needs pair number [x,y] to specify on each side explicitly
684
- # For pool3d, this number should be 3
685
- expand_size = 2
686
-
687
- # The dilations should be [x, y]
688
- if isinstance (dilation , int ): # x -> [x, x]
674
+ stride : Sequence [int ],
675
+ padding : Sequence [int ],
676
+ dilation : Sequence [int ],
677
+ ) -> Tuple [Sequence [int ], Sequence [int ], Sequence [int ], Sequence [int ]]:
678
+ if isinstance (dilation , int ):
689
679
dilations = [dilation ] * expand_size
690
- else : # already [x, y]
680
+ else :
691
681
dilations = dilation
692
682
693
- # The kernel_shape should be [x, y]
694
- if isinstance (kernel_size , int ): # x -> [x, x]
683
+ if isinstance (kernel_size , int ):
695
684
kernel_shape = [kernel_size ] * expand_size
696
- else : # assert(len(kernel_size)==2), already [x, y]
685
+ else :
697
686
kernel_shape = kernel_size
698
687
699
- # The pads should be [w, x, y, z]
700
- if isinstance (padding , int ): # w -> [w, w, w, w]
688
+ if isinstance (padding , int ):
701
689
pads = [padding ] * expand_size * 2
702
- elif len (padding ) == 1 : # [w] -> [w, w, w, w]
703
- pads = padding * 4
704
- elif len (padding ) == 2 : # [w, x] -> [w, x, w, x]
705
- pads = padding * 2
706
- else : # assert len(padding) == 4, already [w, x, y, z]
690
+ elif len (padding ) == 1 :
691
+ pads = padding * expand_size * 2
692
+ elif len (padding ) == 2 :
693
+ pads = padding * expand_size
694
+ else :
707
695
pads = padding
708
696
709
- # The strides should be [x, y]
710
- if isinstance (stride , int ): # x -> [x, x]
697
+ if isinstance (stride , int ):
711
698
strides = [stride ] * expand_size
712
699
elif stride is None :
713
700
strides = kernel_shape
714
701
else :
715
702
strides = stride
716
703
717
- return _aten_max_pool2d_onnx (self , kernel_shape , strides , pads , dilations , ceil_mode )
704
+ return (kernel_shape , strides , pads , dilations )
705
+
706
+
707
+ @torch_op ("aten::max_pool2d" , trace_only = True )
708
+ def aten_max_pool2d (
709
+ self : TFloatOrUInt8 ,
710
+ kernel_size : Sequence [int ],
711
+ stride : Sequence [int ] = (),
712
+ padding : Sequence [int ] = (0 , 0 ),
713
+ dilation : Sequence [int ] = (1 , 1 ),
714
+ ceil_mode : bool = False ,
715
+ ) -> TFloatOrUInt8 :
716
+ """max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor"""
717
+
718
+ # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
719
+ # But ONNX needs to specify a pair of number [x,y] on each side explicitly.
720
+ expand_size = 2
721
+
722
+ kernel_shape , strides , pads , dilations = _adjust_attributes_of_max_pool (
723
+ expand_size , kernel_size , stride , padding , dilation
724
+ )
718
725
726
+ return _aten_max_pool_onnx (self , kernel_shape , strides , pads , dilations , ceil_mode , 3 )
719
727
720
- @torch_op ("aten::max_pool2d" , private = True )
721
- def _aten_max_pool2d_onnx (
728
+
729
+ @torch_op ("internal::max_pool" , private = True )
730
+ def _aten_max_pool_onnx (
722
731
self : TFloatOrUInt8 ,
723
732
kernel_shape : Sequence [int ],
724
733
strides : Sequence [int ],
725
734
pads : Sequence [int ],
726
735
dilations : Sequence [int ],
727
736
ceil_mode : bool ,
737
+ unbatched_rank : int ,
728
738
) -> TFloatOrUInt8 :
729
739
self_rank = op .Size (op .Shape (self ))
730
- if self_rank == 3 : # C,H,W -> N,C,H,W and N=1
740
+ if self_rank == unbatched_rank : # C,H,W -> N,C,H,W and N=1
731
741
self = op .Unsqueeze (self , op .Constant (value_ints = [0 ]))
732
742
733
743
pool_result , _ = op .MaxPool (
@@ -739,122 +749,65 @@ def _aten_max_pool2d_onnx(
739
749
strides = strides ,
740
750
)
741
751
742
- if self_rank == 3 :
752
+ if self_rank == unbatched_rank :
743
753
pool_result = op .Squeeze (pool_result , op .Constant (value_ints = [0 ]))
744
754
745
755
return pool_result
746
756
747
757
748
- @torch_op ("aten::max_pool2d_with_indices " , trace_only = True )
749
- def aten_max_pool2d_with_indices (
758
+ @torch_op ("aten::max_pool3d " , trace_only = True )
759
+ def aten_max_pool3d (
750
760
self : TFloatOrUInt8 ,
751
761
kernel_size : Sequence [int ],
752
762
stride : Sequence [int ] = (),
753
763
padding : Sequence [int ] = (0 , 0 ),
754
764
dilation : Sequence [int ] = (1 , 1 ),
755
765
ceil_mode : bool = False ,
756
- ) -> Tuple [TFloatOrUInt8 , INT64 ]:
757
- """max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""
758
-
759
- # Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
760
- # But ONNX needs pair number [x,y] to specify on each side explicitly
761
- # For pool3d, this number should be 3
762
- expand_size = 2
763
-
764
- # The dilations should be [x, y]
765
- if isinstance (dilation , int ): # x -> [x, x]
766
- dilations = [dilation ] * expand_size
767
- else : # already [x, y]
768
- dilations = dilation
769
-
770
- # The kernel_shape should be [x, y]
771
- if isinstance (kernel_size , int ): # x -> [x, x]
772
- kernel_shape = [kernel_size ] * expand_size
773
- else : # assert(len(kernel_size)==2), already [x, y]
774
- kernel_shape = kernel_size
766
+ ) -> TFloatOrUInt8 :
767
+ """max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor"""
775
768
776
- # The pads should be [w, x, y, z]
777
- if isinstance (padding , int ): # w -> [w, w, w, w]
778
- pads = [padding ] * expand_size * 2
779
- elif len (padding ) == 1 : # [w] -> [w, w, w, w]
780
- pads = padding * 4
781
- elif len (padding ) == 2 : # [w, x] -> [w, x, w, x]
782
- pads = padding * 2
783
- else : # assert len(padding) == 4, already [w, x, y, z]
784
- pads = padding
769
+ # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
770
+ # But ONNX needs to specify a tuple of three ints for all sides explicitly.
771
+ expand_size = 3
785
772
786
- # The strides should be [x, y]
787
- if isinstance (stride , int ): # x -> [x, x]
788
- strides = [stride ] * expand_size
789
- elif stride is None :
790
- strides = kernel_shape
791
- else :
792
- strides = stride
793
-
794
- return _aten_max_pool2d_with_indices_onnx (
795
- self , expand_size , kernel_shape , strides , pads , dilations , ceil_mode
773
+ kernel_shape , strides , pads , dilations = _adjust_attributes_of_max_pool (
774
+ expand_size , kernel_size , stride , padding , dilation
796
775
)
797
776
777
+ return _aten_max_pool_onnx (self , kernel_shape , strides , pads , dilations , ceil_mode , 4 )
778
+
798
779
799
- @torch_op ("aten::max_pool2d_with_indices" , private = True )
800
- def _aten_max_pool2d_with_indices_onnx (
780
+ @torch_op ("aten::max_pool2d_with_indices" , trace_only = True )
781
+ def aten_max_pool2d_with_indices (
801
782
self : TFloatOrUInt8 ,
802
- expand_size : INT64 ,
803
- kernel_shape : Sequence [int ],
804
- strides : Sequence [int ],
805
- pads : Sequence [int ],
806
- dilations : Sequence [int ],
807
- ceil_mode : bool ,
783
+ kernel_size : Sequence [int ],
784
+ stride : Sequence [int ] = (),
785
+ padding : Sequence [int ] = (0 , 0 ),
786
+ dilation : Sequence [int ] = (1 , 1 ),
787
+ ceil_mode : bool = False ,
808
788
) -> Tuple [TFloatOrUInt8 , INT64 ]:
809
- self_rank = op .Size (op .Shape (self ))
810
- if self_rank == 3 : # C,H,W -> N,C,H,W and N=1
811
- self = op .Unsqueeze (self , op .Constant (value_ints = [0 ]))
812
-
813
- pool_result , indices = op .MaxPool (
814
- self ,
815
- ceil_mode = ceil_mode ,
816
- dilations = dilations ,
817
- kernel_shape = kernel_shape ,
818
- pads = pads ,
819
- strides = strides ,
820
- )
789
+ """max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""
821
790
822
- if self_rank == 3 :
823
- pool_result = op .Squeeze (pool_result , op .Constant (value_ints = [0 ]))
791
+ # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
792
+ # But ONNX needs to specify a pair of number [x,y] on each side explicitly.
793
+ expand_size = 2
824
794
825
- # Torch use relative position number for the second Channel data
826
- # If align, need reduce size(Channel)
827
- # e.g. [[8,3,10],[30,32,23]]-[0,18] -> [[8,3,10],[12,14,5]]
828
- # 18 = H x W = 3 x 6
829
- batches = op .Shape (self , start = 0 , end = 1 )
830
- channels = op .Shape (self , start = 1 , end = 2 )
831
- end = batches * channels
832
- offset = op .Range (0 , end , 1 )
833
- data_shape = op .Shape (self , start = 2 )
834
- data_size = op .ReduceProd (data_shape )
835
- offset = offset * data_size
836
- new_shape = op .Expand (
837
- op .Constant (value_ints = [1 ]), op .Reshape (expand_size , op .Constant (value_ints = [- 1 ]))
795
+ kernel_shape , strides , pads , dilations = _adjust_attributes_of_max_pool (
796
+ expand_size , kernel_size , stride , padding , dilation
838
797
)
839
- new_shape = op .Concat (batches , channels , new_shape , axis = 0 )
840
- offset = op .Reshape (offset , new_shape )
841
- indices = indices - offset
842
- if self_rank == 3 :
843
- indices = op .Squeeze (indices , op .Constant (value_ints = [0 ]))
844
- return pool_result , indices
845
-
846
798
847
- def aten_max_pool3d (
848
- self : TensorType ,
849
- kernel_size : Sequence [int ],
850
- stride : Optional [Sequence [int ]] = None ,
851
- padding : Sequence [int ] = (0 , 0 , 0 ),
852
- dilation : Sequence [int ] = (1 , 1 , 1 ),
853
- ceil_mode : bool = False ,
854
- ) -> TensorType :
855
- """max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor"""
856
-
857
- raise NotImplementedError ()
799
+ return _aten_max_pool_with_indices_onnx (
800
+ self ,
801
+ kernel_shape ,
802
+ strides ,
803
+ pads ,
804
+ dilations ,
805
+ ceil_mode ,
806
+ 3 ,
807
+ ([1 ] * expand_size ),
808
+ ([0 ] * expand_size ),
809
+ ([2 + i for i in range (expand_size )]),
810
+ )
858
811
859
812
860
813
def aten_max_pool2d_with_indices_backward (
@@ -872,17 +825,113 @@ def aten_max_pool2d_with_indices_backward(
872
825
raise NotImplementedError ()
873
826
874
827
828
+ @torch_op ("aten::max_pool3d_with_indices" , trace_only = True )
875
829
def aten_max_pool3d_with_indices (
876
- self : TensorType ,
830
+ self : TFloatOrUInt8 ,
877
831
kernel_size : Sequence [int ],
878
- stride : Optional [ Sequence [int ]] = None ,
879
- padding : Sequence [int ] = (0 , 0 , 0 ),
880
- dilation : Sequence [int ] = (1 , 1 , 1 ),
832
+ stride : Sequence [int ] = () ,
833
+ padding : Sequence [int ] = (0 , 0 ),
834
+ dilation : Sequence [int ] = (1 , 1 ),
881
835
ceil_mode : bool = False ,
882
- ) -> tuple [ TensorType , TensorType ]:
836
+ ) -> Tuple [ TFloatOrUInt8 , INT64 ]:
883
837
"""max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""
884
838
885
- raise NotImplementedError ()
839
+ # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
840
+ # But ONNX needs to specify a tuple of three ints for all sides explicitly.
841
+ expand_size = 3
842
+
843
+ kernel_shape , strides , pads , dilations = _adjust_attributes_of_max_pool (
844
+ expand_size , kernel_size , stride , padding , dilation
845
+ )
846
+
847
+ return _aten_max_pool_with_indices_onnx (
848
+ self ,
849
+ kernel_shape ,
850
+ strides ,
851
+ pads ,
852
+ dilations ,
853
+ ceil_mode ,
854
+ 4 ,
855
+ ([1 ] * expand_size ),
856
+ ([0 ] * expand_size ),
857
+ ([2 + i for i in range (expand_size )]),
858
+ )
859
+
860
+
861
+ @torch_op ("internal::max_pool_with_indices" , private = True )
862
+ def _aten_max_pool_with_indices_onnx (
863
+ self : TFloatOrUInt8 ,
864
+ kernel_size : Sequence [int ],
865
+ stride : Sequence [int ],
866
+ padding : Sequence [int ],
867
+ dilation : Sequence [int ],
868
+ ceil_mode : bool ,
869
+ unbatched_rank : int ,
870
+ n_dims_one : Sequence [int ],
871
+ n_dims_zero : Sequence [int ],
872
+ n_dims_axes : Sequence [int ],
873
+ ) -> Tuple [TFloatOrUInt8 , INT64 ]:
874
+ self_rank = op .Size (op .Shape (self ))
875
+ if self_rank == unbatched_rank :
876
+ self = op .Unsqueeze (self , axes = 0 )
877
+
878
+ pool_result , indices = op .MaxPool (
879
+ self ,
880
+ ceil_mode = ceil_mode ,
881
+ dilations = dilation ,
882
+ kernel_shape = kernel_size ,
883
+ pads = padding ,
884
+ strides = stride ,
885
+ )
886
+
887
+ # Simple but hacky way to get flattened indices values
888
+ # to be used to convert the indices values to non-flattened.
889
+ # In ONNX the indices are computed as a flatten 1-D tensor,
890
+ # so the values in indices are in [0, N x C x D1 x ... x Dn).
891
+ # To convert the indices to the same format used by PyTorch,
892
+ # we first execute a maxpool with a kernel and stride of 1 on the same input.
893
+ # This will result in a tensor of indices in which each index will have it's own value.
894
+ # Using this tensor as a reference, we extract the first index of each axis and subtract
895
+ # it from each index of this axis in the indices to convert.
896
+ # This step will result in a tensor where each dimension has values of indices within
897
+ # the dimension it is in.
898
+ # For Maxpool1d(kernel=1,stride=1,return_indices=True), with the input torch.ones(1,2,2).
899
+ # The computed indices are the following:
900
+ # output indices pytorch :
901
+ # [[0,1],
902
+ # [0,1]]
903
+ # output indices onnx:
904
+ # [[0,1],
905
+ # [2,3]]
906
+ # The purpose was to convert the indices from one format to the other to be able to match the results.
907
+ # So flattened_indices will have the value of each index and will be equal to :
908
+ # [[0,1],
909
+ # [2,3]]
910
+ # Then call Slice to get the first value of each line (so 0 and 2).
911
+ # And the subtraction executes :
912
+ # [[0-0,1-0],
913
+ # [2-2,3-2]]
914
+ # So indices results to the expected output which is :
915
+ # [[0,1],
916
+ # [0,1]]
917
+ # For more information :
918
+ # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
919
+ _ , flatten_indices = op .MaxPool (
920
+ self , dilations = dilation , kernel_shape = n_dims_one , strides = n_dims_one
921
+ )
922
+
923
+ ends = op .Constant (value_ints = n_dims_one )
924
+ starts = op .Constant (value_ints = n_dims_zero )
925
+ axes = op .Constant (value_ints = n_dims_axes )
926
+
927
+ delta = op .Slice (flatten_indices , axes = axes , starts = starts , ends = ends )
928
+ indices = op .Sub (indices , delta )
929
+
930
+ if self_rank == unbatched_rank :
931
+ pool_result = op .Squeeze (pool_result , op .Constant (value_ints = [0 ]))
932
+ indices = op .Squeeze (indices , op .Constant (value_ints = [0 ]))
933
+
934
+ return (pool_result , indices )
886
935
887
936
888
937
def aten_max_pool3d_with_indices_backward (
0 commit comments