Skip to content

Commit 2019773

Browse files
authored
Merge branch 'main' into justinchu/release
2 parents 1c47ff7 + 5eafe2a commit 2019773

File tree

3 files changed

+243
-147
lines changed

3 files changed

+243
-147
lines changed

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 183 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -668,66 +668,76 @@ def aten_max_pool1d_with_indices(
668668
raise NotImplementedError()
669669

670670

671-
@torch_op("aten::max_pool2d", trace_only=True)
672-
def aten_max_pool2d(
673-
self: TFloatOrUInt8,
671+
def _adjust_attributes_of_max_pool(
672+
expand_size: int,
674673
kernel_size: Sequence[int],
675-
stride: Sequence[int] = (),
676-
padding: Sequence[int] = (0, 0),
677-
dilation: Sequence[int] = (1, 1),
678-
ceil_mode: bool = False,
679-
) -> TFloatOrUInt8:
680-
"""max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor"""
681-
682-
# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
683-
# But ONNX needs pair number [x,y] to specify on each side explicitly
684-
# For pool3d, this number should be 3
685-
expand_size = 2
686-
687-
# The dilations should be [x, y]
688-
if isinstance(dilation, int): # x -> [x, x]
674+
stride: Sequence[int],
675+
padding: Sequence[int],
676+
dilation: Sequence[int],
677+
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
678+
if isinstance(dilation, int):
689679
dilations = [dilation] * expand_size
690-
else: # already [x, y]
680+
else:
691681
dilations = dilation
692682

693-
# The kernel_shape should be [x, y]
694-
if isinstance(kernel_size, int): # x -> [x, x]
683+
if isinstance(kernel_size, int):
695684
kernel_shape = [kernel_size] * expand_size
696-
else: # assert(len(kernel_size)==2), already [x, y]
685+
else:
697686
kernel_shape = kernel_size
698687

699-
# The pads should be [w, x, y, z]
700-
if isinstance(padding, int): # w -> [w, w, w, w]
688+
if isinstance(padding, int):
701689
pads = [padding] * expand_size * 2
702-
elif len(padding) == 1: # [w] -> [w, w, w, w]
703-
pads = padding * 4
704-
elif len(padding) == 2: # [w, x] -> [w, x, w, x]
705-
pads = padding * 2
706-
else: # assert len(padding) == 4, already [w, x, y, z]
690+
elif len(padding) == 1:
691+
pads = padding * expand_size * 2
692+
elif len(padding) == 2:
693+
pads = padding * expand_size
694+
else:
707695
pads = padding
708696

709-
# The strides should be [x, y]
710-
if isinstance(stride, int): # x -> [x, x]
697+
if isinstance(stride, int):
711698
strides = [stride] * expand_size
712699
elif stride is None:
713700
strides = kernel_shape
714701
else:
715702
strides = stride
716703

717-
return _aten_max_pool2d_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode)
704+
return (kernel_shape, strides, pads, dilations)
705+
706+
707+
@torch_op("aten::max_pool2d", trace_only=True)
708+
def aten_max_pool2d(
709+
self: TFloatOrUInt8,
710+
kernel_size: Sequence[int],
711+
stride: Sequence[int] = (),
712+
padding: Sequence[int] = (0, 0),
713+
dilation: Sequence[int] = (1, 1),
714+
ceil_mode: bool = False,
715+
) -> TFloatOrUInt8:
716+
"""max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor"""
717+
718+
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
719+
# But ONNX needs to specify a pair of number [x,y] on each side explicitly.
720+
expand_size = 2
721+
722+
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
723+
expand_size, kernel_size, stride, padding, dilation
724+
)
718725

726+
return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 3)
719727

