Skip to content

feat: support dynamic shapes for avg poolNd #3010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,13 +2546,19 @@ def avg_pool_param_validator(pool_node: Node) -> bool:

# Note: AvgPool1d uses avg_pool2d as it converts to 2D first.
@dynamo_tensorrt_converter(
torch.ops.aten.avg_pool1d.default, capability_validator=avg_pool_param_validator
torch.ops.aten.avg_pool1d.default,
capability_validator=avg_pool_param_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.avg_pool2d.default, capability_validator=avg_pool_param_validator
torch.ops.aten.avg_pool2d.default,
capability_validator=avg_pool_param_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.avg_pool3d.default, capability_validator=avg_pool_param_validator
torch.ops.aten.avg_pool3d.default,
capability_validator=avg_pool_param_validator,
supports_dynamic_shapes=True,
)
def aten_ops_avg_pool(
ctx: ConversionContext,
Expand Down
3 changes: 0 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ def avg_poolNd(
count_include_pad: bool = True,
divisor_override: Optional[int] = None,
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling."

if ceil_mode is not False:
raise RuntimeError("ceil_mode is not yet supported!")

Expand Down
144 changes: 144 additions & 0 deletions tests/py/dynamo/conversion/test_pool_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,150 @@ def forward(self, x):
inputs = [torch.randn(1, 3, 32, 32, 32)]
self.run_test(TestModule(), inputs, use_dynamo_tracer=True)

@parameterized.expand(
[
(
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
torch.float,
(3,),
(1,),
(1,),
),
]
)
def test_dynamic_shape_pool1d(
self,
min_shape,
opt_shape,
max_shape,
type,
kernel_size,
stride=1,
padding=0,
ceil_mode=False,
count_include_pad=True,
):
class pool1d(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.avg_pool1d.default(
x, kernel_size, stride, padding, ceil_mode, count_include_pad
)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(pool1d(), input_specs, use_dynamo_tracer=True)

@parameterized.expand(
[
(
(1, 1, 1, 1),
(2, 2, 2, 2),
(3, 3, 3, 3),
torch.float,
3,
1,
1,
),
(
(1, 1, 1, 1),
(2, 2, 2, 2),
(3, 3, 3, 3),
torch.float,
(3, 3),
(1, 1),
(1, 1),
),
]
)
def test_dynamic_shape_pool2d(
self,
min_shape,
opt_shape,
max_shape,
type,
kernel_size,
stride=1,
padding=0,
ceil_mode=False,
count_include_pad=True,
):
class pool2d(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.avg_pool2d.default(
x, kernel_size, stride, padding, ceil_mode, count_include_pad
)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(pool2d(), input_specs, use_dynamo_tracer=True)

@parameterized.expand(
[
(
(1, 1, 1, 1, 1),
(2, 2, 2, 2, 2),
(3, 3, 3, 3, 3),
torch.float,
2,
1,
1,
),
(
(1, 1, 1, 1, 1),
(2, 2, 2, 2, 2),
(3, 3, 3, 3, 3),
torch.float,
(2, 2, 2),
(1, 1, 1),
(1, 1, 1),
),
]
)
def test_dynamic_shape_pool3d(
self,
min_shape,
opt_shape,
max_shape,
type,
kernel_size,
stride=1,
padding=0,
ceil_mode=False,
count_include_pad=True,
):
class pool3d(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.avg_pool3d.default(
x, kernel_size, stride, padding, ceil_mode, count_include_pad
)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(pool3d(), input_specs, use_dynamo_tracer=True)

@parameterized.expand(
[
(3, 1, 0),
Expand Down
Loading