22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5-
6-
75from typing import cast , Sequence , Set , Type , TypeAlias
86
97import torch .fx
1311 expand_around_channel ,
1412)
1513from executorch .backends .arm ._passes .rewrite_conv_pass import RewriteConvPass
14+ from executorch .backends .arm .tosa .specification import get_context_shape_env
1615from executorch .exir .dialects ._ops import ops as exir_ops
1716from executorch .exir .pass_base import ExportPass , PassResult
1817
1918Slices : TypeAlias = list [tuple [int , int , int ]]
19+ SymIntLike = int | torch .SymInt
2020
2121conv2d_op = exir_ops .edge .aten .convolution .default
2222max_pooling_op = exir_ops .edge .aten .max_pool2d .default
2626valid_operators = [conv2d_op , max_pooling_op , avg_pooling_op ]
2727
2828
29- def conv_remainder (input_length , pad , dilation , weight , stride ) -> int :
29+ def conv_remainder (
30+ input_length : SymIntLike , pad : int , dilation : int , weight : int , stride : int
31+ ) -> SymIntLike :
3032 """Returns the remainder of input_length; given the padding, dilation,
3133 stride, and kernel size.
3234 """
3335 return (input_length + 2 * pad - dilation * (weight - 1 ) - 1 ) % stride
3436
3537
36- def pooling_remainder (input_size , pad , kernel_size , stride ) -> int :
38+ def pooling_remainder (
39+ input_size : SymIntLike , pad : int , kernel_size : int , stride : int
40+ ) -> SymIntLike :
3741 """Returns the remainder of input_length; given the padding, stride, and
3842 kernel size.
3943 """
4044 return (input_size + 2 * pad - kernel_size ) % stride
4145
4246
47+ def _greater_than (input : SymIntLike , other : int ) -> bool | torch .SymBool :
48+ """Returns whether an int or SymInt is greater than another value."""
49+ if isinstance (input , torch .SymInt ):
50+ shape_env = get_context_shape_env ()
51+ value_ranges = shape_env .bound_sympy (input .node .expr )
52+ return value_ranges .upper > other
53+ else :
54+ return input > other
55+
56+
4357def get_slices_convolution (conv_node : torch .fx .Node ) -> Slices :
4458 slices = []
4559
@@ -59,7 +73,7 @@ def get_slices_convolution(conv_node: torch.fx.Node) -> Slices:
5973 remainder = conv_remainder (
6074 input_shape [dim ], pad , dilation , weight_shape [dim ], stride
6175 )
62- if remainder > pad :
76+ if _greater_than ( remainder , pad ) :
6377 adjustment = remainder - pad
6478 args = (dim , 0 , input_shape [dim ] - adjustment )
6579 slices .append (args )
@@ -87,7 +101,7 @@ def get_slices_pooling(pooling_node: torch.fx.Node) -> Slices:
87101 remainder = pooling_remainder (
88102 input_shape [dim ], pad_size , kernel_length , stride_length
89103 )
90- if remainder > pad_size :
104+ if _greater_than ( remainder , pad_size ) :
91105 adjustment = remainder - pad_size
92106 args = (dim , 0 , input_shape [dim ] - adjustment )
93107 slices .append (args )
0 commit comments