Skip to content

Commit 5e265e6

Browse files
Arm backend: Support SymInts in size_adjust_input (#18226)
Make sure SizeAdjustInputPass can handle inputs with symbolic shapes. Also adds tests for SizeAdjustInputPass. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 511b4fb commit 5e265e6

File tree

2 files changed

+440
-6
lines changed

2 files changed

+440
-6
lines changed

backends/arm/_passes/size_adjust_input_pass.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
6-
75
from typing import cast, Sequence, Set, Type, TypeAlias
86

97
import torch.fx
@@ -13,10 +11,12 @@
1311
expand_around_channel,
1412
)
1513
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
14+
from executorch.backends.arm.tosa.specification import get_context_shape_env
1615
from executorch.exir.dialects._ops import ops as exir_ops
1716
from executorch.exir.pass_base import ExportPass, PassResult
1817

1918
Slices: TypeAlias = list[tuple[int, int, int]]
19+
SymIntLike = int | torch.SymInt
2020

2121
conv2d_op = exir_ops.edge.aten.convolution.default
2222
max_pooling_op = exir_ops.edge.aten.max_pool2d.default
@@ -26,20 +26,34 @@
2626
valid_operators = [conv2d_op, max_pooling_op, avg_pooling_op]
2727

2828

29-
def conv_remainder(input_length, pad, dilation, weight, stride) -> int:
29+
def conv_remainder(
30+
input_length: SymIntLike, pad: int, dilation: int, weight: int, stride: int
31+
) -> SymIntLike:
3032
"""Returns the remainder of input_length; given the padding, dilation,
3133
stride, and kernel size.
3234
"""
3335
return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride
3436

3537

36-
def pooling_remainder(input_size, pad, kernel_size, stride) -> int:
38+
def pooling_remainder(
39+
input_size: SymIntLike, pad: int, kernel_size: int, stride: int
40+
) -> SymIntLike:
3741
"""Returns the remainder of input_length; given the padding, stride, and
3842
kernel size.
3943
"""
4044
return (input_size + 2 * pad - kernel_size) % stride
4145

4246

47+
def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool:
48+
"""Returns whether an int or SymInt is greater than another value."""
49+
if isinstance(input, torch.SymInt):
50+
shape_env = get_context_shape_env()
51+
value_ranges = shape_env.bound_sympy(input.node.expr)
52+
return value_ranges.upper > other
53+
else:
54+
return input > other
55+
56+
4357
def get_slices_convolution(conv_node: torch.fx.Node) -> Slices:
4458
slices = []
4559

@@ -59,7 +73,7 @@ def get_slices_convolution(conv_node: torch.fx.Node) -> Slices:
5973
remainder = conv_remainder(
6074
input_shape[dim], pad, dilation, weight_shape[dim], stride
6175
)
62-
if remainder > pad:
76+
if _greater_than(remainder, pad):
6377
adjustment = remainder - pad
6478
args = (dim, 0, input_shape[dim] - adjustment)
6579
slices.append(args)
@@ -87,7 +101,7 @@ def get_slices_pooling(pooling_node: torch.fx.Node) -> Slices:
87101
remainder = pooling_remainder(
88102
input_shape[dim], pad_size, kernel_length, stride_length
89103
)
90-
if remainder > pad_size:
104+
if _greater_than(remainder, pad_size):
91105
adjustment = remainder - pad_size
92106
args = (dim, 0, input_shape[dim] - adjustment)
93107
slices.append(args)

0 commit comments

Comments
 (0)