Skip to content

feat: dynamic shape support for adaptive_avg_poolNd (partially) #3021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2593,7 +2593,9 @@ def aten_ops_avg_pool(
)


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
@dynamo_tensorrt_converter(
torch.ops.aten.adaptive_avg_pool1d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -2616,10 +2618,18 @@ def aten_ops_adaptive_avg_pool1d(
)


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
@dynamo_tensorrt_converter(
torch.ops.aten.adaptive_avg_pool2d.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten._adaptive_avg_pool2d.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten.adaptive_avg_pool3d.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten._adaptive_avg_pool3d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
"""Calculate the end index of each pooling window"""
return math.ceil((float(idx + 1) * float(in_dim)) / out_dim)

if has_dynamic_shape(input.shape):
assert (
input.shape[-1] != -1 and input.shape[-2] != -1
), "Last 2 dimensions can't be dynamic for adaptive_avg_pool1d."

in_dim = input.shape[-1]
out_dim = output_size if isinstance(output_size, int) else output_size[0]
output_list = []
Expand Down Expand Up @@ -182,6 +187,18 @@ def adaptive_avg_poolNd(
input: TRTTensor,
output_size: Sequence[int],
) -> TRTTensor:
if has_dynamic_shape(input.shape):
if len(output_size) == 2: # adaptive_avg_pool2d
assert (
input.shape[-1] != -1 and input.shape[-2] != -1
), "Last 2 dimensions can't be dynamic for adaptive_avg_pool2d."
elif len(output_size) == 3: # adaptive_avg_pool3d
assert (
input.shape[-1] != -1
and input.shape[-2] != -1
and input.shape[-3] != -1
), "Last 3 dimensions can't be dynamic for adaptive_avg_pool3d."

input_shape = input.shape
input_rank = len(input_shape)
output_rank = len(output_size)
Expand Down
110 changes: 80 additions & 30 deletions tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,40 @@ def forward(self, x):
enable_passes=True,
)

@parameterized.expand(
[
(
(1, 3, 3),
(2, 3, 3),
(3, 3, 3),
torch.float,
(2,),
),
]
)
def test_dynamic_shape_adaptive_pool1d(
self,
min_shape,
opt_shape,
max_shape,
type,
output_size,
):
class adaptive_pool1d(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.adaptive_avg_pool1d.default(x, output_size)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(adaptive_pool1d(), input_specs)

@parameterized.expand(
[
# 3d input
Expand Down Expand Up @@ -159,29 +193,37 @@ def forward(self, x):

@parameterized.expand(
[
((1, 2),),
(
(1, 1, 3, 3),
(2, 2, 3, 3),
(3, 3, 3, 3),
torch.float,
(2, 2),
),
]
)
def test_adaptive_avg_pool2d_dynamic(self, output_size):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def test_dynamic_shape_adaptive_pool2d(
self,
min_shape,
opt_shape,
max_shape,
type,
output_size,
):
class adaptive_pool2d(torch.nn.Module):
def forward(self, x):
out = torch.ops.aten.adaptive_avg_pool2d.default(x, output_size)
return out
return torch.ops.aten.adaptive_avg_pool2d.default(x, output_size)

input_specs = [
Input(
shape=(-1, 2, 3, 2),
dtype=torch.float32,
shape_ranges=[((1, 2, 3, 2), (3, 2, 3, 2), (10, 2, 3, 2))],
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
)

self.run_test_with_dynamic_shape(adaptive_pool2d(), input_specs)

@parameterized.expand(
[
Expand Down Expand Up @@ -271,29 +313,37 @@ def forward(self, x):

@parameterized.expand(
[
((1, 2, 3),),
(
(1, 1, 3, 3, 3),
(2, 2, 3, 3, 3),
(3, 3, 3, 3, 3),
torch.float,
(2, 2, 2),
),
]
)
def test_adaptive_avg_pool3d_dynamic(self, output_size):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def test_dynamic_shape_adaptive_pool3d(
self,
min_shape,
opt_shape,
max_shape,
type,
output_size,
):
class adaptive_pool3d(torch.nn.Module):
def forward(self, x):
out = torch.ops.aten.adaptive_avg_pool3d.default(x, output_size)
return out
return torch.ops.aten.adaptive_avg_pool3d.default(x, output_size)

input_specs = [
Input(
shape=(-1, 2, 3, 1, 4),
dtype=torch.float32,
shape_ranges=[((1, 2, 3, 1, 4), (3, 2, 3, 1, 4), (10, 2, 3, 1, 4))],
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
)

self.run_test_with_dynamic_shape(adaptive_pool3d(), input_specs)


if __name__ == "__main__":
Expand Down
Loading