720-
@torch_op("aten::max_pool2d", private=True)
721-
def _aten_max_pool2d_onnx(
728+
729+
@torch_op("internal::max_pool", private=True)
730+
def _aten_max_pool_onnx(
722731
self: TFloatOrUInt8,
723732
kernel_shape: Sequence[int],
724733
strides: Sequence[int],
725734
pads: Sequence[int],
726735
dilations: Sequence[int],
727736
ceil_mode: bool,
737+
unbatched_rank: int,
728738
) -> TFloatOrUInt8:
729739
self_rank = op.Size(op.Shape(self))
730-
if self_rank == 3: # C,H,W -> N,C,H,W and N=1
740+
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
731741
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
732742

733743
pool_result, _ = op.MaxPool(
@@ -739,122 +749,65 @@ def _aten_max_pool2d_onnx(
739749
strides=strides,
740750
)
741751

742-
if self_rank == 3:
752+
if self_rank == unbatched_rank:
743753
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
744754

745755
return pool_result
746756

747757

748-
@torch_op("aten::max_pool2d_with_indices", trace_only=True)
749-
def aten_max_pool2d_with_indices(
758+
@torch_op("aten::max_pool3d", trace_only=True)
759+
def aten_max_pool3d(
750760
self: TFloatOrUInt8,
751761
kernel_size: Sequence[int],
752762
stride: Sequence[int] = (),
753763
padding: Sequence[int] = (0, 0),
754764
dilation: Sequence[int] = (1, 1),
755765
ceil_mode: bool = False,
756-
) -> Tuple[TFloatOrUInt8, INT64]:
757-
"""max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""
758-
759-
# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
760-
# But ONNX needs pair number [x,y] to specify on each side explicitly
761-
# For pool3d, this number should be 3
762-
expand_size = 2
763-
764-
# The dilations should be [x, y]
765-
if isinstance(dilation, int): # x -> [x, x]
766-
dilations = [dilation] * expand_size
767-
else: # already [x, y]
768-
dilations = dilation
769-
770-
# The kernel_shape should be [x, y]
771-
if isinstance(kernel_size, int): # x -> [x, x]
772-
kernel_shape = [kernel_size] * expand_size
773-
else: # assert(len(kernel_size)==2), already [x, y]
774-
kernel_shape = kernel_size
766+
) -> TFloatOrUInt8:
767+
"""max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor"""
775768

776-
# The pads should be [w, x, y, z]
777-
if isinstance(padding, int): # w -> [w, w, w, w]
778-
pads = [padding] * expand_size * 2
779-
elif len(padding) == 1: # [w] -> [w, w, w, w]
780-
pads = padding * 4
781-
elif len(padding) == 2: # [w, x] -> [w, x, w, x]
782-
pads = padding * 2
783-
else: # assert len(padding) == 4, already [w, x, y, z]
784-
pads = padding
769+
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
770+
# But ONNX needs to specify a tuple of three ints for all sides explicitly.
771+
expand_size = 3
785772

786-
# The strides should be [x, y]
787-
if isinstance(stride, int): # x -> [x, x]
788-
strides = [stride] * expand_size
789-
elif stride is None:
790-
strides = kernel_shape
791-
else:
792-
strides = stride
793-
794-
return _aten_max_pool2d_with_indices_onnx(
795-
self, expand_size, kernel_shape, strides, pads, dilations, ceil_mode
773+
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
774+
expand_size, kernel_size, stride, padding, dilation
796775
)
797776

777+
return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 4)
778+
798779

799-
@torch_op("aten::max_pool2d_with_indices", private=True)
800-
def _aten_max_pool2d_with_indices_onnx(
780+
@torch_op("aten::max_pool2d_with_indices", trace_only=True)
781+
def aten_max_pool2d_with_indices(
801782
self: TFloatOrUInt8,
802-
expand_size: INT64,
803-
kernel_shape: Sequence[int],
804-
strides: Sequence[int],
805-
pads: Sequence[int],
806-
dilations: Sequence[int],
807-
ceil_mode: bool,
783+
kernel_size: Sequence[int],
784+
stride: Sequence[int] = (),
785+
padding: Sequence[int] = (0, 0),
786+
dilation: Sequence[int] = (1, 1),
787+
ceil_mode: bool = False,
808788
) -> Tuple[TFloatOrUInt8, INT64]:
809-
self_rank = op.Size(op.Shape(self))
810-
if self_rank == 3: # C,H,W -> N,C,H,W and N=1
811-
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
812-
813-
pool_result, indices = op.MaxPool(
814-
self,
815-
ceil_mode=ceil_mode,
816-
dilations=dilations,
817-
kernel_shape=kernel_shape,
818-
pads=pads,
819-
strides=strides,
820-
)
789+
"""max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""
821790

822-
if self_rank == 3:
823-
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
791+
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
792+
# But ONNX needs to specify a pair of number [x,y] on each side explicitly.
793+
expand_size = 2
824794

825-
# Torch use relative position number for the second Channel data
826-
# If align, need reduce size(Channel)
827-
# e.g. [[8,3,10],[30,32,23]]-[0,18] -> [[8,3,10],[12,14,5]]
828-
# 18 = H x W = 3 x 6
829-
batches = op.Shape(self, start=0, end=1)
830-
channels = op.Shape(self, start=1, end=2)
831-
end = batches * channels
832-
offset = op.Range(0, end, 1)
833-
data_shape = op.Shape(self, start=2)
834-
data_size = op.ReduceProd(data_shape)
835-
offset = offset * data_size
836-
new_shape = op.Expand(
837-
op.Constant(value_ints=[1]), op.Reshape(expand_size, op.Constant(value_ints=[-1]))
795+
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
796+
expand_size, kernel_size, stride, padding, dilation
838797
)
839-
new_shape = op.Concat(batches, channels, new_shape, axis=0)
840-
offset = op.Reshape(offset, new_shape)
841-
indices = indices - offset
842-
if self_rank == 3:
843-
indices = op.Squeeze(indices, op.Constant(value_ints=[0]))
844-
return pool_result, indices
845-
846798

847-
def aten_max_pool3d(
848-
self: TensorType,
849-
kernel_size: Sequence[int],
850-
stride: Optional[Sequence[int]] = None,
851-
padding: Sequence[int] = (0, 0, 0),
852-
dilation: Sequence[int] = (1, 1, 1),
853-
ceil_mode: bool = False,
854-
) -> TensorType:
855-
"""max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor"""
856-
857-
raise NotImplementedError()
799+
return _aten_max_pool_with_indices_onnx(
800+
self,
801+
kernel_shape,
802+
strides,
803+
pads,
804+
dilations,
805+
ceil_mode,
806+
3,
807+
([1] * expand_size),
808+
([0] * expand_size),
809+
([2 + i for i in range(expand_size)]),
810+
)
858811

859812

860813
def aten_max_pool2d_with_indices_backward(
@@ -872,17 +825,113 @@ def aten_max_pool2d_with_indices_backward(
872825
raise NotImplementedError()
873826

874827

828+
@torch_op("aten::max_pool3d_with_indices", trace_only=True)
875829
def aten_max_pool3d_with_indices(
876-
self: TensorType,
830+
self: TFloatOrUInt8,
877831
kernel_size: Sequence[int],
878-
stride: Optional[Sequence[int]] = None,
879-
padding: Sequence[int] = (0, 0, 0),
880-
dilation: Sequence[int] = (1, 1, 1),
832+
stride: Sequence[int] = (),
833+
padding: Sequence[int] = (0, 0),
834+
dilation: Sequence[int] = (1, 1),
881835
ceil_mode: bool = False,
882-
) -> tuple[TensorType, TensorType]:
836+
) -> Tuple[TFloatOrUInt8, INT64]:
883837
"""max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""
884838

885-
raise NotImplementedError()
839+
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
840+
# But ONNX needs to specify a tuple of three ints for all sides explicitly.
841+
expand_size = 3
842+
843+
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
844+
expand_size, kernel_size, stride, padding, dilation
845+
)
846+
847+
return _aten_max_pool_with_indices_onnx(
848+
self,
849+
kernel_shape,
850+
strides,
851+
pads,
852+
dilations,
853+
ceil_mode,
854+
4,
855+
([1] * expand_size),
856+
([0] * expand_size),
857+
([2 + i for i in range(expand_size)]),
858+
)
859+
860+
861+
@torch_op("internal::max_pool_with_indices", private=True)
862+
def _aten_max_pool_with_indices_onnx(
863+
self: TFloatOrUInt8,
864+
kernel_size: Sequence[int],
865+
stride: Sequence[int],
866+
padding: Sequence[int],
867+
dilation: Sequence[int],
868+
ceil_mode: bool,
869+
unbatched_rank: int,
870+
n_dims_one: Sequence[int],
871+
n_dims_zero: Sequence[int],
872+
n_dims_axes: Sequence[int],
873+
) -> Tuple[TFloatOrUInt8, INT64]:
874+
self_rank = op.Size(op.Shape(self))
875+
if self_rank == unbatched_rank:
876+
self = op.Unsqueeze(self, axes=0)
877+
878+
pool_result, indices = op.MaxPool(
879+
self,
880+
ceil_mode=ceil_mode,
881+
dilations=dilation,
882+
kernel_shape=kernel_size,
883+
pads=padding,
884+
strides=stride,
885+
)
886+
887+
# Simple but hacky way to get flattened indices values
888+
# to be used to convert the indices values to non-flattened.
889+
# In ONNX the indices are computed as a flatten 1-D tensor,
890+
# so the values in indices are in [0, N x C x D1 x ... x Dn).
891+
# To convert the indices to the same format used by PyTorch,
892+
# we first execute a maxpool with a kernel and stride of 1 on the same input.
893+
# This will result in a tensor of indices in which each index will have it's own value.
894+
# Using this tensor as a reference, we extract the first index of each axis and subtract
895+
# it from each index of this axis in the indices to convert.
896+
# This step will result in a tensor where each dimension has values of indices within
897+
# the dimension it is in.
898+
# For Maxpool1d(kernel=1,stride=1,return_indices=True), with the input torch.ones(1,2,2).
899+
# The computed indices are the following:
900+
# output indices pytorch :
901+
# [[0,1],
902+
# [0,1]]
903+
# output indices onnx:
904+
# [[0,1],
905+
# [2,3]]
906+
# The purpose was to convert the indices from one format to the other to be able to match the results.
907+
# So flattened_indices will have the value of each index and will be equal to :
908+
# [[0,1],
909+
# [2,3]]
910+
# Then call Slice to get the first value of each line (so 0 and 2).
911+
# And the subtraction executes :
912+
# [[0-0,1-0],
913+
# [2-2,3-2]]
914+
# So indices results to the expected output which is :
915+
# [[0,1],
916+
# [0,1]]
917+
# For more information :
918+
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
919+
_, flatten_indices = op.MaxPool(
920+
self, dilations=dilation, kernel_shape=n_dims_one, strides=n_dims_one
921+
)
922+
923+
ends = op.Constant(value_ints=n_dims_one)
924+
starts = op.Constant(value_ints=n_dims_zero)
925+
axes = op.Constant(value_ints=n_dims_axes)
926+
927+
delta = op.Slice(flatten_indices, axes=axes, starts=starts, ends=ends)
928+
indices = op.Sub(indices, delta)
929+
930+
if self_rank == unbatched_rank:
931+
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
932+
indices = op.Squeeze(indices, op.Constant(value_ints=[0]))
933+
934+
return (pool_result, indices)
886935

887936

888937
def aten_max_pool3d_with_indices_backward(

0 commit comments

Comments
 (0)