Skip to content

feat: dynamic shape support for aten.select.int #2990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def aten_ops_scatter(
)


@dynamo_tensorrt_converter(torch.ops.aten.select.int)
@dynamo_tensorrt_converter(torch.ops.aten.select.int, supports_dynamic_shapes=True)
def aten_ops_select(
ctx: ConversionContext,
target: Target,
Expand Down
16 changes: 3 additions & 13 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,16 @@ def select(
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
index = index

if index >= input.shape[dim]:
raise RuntimeError(
f"cannot have index greater than the dimension length! {input.shape[dim]}"
)
output_shape = list(input.shape)
output_shape[dim] = 1
if dynamic_shape > 0:
output_shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, name, output_shape, input
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If above asserts() are removed and full dynamic shape was used(e.g.(-1,-1,-1)), test worked.
I'm wondering if select on dynamic dim can be supported.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested the test you mentioned and confirmed that using -1, -1, -1 also passes the test cases successfully.

However, I didn't remove the asserts() because the test cases fail when the index is out of range, meaning the index is larger than the corresponding dimension in the dynamic input shape.

            (
                "success case",
                (1, 1, 1),
                (2, 2, 2),
                (3, 3, 3),
                torch.float,
                0,
                1,
            ),
            (
                "fail case",
                (1, 1, 1),
                (2, 2, 2),
                (3, 3, 3),
                torch.float,
                0,
                3,
            ),

It seems that modifying the index like dimension (dim = get_positive_dim(cast(int, dim), ranks)) when the index is greater than the size would solve the issue. Do you have an example that handles it this way?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cannot check invalid index for dynamic dim. Error will happen in runtime.
Maybe we can check index for only static dim.
if DYNAMIC_DIM != input.shape[dim] and index >= input.shape[dim]:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tested by removing the assert() and raise RuntimeError() statements and tested with (-1, -1, -1).

The test case mentioned as a success case above passes, but the test case mentioned as a fail case fails. The reason for this is that in the fail case, the size of the 0-th dimension is 3, but the index selects positions larger than that (index=3, 4th position). The current converter does not handle this case where the index is larger than the size of the dimension specified by dim, as in the 'fail case'. (Note: It might be possible to handle this using the slice layer, lt (less than) and div functions, but currently, it is handled with an assert statement.)

Therefore, I left the assert() and raise RuntimeError() statement without removing it.

Copy link
Collaborator Author

@chohk88 chohk88 Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! I misunderstood. If the index is larger than the size of the dimension specified by dim, it won't work in PyTorch either, so we don't need to handle this case. Therefore, we don't need to consider test cases like the 'fail case' mentioned above. This means that a dynamic shape is supported for all dimensions.

Pytorch example of correct usage of select.int

image

Pytorch example of IndexError occurring case of select.int

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keehyuna Thanks for the suggestion!

Since PyTorch already raises an error when the index exceeds the input size, there’s no need for us to check this here. I removed those checks. Additionally, following your suggestion, I’ve added test cases to fully support dynamic shapes.

index_value = np.array(index, dtype=np.int32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can index itself be dynamic (ITensor)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the PyTorch docs and the schema, the index is always an integer and cannot be a list or tuple of integers. The size of the indices_tensor created with index and index_value will always be a scalar. Therefore, it cannot be dynamic.

indices_tensor = ctx.net.add_constant(
index_value.shape, to_numpy(index_value)
).get_output(0)
indices_tensor = ctx.net.add_constant(index_value.shape, index_value).get_output(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use get_trt_tensor call here ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

dim and index are now int (not Shape), and I've changed to using get_trt_tensor for indices_tensor.

layer = ctx.net.add_gather(input, indices_tensor, dim)
out = layer.get_output(0)
if len(out.shape) != 1:
layer = ctx.net.add_shuffle(out)

return layer.get_output(0)


Expand Down
57 changes: 44 additions & 13 deletions tests/py/dynamo/conversion/test_select_aten.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
Expand All @@ -13,7 +14,7 @@ class TestSelectConverterOne(DispatchTestCase):
]
)
def test_select(self, _, dim, index):
class TestModule(torch.nn.Module):
class select(nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -22,7 +23,7 @@ def forward(self, input):

input = [torch.randn(1, 2)]
self.run_test(
TestModule(),
select(),
input,
)

Expand All @@ -34,7 +35,7 @@ class TestSelectConverterTwo(DispatchTestCase):
]
)
def test_select(self, _, dim, index):
class TestModule(torch.nn.Module):
class select(nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -43,33 +44,63 @@ def forward(self, input):

input = [torch.randn(4, 4, 4, 4)]
self.run_test(
TestModule(),
select(),
input,
)


class TestSelectConverterWithDynamicShape(DispatchTestCase):
class TestSelectConverterDynamicShape(DispatchTestCase):
@parameterized.expand(
[
("select_dim_index", 1, 0),
(
"select_dim_index",
(1, 3, 3),
(2, 3, 3),
(3, 3, 3),
torch.int32,
1,
0,
),
(
"select_dim_index",
(1, 1, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
2,
0,
),
(
"select_dim_index",
(3, 1, 1),
(3, 2, 2),
(3, 3, 3),
torch.float,
0,
2,
),
]
)
def test_select_with_dynamic_shape(self, _, dim, index):
class TestModule(torch.nn.Module):
def test_dynamic_shape_select(
self, _, min_shape, opt_shape, max_shape, type, dim, index
):
class select(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.select.int(input, dim, index)

input_spec = [
input_specs = [
Input(
shape=(-1, 3, 3),
dtype=torch.float32,
shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(TestModule(), input_spec)

self.run_test_with_dynamic_shape(select(), input_specs)


if __name__ == "__main__":
Expand Down
Loading